import { Dispatch, SetStateAction, Suspense, useEffect, useMemo, useRef, useState } from "react";

import { useQueryClient } from "@tanstack/react-query";
import { InputNumber, Slider } from "antd";
import chroma from "chroma-js";
import distinctColors from "distinct-colors";
import {
    CartesianGrid,
    Legend,
    LegendProps,
    Line,
    LineChart,
    ResponsiveContainer,
    Tooltip,
    TooltipProps,
    XAxis,
    YAxis,
} from "recharts";
import { NameType, ValueType } from "recharts/types/component/DefaultTooltipContent";
import { Button, Divider, Header, Icon, Loader, Modal, Popup, Segment } from "semantic-ui-react";

import { adapterVersion, finetuningJob, grpoJobMetrics, repo } from "@/autogen/openapi";

import { getDocsHome } from "../../../../utils/api";
import { formatValueToNumericString } from "../../../../utils/numbers";
import { isTerminalJobStatus } from "../../../misc/utils";
import { GET_ADAPTER_VERSION_QUERY_KEY, GET_FINETUNING_JOB_QUERY_KEY } from "../../../query";
import {
    ChartData,
    DEFAULT_EMA_FACTOR,
    EMA_PREFIX,
    RewardChartGroup,
    VALIDATION_PREFIX,
    getRewardChartData,
    useStreamingMetrics,
} from "../metrics/grpo_util";
import { lineChartFormatting } from "./TurboAccuracyCurves";

import "../../../../prompt/Sliders.css";

const EMA_SLIDER_MIN = 0;
const EMA_SLIDER_MAX = 1;
const EMA_SLIDER_STEP = 0.1;
// TODO: Let's get rid of antd if we can!
export const EMASlider = (props: { emaFactor: number; setEmaFactor: Dispatch<SetStateAction<number>> }) => {
    const { emaFactor, setEmaFactor } = props;

    return (
        <div>
            <label
                style={{ display: "flex", justifyContent: "space-between", alignItems: "center", marginBottom: "0" }}
            >
                <label style={{ marginRight: "1rem" }}>
                    <Popup
                        className="transition-scale"
                        hoverable
                        wide={"very"}
                        position={"right center"}
                        trigger={<b>Smoothing</b>}
                        content={
                            <span>
                                Exponential smoothing is a common technique for smoothing time series data by
                                exponentially decaying the weight of previous points. The range is 0 to 1.
                            </span>
                        }
                    />
                </label>
                <div style={{ minWidth: "10rem", marginRight: "1rem" }}>
                    <Slider
                        value={emaFactor}
                        onChange={(val) => setEmaFactor(val)}
                        styles={{
                            rail: { height: `{6/14}rem` },
                            track: { backgroundColor: "black", height: `{6/14}rem` },
                        }}
                        step={EMA_SLIDER_STEP}
                        min={EMA_SLIDER_MIN}
                        max={EMA_SLIDER_MAX}
                    />
                </div>
                <span>
                    <InputNumber
                        min={EMA_SLIDER_MIN}
                        max={EMA_SLIDER_MAX}
                        step={EMA_SLIDER_STEP}
                        value={emaFactor}
                        onChange={(val) => setEmaFactor(val?.valueOf() ?? DEFAULT_EMA_FACTOR)}
                    />
                </span>
            </label>
        </div>
    );
};

const formatMetricForLegendTooltip = (metricName: string) => {
    const isEMA = metricName.startsWith(EMA_PREFIX);
    const isValidation = metricName.includes(VALIDATION_PREFIX);
    const formattedName = metricName.replace(EMA_PREFIX, "").replace(VALIDATION_PREFIX, "");
    return `${isEMA ? "AVG " : ""}${formattedName}${isValidation ? " (Validation)" : ""}`;
};

