import { AxiosInstance, AxiosResponse } from "axios";
import { useEffect, useMemo, useRef, useState } from "react";
import { useAuth0TokenOptions } from "../../data";
import metrics from "../../metrics/metrics";
import { createV1APIServer, redirectIfSessionInvalid } from "../../utils/api";
import { getErrorMessage } from "../../utils/errors";
import {
    getBestModelMetricsFromRawMetrics,
    getMetricKeysFromRuns,
    getSpecificMetricsFromRawMetrics,
} from "../utilLearningCurves";
import ModelMetricsGraph from "../version/modelmetrics/ModelMetricsGraph";
import ModelMetricsTable from "../version/modelmetrics/ModelMetricsTable";

function ModelCompareLearningCurves(props: { models: Model[] }) {
    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    const [apiServer, setAPIServer] = useState<AxiosInstance | null>(null);

    useEffect(() => {
        const getAPIServer = async () => {
            const v1APIServer = await createV1APIServer(auth0TokenOptions);
            // NOTE: Whoever wrote the axios typings is a moron because the return type of axios.create is not
            // AxiosInstance -- it's a wrap function. And React will see that and treat it as a callback that
            // setState should directly call. FML.
            // See: [1], [2]:
            // [1]: https://github.com/axios/axios/issues/4365
            // [2]: https://stackoverflow.com/questions/64427195/calling-setstate-will-execute-the-function-value-instead-of-passing-it
            setAPIServer(() => v1APIServer);
        };
        getAPIServer();
    }, []);

    const [experimentRuns, setExperimentRuns] = useState<ModelRun[]>([]);

    const [metricHistory, setMetricHistory] = useState<MetricHistory>({});
    const [specificMetrics, setSpecificMetrics] = useState<SpecificModelMetrics>({});
    const [bestModelMetrics, setBestModelMetrics] = useState<BestModelMetrics>({});

    const [selectedMetrics, setSelectedMetrics] = useState<string[]>([]);
    const selectedRuns = useMemo(() => experimentRuns.map((x) => x.info.run_id), [experimentRuns]);

    const runToNumberMap = useRef<Record<string, number>>({});
    useEffect(() => {
        if (apiServer === null) {
            return;
        }
        if (props.models.length === experimentRuns.length) {
            return;
        }

        if (props.models.length < experimentRuns.length) {
            setExperimentRuns((runs) => {
                return runs.filter((run) => props.models.map((x) => x?.activeRunID).includes(run.info.run_id));
            });
        } else {
            const newModels = props.models.filter(
                (model) => !experimentRuns.some((run) => run.info.run_id === model.activeRunID),
            );
            const newRuns: ModelRun[] = [];

            // We already have a list of models in props.models, but none of these contain run metric information. Every time
            // a model is selected, we need to fetch the run information for that model.
            Promise.all(newModels.map((model) => apiServer.get("models/version/" + model.id + "?withRuns=true"))).then(
                (responses) => {
                    responses.forEach((res: AxiosResponse<GetModelWithRunsResponse>) => {
                        const newRun = res.data.runs?.find((x) => x.info.run_id === res.data.modelVersion.activeRunID);
                        if (res.data.modelVersion.activeRunID) {
                            runToNumberMap.current[res.data.modelVersion.activeRunID] = Number(
                                res.data.modelVersion.repoVersion,
                            );
                        }
                        if (newRun) {
                            newRuns.push(newRun);
                        }
                    });

                    // Auto-select all non-combined loss metrics on first load
                    if (experimentRuns.length === 0) {
                        const sm =
                            selectedMetrics.length > 0
                                ? selectedMetrics
                                : getMetricKeysFromRuns(newRuns).filter(
                                      (x) =>
                                          !x.startsWith("best.") && x.includes(".loss") && !x.includes("combined.loss"),
                                  );
                        setSelectedMetrics(sm);
                        fetchMetricHistory(
                            sm,
                            newRuns.map((x) => x.info.run_id),
                            metricHistory,
                        );
                    } else {
                        fetchMetricHistory(
                            selectedMetrics,
                            newRuns.map((x) => x.info.run_id),
                            metricHistory,
                        );
                    }
                    setExperimentRuns((runs) => [...runs, ...newRuns]);
                },
            );
        }
    }, [props.models]);

    useEffect(() => {
        const metricKeys = getMetricKeysFromRuns(experimentRuns);
        const rawMetrics: Partial<RawCompareModelMetrics> = {};
        metricKeys.forEach((key) => {
            rawMetrics[key] = null;
        });

        setSpecificMetrics(getSpecificMetricsFromRawMetrics(rawMetrics));
    }, [experimentRuns]);

    useEffect(() => {
        const metricKeys = getMetricKeysFromRuns(experimentRuns);
        const rawMetrics: Partial<RawCompareModelMetrics> = {};
        metricKeys.forEach((key) => {
            rawMetrics[key] = null;
        });

        setBestModelMetrics(getBestModelMetricsFromRawMetrics(rawMetrics));
    }, [experimentRuns]);

    const fetchMetricHistory = (metricKeys: string[], runKeys: string[], mh = metricHistory) => {
        const numMetrics = metricKeys.length;
        Promise.all(
            runKeys.flatMap((run) => {
                return metricKeys.map((metricKey) => apiServer?.get("models/metrics/history/" + run + "/" + metricKey));
            }),
        )
            .then((responses) => {
                let newMetricHistory: MetricHistory = { ...mh };
                responses.forEach((runMetric, i) => {
                    if (runMetric?.data.metrics) {
                        const run_id = runKeys[Math.floor(i / numMetrics)];
                        if (!newMetricHistory[run_id]) {
                            newMetricHistory[run_id] = [];
                        }

                        newMetricHistory[run_id] = ((runMetric.data.metrics as ModelRunMetrics[]) || [])?.map(
                            (metric, i) => {
                                let currentMetric = newMetricHistory[run_id][i] ? newMetricHistory[run_id][i] : {};
                                return {
                                    ...currentMetric,
                                    [metric.key]: metric.value,
                                };
                            },
                        );
                    }
                });
                setMetricHistory(newMetricHistory);
            })
            .catch((error) => {
                const errorMsg = getErrorMessage(error) ?? "";
                redirectIfSessionInvalid(errorMsg);
            });
    };

    const handleMetricClick = (metricKey: string) => {
        const metricName = "Model.ModelCompare.Metric";
        if (selectedMetrics.includes(metricKey)) {
            // remove
            let newSelectedMetrics = selectedMetrics.filter((x) => x !== metricKey);
            setSelectedMetrics(newSelectedMetrics);
            metrics.captureRemove(metricName, { name: metricKey, value: newSelectedMetrics, metricType: "normal" });
        } else {
            // add
            setSelectedMetrics((old) => [...old, metricKey]);
            metrics.captureAdd(metricName, { name: metricKey, value: selectedMetrics, metricType: "normal" });
            fetchMetricHistory([metricKey], selectedRuns);
        }
    };

    const batchHandleMetricClick = (metricKeys: string[], force = false) => {
        const metricName = "Model.ModelCompare.Metric.Batch";

        const existingMetrics = new Set(selectedMetrics);
        let addedMetrics = new Set<string>();
        let removedMetrics = new Set<string>();

        if (force === true) {
            addedMetrics = new Set<string>(metricKeys.filter((x) => !existingMetrics.has(x)));
        } else if (force === false) {
            removedMetrics = new Set<string>(metricKeys.filter((x) => existingMetrics.has(x)));
        } else {
            metricKeys.forEach((x) => {
                if (selectedMetrics.includes(x)) {
                    removedMetrics.add(x);
                } else {
                    addedMetrics.add(x);
                }
            });
        }

        const addedMetricsArr = Array.from(addedMetrics);
        const removedMetricsArr = Array.from(removedMetrics);

        let newSelectedMetrics = selectedMetrics.filter((item) => !removedMetrics.has(item));
        newSelectedMetrics = newSelectedMetrics.concat(addedMetricsArr);
        setSelectedMetrics(newSelectedMetrics);

        metrics.captureRemove(metricName, {
            added: addedMetricsArr,
            removed: removedMetricsArr,
            value: newSelectedMetrics,
            metricType: "normal",
        });

        let newMetricHistory = { ...metricHistory };
        removedMetrics.forEach((metric) => {
            Object.keys(newMetricHistory).forEach((trial) => {
                newMetricHistory[trial].forEach((epoch) => {
                    delete epoch[metric];
                });
            });
        });

        fetchMetricHistory(addedMetricsArr, selectedRuns, newMetricHistory);
    };

    return (
        <div style={{ display: "flex", flexDirection: "row", justifyContent: "flex-start" }}>
            <div style={{ display: "flex" }}>
                <ModelMetricsTable
                    metricHistory={metricHistory}
                    metrics={specificMetrics}
                    bestModelMetrics={bestModelMetrics}
                    selectedMetrics={selectedMetrics}
                    handleMetricClick={handleMetricClick}
                    batchHandleMetricClick={batchHandleMetricClick}
                    isCompare
                />
            </div>
            <div style={{ display: "flex", width: "100%", paddingLeft: `${16 / 14}rem` }}>
                <ModelMetricsGraph
                    models={props.models}
                    metricHistory={metricHistory}
                    bestModelMetrics={bestModelMetrics}
                    selectedRuns={props.models.filter((x) => x != null && x.activeRunID).map((x) => x.activeRunID!)}
                    selectedMetrics={selectedMetrics}
                    runToNumberMap={runToNumberMap}
                    isHyperopt={true}
                />
            </div>
        </div>
    );
}

export default ModelCompareLearningCurves;
