import { useEffect, useState } from "react";

import { useQueryClient } from "@tanstack/react-query";
import _ from "lodash";
import { Loader, Message, Popup } from "semantic-ui-react";

import {
    adapterVersion,
    finetuningJob,
    finetuningJobStatus,
    repo,
    sftMetricsPayload,
    sftMetricsStepData,
} from "@/autogen/openapi";

import InfoMessage from "../../../../components/InfoMessage";
import { isTerminalJobStatus } from "../../../misc/utils";
import { GET_ADAPTER_VERSION_QUERY_KEY, GET_FINETUNING_JOB_QUERY_KEY } from "../../../query";
import SFTLineChart from "../metrics/SFTLineChart";
import {
    SFTMetricsStepDataKey,
    SFTMetricsStepDataType,
    TURBO_METRIC_NAMES,
    stepCounterText,
} from "../metrics/sft_util";

export const lineChartFormatting = {
    width: "100%",
    overflow: "hidden",
    display: "grid",
    gridTemplateColumns: "repeat(3, 1fr)",
    gap: "5px",
};

function computeExpectedSpeedup(accuracy0: number, accuracy1: number, accuracy2: number): number {
    // Calculate the expected speedup using the formula
    const failure0 = 1 - accuracy0;
    const failure1 = 1 - accuracy1;
    const failure2 = 1 - accuracy2;

    // The denominator is capped at 3 to prevent a speedup of < 1x.
    // This can happen if the accuracies are 0, such as during the beginning of training.
    const denominator = Math.min(1 + failure0 + failure0 * failure1 + failure0 * failure1 * failure2, 3);

    const expectedSpeedup = 3 / denominator;

    return expectedSpeedup;
}

const ExpectedSpeedup = (props: { expectedSpeedup: number }) => {
    const { expectedSpeedup } = props;
    return (
        <Popup
            // content="The expected speedup is based on the accuracy of all three speculation heads."
            content={
                <>
                    <p>
                        <strong>Expected Speedup Calculation</strong>
                    </p>
                    <p>
                        The expected speedup is calculated using the accuracies of the three speculation heads. It
                        estimates the improvement in token generation speed by reducing the need to fall back to the
                        base model.
                    </p>
                    <p>
                        <code>Expected Speedup = 3 / [1 + (1 - A₀) + (1 - A₀)(1 - A₁) + (1 - A₀)(1 - A₁)(1 - A₂)]</code>
                        , where A₀, A₁, and A₂ are the top-1 accuracies of the three speculation heads.
                    </p>
                    <p>
                        <strong>Interpretation:</strong> A higher speedup indicates better speculative efficiency,
                        resulting in faster token generation.
                    </p>
                </>
            }
            trigger={
                <Message size="large">
                    Expected inference speedup based on turbo speculation head accuracies is ~
                    <strong>{expectedSpeedup.toFixed(2)}x</strong>.
                </Message>
            }
            size="large"
            wide="very"
        />
    );
};

const LineChart = (props: { selectedMetrics: SFTMetricsStepDataKey[]; chartData: sftMetricsStepData[] }) => {
    // Parent state:
    const { selectedMetrics, chartData } = props;
    return (
        <div style={{ display: "flex", flexDirection: "row", justifyContent: "flex-start", padding: "20px" }}>
            <div style={{ flex: "0 0 100%" }}>
                <SFTLineChart selectedMetrics={selectedMetrics} chartData={chartData} xAxisName={"checkpoint_number"} />
            </div>
        </div>
    );
};