const CustomLegend = (
    props: LegendProps & {
        visibleSeries: Record<string, boolean>;
        setVisibleSeries: Dispatch<SetStateAction<Record<string, boolean>>>;
    },
) => {
    const { payload, visibleSeries, setVisibleSeries } = props;
    const containerRef = useRef<HTMLDivElement>(null);
    const [applyMaxHeight, setApplyMaxHeight] = useState(false);

    useEffect(() => {
        if (containerRef.current) {
            // Find the Segment grandparent element
            let segmentElement = containerRef.current.parentElement;
            while (segmentElement && !segmentElement.classList.contains("segment")) {
                segmentElement = segmentElement.parentElement;
            }

            if (segmentElement) {
                const segmentHeight = segmentElement.clientHeight;
                const legendHeight = containerRef.current.scrollHeight;
                setApplyMaxHeight(legendHeight > segmentHeight * 0.3);
            }
        }
    }, [payload]);

    return (
        <div
            ref={containerRef}
            style={{
                display: "flex",
                flexDirection: "column",
                alignItems: "flex-start",
                marginTop: "1.8rem" /* Increased margin to prevent overlap with x-axis label (25px ÷ 14px = 1.8rem) */,
                padding: "0 2.1rem" /* 30px ÷ 14px = 2.1rem */,
                ...(applyMaxHeight ? { maxHeight: "7.1rem", overflowY: "auto" } : {}) /* 100px ÷ 14px = 7.1rem */,
            }}
        >
            {payload?.map((entry, index) => {
                const metricName = entry.dataKey as string | undefined;
                if (!metricName) return null;
                const isActive = visibleSeries[metricName];

                return (
                    <div
                        key={`item-${index}`}
                        onClick={() => {
                            setVisibleSeries({
                                ...visibleSeries,
                                [metricName]: !isActive,
                            });
                        }}
                        style={{
                            marginBottom: "8px",
                            display: "flex",
                            alignItems: "center",
                            cursor: "pointer",
                            opacity: isActive ? 1 : 0.5,
                        }}
                    >
                        <div
                            style={{
                                width: "12px",
                                height: "12px",
                                backgroundColor: entry.color,
                                marginRight: "8px",
                            }}
                        />
                        <span>{formatMetricForLegendTooltip(metricName)}</span>
                    </div>
                );
            })}
        </div>
    );
};

// Custom tooltip that only shows EMA values for visible series
// ? NOTE: The type of `payload[0].payload` is `ChartDataPoint`. Useful to know, but not applicable at the moment.
const CustomTooltip = (props: TooltipProps<ValueType, NameType> & { visibleSeries: Record<string, boolean> }) => {
    const { active, payload, label, visibleSeries } = props;

    if (!active || !payload || !payload.length) return null;

    return (
        <div
            style={{
                backgroundColor: "white",
                padding: "8px",
                border: "1px solid #e0e0e0",
                boxShadow: "0 2px 4px rgba(0,0,0,0.1)",
                borderRadius: "3px",
            }}
        >
            <p style={{ color: "#333", margin: "0 0 5px 0" }}>{`Step: ${label}`}</p>
            {payload?.map((entry, index) => {
                const metricName = entry.dataKey as string | undefined;
                if (!metricName) return null;
                const isActive = visibleSeries[metricName];

                return (
                    isActive && (
                        <p
                            key={index}
                            style={{
                                color: entry.color,
                                fontWeight: 500,
                                margin: "3px 0",
                            }}
                        >
                            {`${formatMetricForLegendTooltip(metricName)}: ${(entry.value as number)?.toFixed(3)}`}
                        </p>
                    )
                );
            })}
        </div>
    );
};

