import { useEffect, useState } from "react";

import { useQueryClient } from "@tanstack/react-query";
import _ from "lodash";
import { Loader, Table } from "semantic-ui-react";

import {
    adapterVersion,
    finetuningJob,
    finetuningJobStatus,
    repo,
    sftMetricsPayload,
    sftMetricsStepData,
} from "@/autogen/openapi";

import FormattedNumberTooltip from "../../../../components/FormattedNumberTooltip";
import InfoMessage from "../../../../components/InfoMessage";
import { SEMANTIC_GREY } from "../../../../utils/colors";
import { isTerminalJobStatus } from "../../../misc/utils";
import { GET_ADAPTER_VERSION_QUERY_KEY, GET_FINETUNING_JOB_QUERY_KEY } from "../../../query";
import SFTLineChart from "../metrics/SFTLineChart";
import { SFTMetricsStepDataKey, SFTMetricsStepDataType, metricNameMap, stepCounterText } from "../metrics/sft_util";

const tooltipStyle = { color: SEMANTIC_GREY };

const CHART_METRIC_NAMES = ["train_metrics_loss", "validation_metrics_loss"];

const SummaryTable = (props: { latestCheckpoint: sftMetricsPayload | undefined }) => {
    // Parent state:
    const { latestCheckpoint } = props;

    // Derived state:
    const bestCheckpointNumber = latestCheckpoint?.data.best_eval_metric_checkpoint_number;

    return (
        <Table
            basic
            style={{
                background: "#F7F7F7",
                paddingLeft: "1rem",
                paddingRight: "1rem",
                border: "1px solid rgb(228 230 230)",
                borderRadius: ".28571429rem",
            }}
            collapsing
        >
            <Table.Body>
                {bestCheckpointNumber && bestCheckpointNumber !== -1 && (
                    <Table.Row>
                        <Table.Cell style={{ paddingTop: "16px", paddingBottom: "16px", width: "100%" }}>
                            <b>Best Checkpoint</b>
                        </Table.Cell>
                        <Table.Cell>{bestCheckpointNumber}</Table.Cell>
                    </Table.Row>
                )}
                {CHART_METRIC_NAMES.map((name) => {
                    const metricName = metricNameMap[name];
                    const latestValue = latestCheckpoint?.data[name as SFTMetricsStepDataKey];

                    if (latestValue === undefined) {
                        return null;
                    }

                    return (
                        <Table.Row key={`${name}-metric-table-row`}>
                            <Table.Cell collapsing style={{ borderTop: "none", width: "100%" }}>
                                {metricName}
                            </Table.Cell>
                            <Table.Cell style={{ borderTop: "none" }}>
                                <FormattedNumberTooltip
                                    value={latestValue}
                                    style={tooltipStyle}
                                    truncatedStyle={tooltipStyle}
                                />
                            </Table.Cell>
                        </Table.Row>
                    );
                })}
            </Table.Body>
        </Table>
    );
};

const Canvas = (props: {
    selectedMetrics: SFTMetricsStepDataKey[];
    latestCheckpoint: sftMetricsPayload | undefined;
    chartData: sftMetricsStepData[];
}) => {
    // Parent state:
    const { selectedMetrics, latestCheckpoint, chartData } = props;
    return (
        <div style={{ display: "flex", flexDirection: "row", justifyContent: "flex-start" }}>
            <div style={{ flex: "0 0 20%" }}>
                <SummaryTable latestCheckpoint={latestCheckpoint} />
            </div>
            <div style={{ flex: "0 0 80%" }}>
                <SFTLineChart selectedMetrics={selectedMetrics} chartData={chartData} xAxisName={"checkpoint_number"} />
            </div>
        </div>
    );
};

