import React from "react";
import { Tab } from "semantic-ui-react";
import ModelMetricsBarGraph from "./metricsgraphs/ModelMetricsBarGraph";
import ModelMetricsLineGraph from "./metricsgraphs/ModelMetricsLineGraph";

function ModelMetricsGraph(props: {
    models: Model[];
    selectedRuns: string[];
    selectedMetrics: string[];
    metricHistory: MetricHistory;
    bestModelMetrics?: BestModelMetrics;
    runToNumberMap: React.MutableRefObject<Record<string, number>>;
    isHyperopt: boolean;
}) {
    const getMetricKey = (runID: string, metricName: string, shortened: boolean = false) => {
        const m = shortened
            ? metricName
                  .replace("train_metrics.", "train.")
                  .replace("validation_metrics.", "vali.")
                  .replace("test_metrics.", "test.")
            : metricName;
        return props.isHyperopt ? getHyperoptMetricKey(runID, m) : m;
    };

    const getHyperoptMetricKey = (runID: string, metricName: string) => {
        const runString = props.runToNumberMap.current[runID];
        const runPrefix = runString ? runString + "." : "";
        return runPrefix + metricName;
    };

    /**
     * props.MetricHistory is Record<runID, metrics[]>, but the charting library only takes an array.
     * We have to merge all the runs into one array, with unique metricKeys prefixed with trial #:
     * [1.train_metrics.combined.loss, 2.train_metrics.combined.loss, etc.]
     */
    const convertMetricHistoryHyperopt = () => {
        let maxEpochs = Math.max(...Object.values(props.metricHistory).map((x) => x.length), 0);
        const graphHistory = Array(maxEpochs).fill(null);

        for (const [run, metricData] of Object.entries(props.metricHistory)) {
            metricData.forEach((epochMetrics, epoch) => {
                if (graphHistory[epoch] === null) {
                    graphHistory[epoch] = {};
                }
                for (const [metricKey, metricValue] of Object.entries(epochMetrics)) {
                    graphHistory[epoch][getHyperoptMetricKey(run, metricKey)] = metricValue;
                }
            });
        }
        return graphHistory;
    };

    if (!props.metricHistory || props.selectedRuns === undefined || props.selectedRuns.length === 0) {
        return null;
    }

    const chartData: ModelChartHistory = props.isHyperopt
        ? convertMetricHistoryHyperopt()
        : props.metricHistory.default || props.metricHistory[props.selectedRuns[0]];

    if (!props.metricHistory) {
        return null;
    }

    const panes = [
        {
            key: "line",
            menuItem: "Line",
            render: () => (
                <Tab.Pane>
                    <ModelMetricsLineGraph
                        isHyperopt={props.isHyperopt}
                        models={props.models}
                        selectedRuns={props.selectedRuns}
                        selectedMetrics={props.selectedMetrics}
                        chartData={chartData}
                        bestModelMetrics={props.bestModelMetrics}
                        runToNumberMap={props.runToNumberMap}
                        getMetricKey={getMetricKey}
                    />
                </Tab.Pane>
            ),
        },
        {
            key: "bar",
            menuItem: "Bar",
            render: () => (
                <Tab.Pane>
                    <ModelMetricsBarGraph
                        isHyperopt={props.isHyperopt}
                        selectedRuns={props.selectedRuns}
                        selectedMetrics={props.selectedMetrics}
                        runToNumberMap={props.runToNumberMap}
                        metricHistory={props.metricHistory}
                        chartData={chartData}
                        getMetricKey={getMetricKey}
                    />
                </Tab.Pane>
            ),
        },
    ];

    return <Tab panes={panes} style={{ width: "100%", height: "100%" }} menu={{ secondary: true, pointing: true }} />;
}

export default ModelMetricsGraph;
