import { useQueryClient } from "@tanstack/react-query";
import _ from "lodash";
import { useEffect, useState } from "react";
import { Loader } from "semantic-ui-react";
import {
    adapterVersion,
    finetuningJob,
    finetuningJobStatus,
    predifineMetricsPayload,
    predifineMetricsPayloadData,
    repo,
} from "../../../../api_generated";
import InfoMessage from "../../../../components/InfoMessage";
import { isTerminalJobStatus } from "../../../misc/utils";
import { GET_ADAPTER_VERSION_QUERY_KEY, GET_FINETUNING_JOB_QUERY_KEY } from "../../../query";
import MetricsLineGraph from "./MetricsLineGraph";
import MetricsTable from "./MetricsTable";
import { CHART_METRIC_NAMES, MetricsPayloadData, MetricsPayloadDataKey } from "./util";

const stepCounterText = (
    latestEvent: predifineMetricsPayload | undefined,
    latestCheckpoint: predifineMetricsPayload | undefined,
) => {
    if (latestEvent === undefined) {
        return "Waiting for metrics...";
    }

    // Until we have the first checkpoint, show the number of steps to the first checkpoint:
    if (latestCheckpoint === undefined) {
        return `Waiting for metrics... Step ${latestEvent.meta.steps} / ${latestEvent.meta.steps_per_checkpoint} until first checkpoint`;
    }

    // After the first checkpoint, show the number of steps until completion:
    return `Step ${latestEvent.meta.steps} / ${latestEvent.meta.total_steps}`;
};

const TableAndChart = (props: {
    selectedMetrics: MetricsPayloadDataKey[];
    latestCheckpoint: predifineMetricsPayload | undefined;
    chartData: predifineMetricsPayloadData[];
}) => {
    // Parent state:
    const { selectedMetrics, latestCheckpoint, chartData } = props;
    return (
        <div style={{ display: "flex", flexDirection: "row", justifyContent: "flex-start" }}>
            <div style={{ flex: "0 0 20%" }}>
                <MetricsTable latestCheckpoint={latestCheckpoint} />
            </div>
            <div style={{ flex: "0 0 80%" }}>
                <MetricsLineGraph selectedMetrics={selectedMetrics} chartData={chartData} />
            </div>
        </div>
    );
};

const HistoricalCurves = (props: {
    adapterVersion: adapterVersion; // TODO: unused
    historicalData: predifineMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const { historicalData, historicalDataIsLoading } = props;

    // Local state:
    const [selectedMetrics] = useState<MetricsPayloadDataKey[]>(CHART_METRIC_NAMES);

    // Derived metrics state:
    const latestEvent = historicalData?.at(-1);
    // TODO: This doesn't mean anything right now, because we seem to only get checkpoint events anyway and because we
    // would be receiving a "zeroed" data struct on each event from golang (instead of an entirely absent data struct):
    const checkpointEvents = historicalData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const latestCheckpoint = checkpointEvents?.at(-1);
    const chartData: MetricsPayloadData[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    return (
        <>
            {historicalDataIsLoading && (
                <div className="loading-overlay" style={{ height: "10rem", background: "none" }}>
                    <Loader active />
                </div>
            )}
            {latestEvent && chartData && (
                <TableAndChart
                    selectedMetrics={selectedMetrics}
                    latestCheckpoint={latestCheckpoint}
                    chartData={chartData}
                />
            )}
        </>
    );
};

const LiveCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
    websocketData: predifineMetricsPayload[] | undefined;
}) => {
    // Parent state:
    const { adapterVersion, repoUUID, job, websocketOpen, setWebsocketOpen, websocketData } = props;
    const jobUUID = job.uuid;
    const jobStatus = adapterVersion.status;
    const versionTag = adapterVersion.tag;

    // Local state:
    const [selectedMetrics] = useState<MetricsPayloadDataKey[]>(CHART_METRIC_NAMES);

    // Derived metrics state:
    const latestEvent = websocketData?.at(-1);
    const latestEventIsTerminal = latestEvent?.meta.is_completed;
    // TODO: Right now, checkpoint events are the only ones with `data` set:
    const checkpointEvents = websocketData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const latestCheckpoint = checkpointEvents?.at(-1);
    const chartData: MetricsPayloadData[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    // Query state:
    const queryClient = useQueryClient();
    useEffect(() => {
        if (latestEventIsTerminal || isTerminalJobStatus(jobStatus)) {
            setWebsocketOpen(false);

            queryClient.invalidateQueries({ queryKey: GET_FINETUNING_JOB_QUERY_KEY(jobUUID) });
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_VERSION_QUERY_KEY(repoUUID, versionTag) });
        }
    }, [latestEventIsTerminal, jobStatus, jobUUID, repoUUID, versionTag, setWebsocketOpen]);

    return (
        <>
            {websocketOpen && (
                <InfoMessage
                    header={`Adapter [${adapterVersion.repo}/${adapterVersion.tag}] is ${adapterVersion.status}`}
                    infoMessage={
                        <>
                            <br />
                            <Loader active inline size={"small"} />
                            &nbsp;
                            {stepCounterText(latestEvent, latestCheckpoint)}
                        </>
                    }
                />
            )}
            {latestEvent && chartData && (
                <TableAndChart
                    selectedMetrics={selectedMetrics}
                    latestCheckpoint={latestCheckpoint}
                    chartData={chartData}
                />
            )}
        </>
    );
};

const LearningCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
    websocketData: predifineMetricsPayload[] | undefined;
    historicalData: predifineMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const {
        adapterVersion,
        repoUUID,
        job,
        websocketOpen,
        setWebsocketOpen,
        websocketData,
        historicalData,
        historicalDataIsLoading,
    } = props;
    const jobStatus = adapterVersion.status;

    if (jobStatus === finetuningJobStatus.CANCELED) {
        return null;
    }

    return isTerminalJobStatus(jobStatus) ? (
        <HistoricalCurves
            adapterVersion={adapterVersion}
            historicalData={historicalData}
            historicalDataIsLoading={historicalDataIsLoading}
        />
    ) : (
        <LiveCurves
            adapterVersion={adapterVersion}
            repoUUID={repoUUID}
            job={job}
            websocketData={websocketData}
            websocketOpen={websocketOpen}
            setWebsocketOpen={setWebsocketOpen}
        />
    );
};

export default LearningCurves;
