import React, { useEffect, useState } from "react";
import { Link } from "react-router-dom";
import { Checkbox, Divider, Header, Popup, Table } from "semantic-ui-react";
import FormattedNumberTooltip from "../../components/FormattedNumberTooltip";
import GrayDash from "../../components/GrayDash";
import metrics from "../../metrics/metrics";
import { cssTruncate } from "../../utils/overflow";
import { alphabeticalCollator } from "../../utils/sort";
import { formatMetricName, generalMetricsKeys, isGBMModel, isTrainValTestMetric } from "../util";

const ignoreMetricNames = ["last_improvement_steps", "tune_checkpoint_num"];

const getMetricsCells = (model: Model, specificMetricKeys: string[]) => {
    return (
        <>
            {generalMetricsKeys.map((x) => {
                if (isGBMModel(model.config) && x === "batch_size") {
                    return (
                        <Table.Cell key={x} style={{ borderRight: "blue !important" }}>
                            <GrayDash />
                        </Table.Cell>
                    );
                }

                const idx = (model.modelMetrics || []).map((metric) => metric.runMetricName).indexOf(x);
                if (idx !== -1 && model.modelMetrics) {
                    return (
                        <Table.Cell key={x} style={{ borderRight: "blue !important" }}>
                            <FormattedNumberTooltip value={model.modelMetrics[idx].metricValue} />
                        </Table.Cell>
                    );
                }
                return <Table.Cell />;
            })}
            {specificMetricKeys.map((x) => {
                const idx = (model.modelMetrics || []).map((metric) => metric.runMetricName).indexOf(x);
                if (idx !== -1 && model.modelMetrics) {
                    return (
                        <Table.Cell key={x}>
                            <FormattedNumberTooltip value={model.modelMetrics[idx].metricValue} />
                        </Table.Cell>
                    );
                }
                return <Table.Cell />;
            })}
        </>
    );
};

function ModelCompareTable(props: {
    modelVersions: Model[];
    repo: ModelRepo;
    checkedModels: (Model | null)[];
    setCheckedModels: React.Dispatch<React.SetStateAction<(Model | null)[]>>;
    disabledCriteria?: (model: Model) => React.ReactNode;
}) {
    const [specificMetricKeys, setSpecificMetricKeys] = useState<string[]>();

    useEffect(() => {
        const s = new Set<string>();
        props.modelVersions.forEach((x) =>
            x.modelMetrics?.forEach((metric) => {
                if (
                    !generalMetricsKeys.includes(metric.metricName) &&
                    !metric.metricName.startsWith("num_") &&
                    !ignoreMetricNames.includes(metric.metricName)
                ) {
                    s.add(metric.runMetricName);
                }
            }),
        );
        const arr = Array.from(s).filter((x) => x.startsWith("best."));
        setSpecificMetricKeys(
            arr.sort((a, b) => {
                // Omit the train/val/test prefix when sorting (i.e. <train_metrics.Survived.accuracy>)
                const a2 = isTrainValTestMetric(a) ? a.slice(a.indexOf(".") + 1) : a;
                const b2 = isTrainValTestMetric(b) ? b.slice(b.indexOf(".") + 1) : b;
                return alphabeticalCollator.compare(a2, b2);
            }),
        );
    }, [props.modelVersions]);

    return (
        <div style={{ width: "100%", overflowX: "auto", maxHeight: "50vh", overflowY: "auto" }}>
            <Table selectable={props.modelVersions.length > 0} className={"model-compare-table-freeze"}>
                <Table.Header>
                    <Table.Row>
                        <Table.HeaderCell>Model</Table.HeaderCell>
                        {generalMetricsKeys.map((x) => (
                            <Table.HeaderCell key={x}>{x}</Table.HeaderCell>
                        ))}
                        {specificMetricKeys?.map((x) => (
                            <Table.HeaderCell key={x}>{formatMetricName(x)}</Table.HeaderCell>
                        ))}
                    </Table.Row>
                </Table.Header>
                <Table.Body>
                    {props.modelVersions.length == 0 ? (
                        <Table.Row>
                            <Table.Cell colSpan={9} textAlign={"center"} verticalAlign={"middle"}>
                                <Divider hidden />
                                <img src={"/model/emptyRepos.svg"} alt="" />
                                <Header as="h2" size={"medium"} style={{ marginBottom: "0.5rem" }}>
                                    Looks like you don't have any models yet!
                                </Header>
                                <Link style={{ fontSize: "0.9em" }} to={"/models/edit/train/repo/" + props.repo.id}>
                                    Train your first model version
                                </Link>
                                <Divider hidden />
                            </Table.Cell>
                        </Table.Row>
                    ) : null}
                    {props.modelVersions.map((model) => {
                        const disabled = props.disabledCriteria?.(model);
                        const checked = (props.checkedModels || []).map((x) => x?.id).includes(model.id);
                        return (
                            <Table.Row key={model.id}>
                                <Table.Cell
                                    collapsing
                                    verticalAlign={"middle"}
                                    style={{
                                        maxWidth: "20vw",
                                        ...cssTruncate,
                                        display: undefined,
                                        fontWeight: "normal",
                                    }}
                                >
                                    {disabled && !checked ? (
                                        <Popup
                                            position={"left center"}
                                            flowing
                                            className={"transition-scale"}
                                            trigger={
                                                <Checkbox
                                                    style={{ verticalAlign: "middle" }}
                                                    className={"gray-disabled-checkbox"}
                                                    disabled
                                                />
                                            }
                                            content={disabled}
                                        />
                                    ) : (
                                        <Checkbox
                                            style={{ verticalAlign: "middle" }}
                                            checked={checked}
                                            onClick={(event, data) => {
                                                if (data.checked) {
                                                    props.setCheckedModels((old) => [...old, model]);
                                                } else {
                                                    props.setCheckedModels((old) =>
                                                        old.filter((x) => x?.id !== model.id),
                                                    );
                                                }
                                            }}
                                        />
                                    )}
                                    &emsp;
                                    <Link
                                        className={metrics.BLOCK_AUTO_CAPTURE}
                                        onClick={() =>
                                            metrics.captureClick("Link.ModelVersion", {
                                                value: model.id,
                                            })
                                        }
                                        to={"/models/version/" + model.id}
                                    >
                                        {model.repoVersion}
                                    </Link>
                                    &emsp;
                                    <Link
                                        className={metrics.BLOCK_AUTO_CAPTURE + " black-link"}
                                        onClick={() =>
                                            metrics.captureClick("Link.ModelVersion", {
                                                value: model.id,
                                            })
                                        }
                                        to={"/models/version/" + model.id}
                                    >
                                        {model.description}
                                    </Link>
                                </Table.Cell>
                                {getMetricsCells(model, specificMetricKeys || [])}
                            </Table.Row>
                        );
                    })}
                </Table.Body>
            </Table>
        </div>
    );
}

export default ModelCompareTable;