const HistoricalCurves = (props: {
    adapterVersion: adapterVersion; // TODO: unused
    historicalData: sftMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const { historicalData, historicalDataIsLoading } = props;

    // Local state:
    const [selectedMetrics] = useState<SFTMetricsStepDataKey[]>(CHART_METRIC_NAMES as SFTMetricsStepDataKey[]);

    // Derived metrics state:
    const latestEvent = historicalData?.at(-1);
    // TODO: This doesn't mean anything right now, because we seem to only get checkpoint events anyway and because we
    // would be receiving a "zeroed" data struct on each event from golang (instead of an entirely absent data struct):
    const checkpointEvents = historicalData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const latestCheckpoint = checkpointEvents?.at(-1);
    const chartData: SFTMetricsStepDataType[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    return (
        <>
            {historicalDataIsLoading && (
                <div className="loading-overlay" style={{ height: "10rem", background: "none" }}>
                    <Loader active />
                </div>
            )}
            {latestEvent && chartData && (
                <Canvas selectedMetrics={selectedMetrics} latestCheckpoint={latestCheckpoint} chartData={chartData} />
            )}
        </>
    );
};

const LiveCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
    websocketData: sftMetricsPayload[] | undefined;
}) => {
    // Parent state:
    const { adapterVersion, repoUUID, job, websocketOpen, setWebsocketOpen, websocketData } = props;
    const jobUUID = job.uuid;
    const jobStatus = adapterVersion.status;
    const versionTag = adapterVersion.tag;

    // Local state:
    const [selectedMetrics] = useState<SFTMetricsStepDataKey[]>(CHART_METRIC_NAMES as SFTMetricsStepDataKey[]);

    // Derived metrics state:
    const latestEvent = websocketData?.at(-1);
    const latestEventIsTerminal = latestEvent?.meta.is_completed;
    // TODO: Right now, checkpoint events are the only ones with `data` set:
    const checkpointEvents = websocketData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const latestCheckpoint = checkpointEvents?.at(-1);
    const chartData: SFTMetricsStepDataType[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    // Query state:
    const queryClient = useQueryClient();
    useEffect(() => {
        if (latestEventIsTerminal || isTerminalJobStatus(jobStatus)) {
            setWebsocketOpen(false);

            queryClient.invalidateQueries({ queryKey: GET_FINETUNING_JOB_QUERY_KEY(jobUUID) });
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_VERSION_QUERY_KEY(repoUUID, versionTag) });
        }
    }, [latestEventIsTerminal, jobStatus, jobUUID, repoUUID, versionTag, setWebsocketOpen]);

    return (
        <>
            {websocketOpen && (
                <InfoMessage
                    header={`Adapter [${adapterVersion.repo}/${adapterVersion.tag}] is ${adapterVersion.status}`}
                    infoMessage={
                        <>
                            <br />
                            <Loader active inline size={"small"} />
                            &nbsp;
                            {stepCounterText(latestEvent, latestCheckpoint)}
                        </>
                    }
                />
            )}
            {latestEvent && chartData && (
                <Canvas selectedMetrics={selectedMetrics} latestCheckpoint={latestCheckpoint} chartData={chartData} />
            )}
        </>
    );
};

const LossCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
    websocketData: sftMetricsPayload[] | undefined;
    historicalData: sftMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const {
        adapterVersion,
        repoUUID,
        job,
        websocketOpen,
        setWebsocketOpen,
        websocketData,
        historicalData,
        historicalDataIsLoading,
    } = props;
    const jobStatus = adapterVersion.status;

    if (jobStatus === finetuningJobStatus.CANCELED) {
        return null;
    }

    return isTerminalJobStatus(jobStatus) ? (
        <HistoricalCurves
            adapterVersion={adapterVersion}
            historicalData={historicalData}
            historicalDataIsLoading={historicalDataIsLoading}
        />
    ) : (
        <LiveCurves
            adapterVersion={adapterVersion}
            repoUUID={repoUUID}
            job={job}
            websocketData={websocketData}
            websocketOpen={websocketOpen}
            setWebsocketOpen={setWebsocketOpen}
        />
    );
};

export default LossCurves;
