import React from "react";
import { Bar, BarChart, Legend, ResponsiveContainer, Tooltip, XAxis, YAxis } from "recharts";
import { formatValueToNumericString } from "../../../../utils/numbers";
import { metricsRegex } from "../util";

function ModelMetricsBarGraph(props: {
    isHyperopt: boolean;
    selectedRuns: string[];
    selectedMetrics: string[];
    runToNumberMap: React.MutableRefObject<Record<string, number>>;
    metricHistory: MetricHistory;
    chartData: ModelChartHistory;
    getMetricKey: (runID: string, metricName: string) => string;
}) {
    if (!props.chartData || props.chartData.length === 0) {
        return null;
    }

    const d: Record<string, any> = {};
    props.selectedRuns?.forEach((run) => {
        let lastMetrics: Record<string, number | string>;

        // For streaming, the key in metricHistory will be "default", but the key in selectedMetrics
        // may be the actual trial uuid.
        if (props.metricHistory[run] === undefined && !props.isHyperopt) {
            lastMetrics = props.metricHistory.default[props.metricHistory.default.length - 1];
        } else if (props.metricHistory[run] !== undefined) {
            lastMetrics = props.metricHistory[run][props.metricHistory[run].length - 1];
        } else {
            return;
        }

        props.selectedMetrics?.forEach((metric) => {
            const res = metricsRegex.exec(metric);
            if (res === null) {
                return;
            }
            const [, metricType, target, metricName] = res;
            const partialMetricName = props.getMetricKey(run, target + "." + metricName);
            if (!d[partialMetricName]) {
                d[partialMetricName] = {
                    name: partialMetricName,
                };
            }
            if (metricType === "train_metrics") {
                d[partialMetricName].train = lastMetrics[metric];
            }
            if (metricType === "validation_metrics") {
                d[partialMetricName].vali = lastMetrics[metric];
            }
            if (metricType === "test_metrics") {
                d[partialMetricName].test = lastMetrics[metric];
            }
        });
    });

    const data = Object.values(d);

    // https://github.com/recharts/recharts/issues/172#issuecomment-307858843
    return (
        <ResponsiveContainer width={"99.8%"} aspect={2}>
            <BarChart data={data} style={{ overflow: "none" }}>
                <XAxis dataKey={"name"} angle={-30} textAnchor="end" interval={0} />
                <YAxis tickFormatter={(value, _) => formatValueToNumericString(value, 2)} />
                <Tooltip />
                <Legend wrapperStyle={{ bottom: -70, paddingBottom: "70px" }} />
                <Bar dataKey="train" fill="#a83045" />
                <Bar dataKey="vali" fill="#f57c0e" />
                <Bar dataKey="test" fill="#4787c5" />
            </BarChart>
        </ResponsiveContainer>
    );
}

export default ModelMetricsBarGraph;