const Chart = (props: { selectedMetrics: string[]; chartData: ChartData }) => {
    // Parent state:
    const { selectedMetrics, chartData } = props;

    // NOTE: Messy, but for now we'll hard-code colors for EMA lines so that they are neutral and visible.
    const rawEMAColor = chroma("#000000");
    const validationEMAColor = chroma("#888888");
    const palette = distinctColors({
        count: selectedMetrics.length,
        lightMin: 30,
    });

    // State to track which series are visible
    const [visibleSeries, setVisibleSeries] = useState(
        selectedMetrics.reduce((acc, name) => ({ ...acc, [name]: true }), {} as Record<string, boolean>),
    );

    return (
        <div
            style={{
                display: "flex",
                flexDirection: "row",
                justifyContent: "flex-start",
                padding: "20px",
                position: "relative",
            }}
        >
            <div style={{ flex: "0 0 100%" }}>
                <ResponsiveContainer width={"100%"} aspect={2}>
                    <LineChart data={chartData} style={{ overflow: "none" }} syncId="reward-chart">
                        <XAxis dataKey={"step"} label={{ value: "Step", position: "bottom", offset: 5 }}></XAxis>
                        <YAxis tickFormatter={(val) => formatValueToNumericString(val, 2)} />
                        {Object.entries(visibleSeries).map(([metric, isVisible], idx) => {
                            // TODO: The coloring can still be improved, but it requires some discussion. In general, it
                            // would be more intuitive to cluster colors by reward - so that each reach reward has its own
                            // color band, versions are equidistant samples within that band, and the EMA (computed across
                            // versions) is assigned the "median" or basal color of the band. This is complicated to compute
                            // dynamically, and we may change the UX here (e.g. multiple rewards on the same chart?), so
                            // we'll leave this as is for now. To make the EMA more visible, all raw data will have an
                            // opacity modifier.
                            let color = palette[idx];
                            if (metric.includes(EMA_PREFIX) && metric.includes(VALIDATION_PREFIX)) {
                                color = validationEMAColor;
                            } else if (metric.includes(EMA_PREFIX)) {
                                color = rawEMAColor;
                            } else {
                                color = color.alpha(0.5);
                            }

                            return (
                                <Line
                                    hide={!isVisible}
                                    key={`line-${metric}`}
                                    type="natural"
                                    dataKey={metric}
                                    stroke={color?.hex()} // TODO: Use alpha (opacity), pair ema and raw again
                                    activeDot={true}
                                    isAnimationActive={false}
                                    name={metric}
                                    dot={false}
                                    connectNulls={true} // This is necessary for validation data to have lines
                                />
                            );
                        })}
                        {/* <Tooltip /> */}
                        <Tooltip content={<CustomTooltip visibleSeries={visibleSeries} />} />
                        <Legend
                            content={<CustomLegend visibleSeries={visibleSeries} setVisibleSeries={setVisibleSeries} />}
                        />
                        <CartesianGrid stroke="#eee" strokeDasharray="5 5" />
                    </LineChart>
                </ResponsiveContainer>
            </div>
        </div>
    );
};

const Canvas = (props: { chartData: ChartData; chartDataGroups: RewardChartGroup[] }) => {
    const { chartData, chartDataGroups } = props;
    const [openModal, setOpenModal] = useState(false);
    const [modalChartData, setModalChartData] = useState<ChartData | null>(null);
    const [modalSelectedMetrics, setModalSelectedMetrics] = useState<string[]>([]);
    const [modalTitle, setModalTitle] = useState<string>("");

    const handleExpandClick = (chartData: ChartData, selectedMetrics: string[], title: string) => {
        setModalChartData(chartData);
        setModalSelectedMetrics(selectedMetrics);
        setModalTitle(title);
        setOpenModal(true);
    };

    return (
        <div style={lineChartFormatting}>
            {chartDataGroups
                .sort((a, b) => {
                    // Always show total_reward first, then total_reward_std, then leave the rest as is
                    if (a.rewardName === "total_reward") return -1;
                    if (b.rewardName === "total_reward") return 1;
                    if (a.rewardName === "total_reward_std") return -1;
                    if (b.rewardName === "total_reward_std") return 1;
                    return 0;
                })
                .map(({ rewardName, metrics }) => {
                    const selectedMetrics = metrics;
                    return (
                        <Segment key={rewardName} style={{ margin: "0px", position: "relative" }}>
                            <div
                                style={{
                                    display: "flex",
                                    alignItems: "center",
                                    justifyContent: "space-between",
                                    marginBottom: "10px",
                                }}
                            >
                                <h3>{rewardName}</h3>
                                <Button
                                    icon="expand"
                                    onClick={() => handleExpandClick(chartData, selectedMetrics, rewardName)}
                                />
                            </div>
                            <Chart selectedMetrics={selectedMetrics} chartData={chartData} />
                        </Segment>
                    );
                })}
            <Modal open={openModal} onClose={() => setOpenModal(false)} size="fullscreen">
                <Modal.Header>
                    <div style={{ display: "flex", justifyContent: "space-between", alignItems: "center" }}>
                        <h3>{modalTitle}</h3>
                        <Button
                            icon="close"
                            onClick={() => setOpenModal(false)}
                            style={{ background: "none", padding: "0" }}
                        />
                    </div>
                </Modal.Header>
                <Modal.Content>
                    {modalChartData && <Chart selectedMetrics={modalSelectedMetrics} chartData={modalChartData} />}
                </Modal.Content>
            </Modal>
        </div>
    );
};