const HistoricalCurves = (props: {
    historicalData: sftMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const { historicalData, historicalDataIsLoading } = props;

    // Derived metrics state:
    const latestEvent = historicalData?.at(-1);
    const checkpointEvents = historicalData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const chartData: SFTMetricsStepDataType[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    // Local state:
    const [expectedSpeedup, setExpectedSpeedup] = useState<number>(1);

    const latestEventData = latestEvent?.data;
    useEffect(() => {
        // Make sure latestEvent.data is not undefined or null
        if (!_.isEmpty(latestEventData)) {
            const accuracy0 = latestEventData.train_metrics_turbo_0_top1_accuracy || 0;
            const accuracy1 = latestEventData.train_metrics_turbo_1_top1_accuracy || 0;
            const accuracy2 = latestEventData.train_metrics_turbo_2_top1_accuracy || 0;
            const speedup = computeExpectedSpeedup(accuracy0, accuracy1, accuracy2);
            setExpectedSpeedup(speedup);
        }
    }, [latestEventData]);

    return (
        <>
            {historicalDataIsLoading && (
                <div className="loading-overlay" style={{ height: "10rem", background: "none" }}>
                    <Loader active />
                </div>
            )}
            {latestEvent && <ExpectedSpeedup expectedSpeedup={expectedSpeedup} />}
            {latestEvent && chartData && (
                <div style={lineChartFormatting}>
                    {TURBO_METRIC_NAMES.map((metrics, index) => (
                        <LineChart key={index} selectedMetrics={metrics} chartData={chartData} />
                    ))}
                </div>
            )}
        </>
    );
};

const LiveCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketData: sftMetricsPayload[] | undefined;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
}) => {
    // Parent state:
    const { adapterVersion, repoUUID, job, websocketOpen, setWebsocketOpen, websocketData } = props;
    const jobUUID = job.uuid;
    const jobStatus = adapterVersion.status;
    const versionTag = adapterVersion.tag;

    // Derived metrics state:
    const latestEvent = websocketData?.at(-1);
    const latestEventIsTerminal = latestEvent?.meta.is_completed;
    const checkpointEvents = websocketData?.filter((payload) => {
        return !_.isEmpty(payload.data);
    });
    const latestCheckpoint = checkpointEvents?.at(-1);
    const chartData: SFTMetricsStepDataType[] | undefined = checkpointEvents?.map((payload) => {
        return payload.data;
    });

    // Local state:
    const [expectedSpeedup, setExpectedSpeedup] = useState<number>(1);

    const latestEventData = latestEvent?.data;
    useEffect(() => {
        if (!_.isEmpty(latestEventData)) {
            const accuracy0 = latestEventData.train_metrics_turbo_0_top1_accuracy || 0;
            const accuracy1 = latestEventData.train_metrics_turbo_1_top1_accuracy || 0;
            const accuracy2 = latestEventData.train_metrics_turbo_2_top1_accuracy || 0;
            const speedup = computeExpectedSpeedup(accuracy0, accuracy1, accuracy2);
            setExpectedSpeedup(speedup);
        } else if (!_.isEmpty(latestCheckpoint?.data)) {
            // Use the historical data to calculate the expected speedup
            const accuracy0 = latestCheckpoint?.data.train_metrics_turbo_0_top1_accuracy || 0;
            const accuracy1 = latestCheckpoint?.data.train_metrics_turbo_1_top1_accuracy || 0;
            const accuracy2 = latestCheckpoint?.data.train_metrics_turbo_2_top1_accuracy || 0;
            const speedup = computeExpectedSpeedup(accuracy0, accuracy1, accuracy2);
            setExpectedSpeedup(speedup);
        }
    }, [latestEventData]);

    // Query state:
    const queryClient = useQueryClient();
    useEffect(() => {
        if (latestEventIsTerminal || isTerminalJobStatus(jobStatus)) {
            setWebsocketOpen(false);

            queryClient.invalidateQueries({ queryKey: GET_FINETUNING_JOB_QUERY_KEY(jobUUID) });
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_VERSION_QUERY_KEY(repoUUID, versionTag) });
        }
    }, [latestEventIsTerminal, jobStatus, jobUUID, repoUUID, versionTag, setWebsocketOpen]);

    return (
        <>
            {websocketOpen && (
                <InfoMessage
                    header={`Adapter [${adapterVersion.repo}/${adapterVersion.tag}] is ${adapterVersion.status}`}
                    infoMessage={
                        <>
                            <br />
                            <Loader active inline size={"small"} />
                            &nbsp;
                            {stepCounterText(latestEvent, latestCheckpoint)}
                        </>
                    }
                />
            )}
            {latestEvent && <ExpectedSpeedup expectedSpeedup={expectedSpeedup} />}
            {latestEvent && chartData && (
                <div style={lineChartFormatting}>
                    {TURBO_METRIC_NAMES.map((metrics, index) => (
                        <LineChart key={index} selectedMetrics={metrics} chartData={chartData} />
                    ))}
                </div>
            )}
        </>
    );
};

const TurboAccuracyCurves = (props: {
    adapterVersion: adapterVersion;
    repoUUID: repo["uuid"];
    job: finetuningJob;
    websocketOpen: boolean;
    setWebsocketOpen: React.Dispatch<React.SetStateAction<boolean>>;
    websocketData: sftMetricsPayload[] | undefined;
    historicalData: sftMetricsPayload[] | undefined;
    historicalDataIsLoading: boolean;
}) => {
    // Parent state:
    const {
        adapterVersion,
        repoUUID,
        job,
        websocketOpen,
        setWebsocketOpen,
        websocketData,
        historicalData,
        historicalDataIsLoading,
    } = props;
    const jobStatus = adapterVersion.status;

    if (jobStatus === finetuningJobStatus.CANCELED) {
        return null;
    }

    return isTerminalJobStatus(jobStatus) ? (
        <HistoricalCurves historicalData={historicalData} historicalDataIsLoading={historicalDataIsLoading} />
    ) : (
        <LiveCurves
            adapterVersion={adapterVersion}
            repoUUID={repoUUID}
            job={job}
            websocketData={websocketData}
            websocketOpen={websocketOpen}
            setWebsocketOpen={setWebsocketOpen}
        />
    );
};

export default TurboAccuracyCurves;
