import { useMemo } from "react";

import { CartesianGrid, Label, Legend, Line, LineChart, ResponsiveContainer, Tooltip, XAxis, YAxis } from "recharts";
import { max, min, sampleVariance } from "simple-statistics";
import stc from "string-to-color";

import { formatValueToNumericString } from "../../../../utils/numbers";
import { capitalize } from "../../../../utils/strings";

// TODO: Generalize this component further to support GRPO?

const formatXAxisLabel = (xAxisName: string) => {
    if (!xAxisName) return "";

    // Split by underscore
    const words = xAxisName.toString().split("_");

    // Capitalize each word
    const capitalizedWords = words.map((word) => {
        return capitalize(word);
    });

    // Join with spaces
    return capitalizedWords.join(" ").trim();
};

const SFTLineChart = <T,>(props: { selectedMetrics: (keyof T)[]; xAxisName: keyof T; chartData: T[] | undefined }) => {
    // Parent state:
    const { selectedMetrics, chartData, xAxisName } = props;

    // Set y-axis domain using variance:
    const yDomain = useMemo(() => {
        const allDatapoints: number[] = [];
        selectedMetrics.forEach((metric) => {
            const metricData = chartData?.reduce((accumulator: number[], current) => {
                const d = current[metric] as number | undefined;
                if (d !== undefined) {
                    accumulator.push(d);
                }
                return accumulator;
            }, []);
            if (metricData !== undefined) {
                allDatapoints.push(...metricData);
            }
        });

        // If there are fewer than 2 data points, let recharts set the domain:
        if (allDatapoints.length < 2) {
            return undefined;
        }
        const svar = sampleVariance(allDatapoints);
        const minY = min(allDatapoints);
        const maxY = max(allDatapoints);
        const varMultiplier = 0.1;

        // NOTE: With current metrics the values should never be negative, so floor the bottom of y-axis at 0:
        const low = Math.max(0, minY - svar * varMultiplier);
        const high = maxY + svar * varMultiplier;

        return [low, high];
    }, [chartData, selectedMetrics]);

    if (chartData === undefined || chartData.length === 0) {
        return null;
    }

    // https://github.com/recharts/recharts/issues/172#issuecomment-307858843
    return (
        <ResponsiveContainer width={"99.8%"} aspect={2}>
            <LineChart data={chartData} style={{ overflow: "none" }}>
                <XAxis dataKey={String(xAxisName)}>
                    <Label value={formatXAxisLabel(String(xAxisName))} position={"bottom"} />
                </XAxis>
                <YAxis tickFormatter={(val) => formatValueToNumericString(val, 2)} domain={yDomain} />
                {selectedMetrics.map((name) => {
                    // If the data for the selected metric is undefined, don't show the line (note that we assume it
                    //  would then be undefined for the whole data series):
                    const latestValue = chartData.at(-1)?.[name];
                    if (latestValue === undefined) {
                        return null;
                    }

                    return (
                        <Line
                            key={`line-${String(name)}`}
                            type="monotone"
                            dataKey={String(name)}
                            name={String(name)}
                            // TODO: Generate better colors?
                            stroke={stc(name)}
                            strokeWidth={2}
                            dot={false}
                        />
                    );
                })}
                <Tooltip />
                <Legend wrapperStyle={{ bottom: -20, paddingBottom: "20px" }} />
                <CartesianGrid stroke="#eee" strokeDasharray="5 5" />
            </LineChart>
        </ResponsiveContainer>
    );
};

export default SFTLineChart;