const HistoricalCurves = (props: {
    historicalData: grpoJobMetrics[] | undefined;
    historicalDataIsLoading: boolean;
    emaFactor: number;
}) => {
    // Parent state:
    const { historicalData, historicalDataIsLoading, emaFactor } = props;

    // Derived metrics state:
    const { chartData, chartDataGroups } = useMemo(
        () => getRewardChartData(historicalData, emaFactor),
        [historicalData, emaFactor],
    );

    return (
        <>
            {historicalDataIsLoading && (
                <div className="loading-overlay" style={{ height: "10rem", background: "none" }}>
                    <Loader active />
                </div>
            )}
            {chartData && <Canvas chartData={chartData} chartDataGroups={chartDataGroups} />}
        </>
    );
};

const LiveCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    setWebsocketOpen: Dispatch<SetStateAction<boolean>>;
    websocketData: grpoJobMetrics[] | undefined;
    emaFactor: number;
}) => {
    // Parent state:
    const { adapterVersion, repoUUID, job, setWebsocketOpen, websocketData, emaFactor } = props;
    const jobUUID = job.uuid;
    const jobStatus = adapterVersion.status;
    const versionTag = adapterVersion.tag;

    // Derived metrics state:
    const latestEvent = websocketData?.at(-1);
    const latestEventIsTerminal = latestEvent?.meta.is_completed;
    const { chartData, chartDataGroups } = useStreamingMetrics(websocketData, emaFactor);

    // Query state:
    const queryClient = useQueryClient();
    useEffect(() => {
        if (latestEventIsTerminal || isTerminalJobStatus(jobStatus)) {
            setWebsocketOpen(false);

            queryClient.invalidateQueries({ queryKey: GET_FINETUNING_JOB_QUERY_KEY(jobUUID) });
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_VERSION_QUERY_KEY(repoUUID, versionTag) });
        }
    }, [latestEventIsTerminal, jobStatus, jobUUID, repoUUID, versionTag, setWebsocketOpen]);

    return (
        <>
            <Suspense
                fallback={
                    <div className="loading-overlay" style={{ height: "10rem", background: "none" }}>
                        <Loader active />
                    </div>
                }
            >
                {latestEvent && chartData && <Canvas chartData={chartData} chartDataGroups={chartDataGroups} />}
            </Suspense>
        </>
    );
};

const RewardGraphs = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    setWebsocketOpen: Dispatch<SetStateAction<boolean>>;
    websocketData: grpoJobMetrics[] | undefined;
    historicalData: grpoJobMetrics[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const { adapterVersion, repoUUID, job, setWebsocketOpen, websocketData, historicalData, historicalDataIsLoading } =
        props;
    const jobStatus = adapterVersion.status;

    // Local state:
    const [emaFactor, setEmaFactor] = useState(DEFAULT_EMA_FACTOR);

    const hasData =
        Boolean(websocketData?.length && websocketData.length > 0) ||
        Boolean(historicalData?.length && historicalData.length > 0);

    return (
        <>
            {hasData ? (
                <div
                    style={{
                        display: "flex",
                        justifyContent: "space-between",
                        verticalAlign: "bottom",
                        marginBottom: "1rem",
                    }}
                >
                    <a
                        href={`${getDocsHome()}/user-guide/fine-tuning/grpo#how-do-i-interpret-my-reward-graphs`}
                        target="_blank"
                        rel="noreferrer"
                        style={{ paddingTop: "5px" }}
                    >
                        <Icon name="book" />
                        Learn how to interpret reward graphs and track model progress
                    </a>

                    <EMASlider emaFactor={emaFactor} setEmaFactor={setEmaFactor} />
                </div>
            ) : (
                <div style={{ display: "flex", flexDirection: "column", alignItems: "center" }}>
                    <Divider hidden />
                    <img src={"/model/emptyRepos.svg"} alt="" />
                    <Header as="h2" size={"medium"} style={{ marginBottom: "0.5rem" }}>
                        No metrics yet!
                    </Header>
                    <Divider hidden />
                </div>
            )}
            <div style={{ minHeight: "50vh" }}>
                {isTerminalJobStatus(jobStatus) ? (
                    <HistoricalCurves
                        historicalData={historicalData}
                        historicalDataIsLoading={historicalDataIsLoading}
                        emaFactor={emaFactor}
                    />
                ) : (
                    <LiveCurves
                        adapterVersion={adapterVersion}
                        repoUUID={repoUUID}
                        job={job}
                        websocketData={websocketData}
                        setWebsocketOpen={setWebsocketOpen}
                        emaFactor={emaFactor}
                    />
                )}
            </div>
        </>
    );
};

export default RewardGraphs;
