import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import React, { useState } from "react";
import { Link } from "react-router-dom";
import { useRecoilState } from "recoil";
import { Loader, MenuItem, Message, Popup, Tab, TabProps } from "semantic-ui-react";
import InfoMessage from "../../components/InfoMessage";
import { track } from "../../metrics/june";
import metrics from "../../metrics/metrics";
import { USER_STATE } from "../../state/global";
import { ModelStatus } from "../../types/model/modelStatus";
import { modelStatusValueMap } from "../../utils/sort";
import {
    calibrationPlotsDisabledCriteria,
    checkIfHyperoptEnabled,
    confusionMatrixDisabledCriteria,
    f1DisabledCriteria,
    featureImportanceDisabledCriteria,
    isFineTunedModel,
    isLLMModel,
    isTrainingStatus,
    prCurveDisabledCriteria,
    rocDisabledCriteria,
} from "../util";
import CalibrationPlotsViewer from "../viz/CalibrationPlotsViewer";
import ConfusionMatrix from "../viz/ConfusionMatrixViewer";
import EvalStatsViewer from "../viz/EvalStatsViewer";
import F1Viewer from "../viz/F1Viewer";
import FeatureImportanceViewer from "../viz/FeatureImportanceViewer";
import HyperoptReportViewer from "../viz/HyperoptReportViewer";
import ModelComputeViewer from "../viz/ModelComputeViewer";
import ModelGraphViewer from "../viz/ModelGraphViewer";
import PRCurveViewer from "../viz/PRCurveViewer";
import ROCViewer from "../viz/ROCViewer";
import { VizType } from "../viz/util";
import ModelConfigViewer from "./ModelConfigViewer";
import ModelMetricsGraph from "./modelmetrics/ModelMetricsGraph";
import ModelMetricsTable from "./modelmetrics/ModelMetricsTable";

const ModelConfigViewerPane = (model: Model) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"modelConfigViewer"}>
        <ModelConfigViewer config={model.config} />
    </Tab.Pane>
);

const ModelGraphViewerPane = (model: Model) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"modelGraphViewer"}>
        <div style={{ minHeight: "500px" }}>
            <ModelGraphViewer model={model} />
        </div>
    </Tab.Pane>
);

const ModelComputeViewerPane = (model: Model) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"modelComputeViewer"}>
        <div style={{ minHeight: "500px" }}>
            <ModelComputeViewer model={model} />
        </div>
    </Tab.Pane>
);

const LogsViewerPane = (model: Model) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"logsViewerPane"}>
        <Message negative>
            <Message.Header>Error in model training</Message.Header>
            <div style={{ whiteSpace: "pre-wrap" }}>{model.errorText}</div>
        </Message>
    </Tab.Pane>
);

const ConfusionMatrixViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"confusionMatrixViewer"}>
        <div style={{ minHeight: "500px" }}>
            <ConfusionMatrix model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const RocViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"rocViewer"}>
        <div style={{ minHeight: "500px" }}>
            <ROCViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const PRCurveViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"prCurveViewer"}>
        <div style={{ minHeight: "500px" }}>
            <PRCurveViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const F1ViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"f1Viewer"}>
        <div style={{ minHeight: "500px" }}>
            <F1Viewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const HyperoptReportViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"hyperoptReportViewer"}>
        <div style={{ minHeight: "500px" }}>
            <HyperoptReportViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const FeatureImportancePane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"featureImportance"}>
        <div style={{ minHeight: "500px" }}>
            <FeatureImportanceViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const CalibrationPlotsViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"calibrationPlots"}>
        <div style={{ minHeight: "500px" }}>
            <CalibrationPlotsViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const EvalStatsViewerPane = (model: Model, errorMessage?: string) => (
    <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"allMetrics"}>
        <div style={{ minHeight: "500px" }}>
            <EvalStatsViewer model={model} errorMessage={errorMessage} />
        </div>
    </Tab.Pane>
);

const VizMenuItem = (name: string, model: Model, disabledMsg?: React.ReactNode, disabledStep?: ModelStatus) => {
    const vizStep = disabledStep || ModelStatus.VISUALIZING;
    if (modelStatusValueMap[model.status] <= modelStatusValueMap[vizStep]) {
        return (
            <MenuItem disabled key={name}>
                <Popup
                    className={"transition-scale"}
                    position={"top center"}
                    content={`Visualizations are not available until model has finished ${vizStep}.`}
                    trigger={<span>{name}</span>}
                />
            </MenuItem>
        );
    }
    if (disabledMsg) {
        return (
            <MenuItem disabled key={name}>
                <Popup
                    className={"transition-scale"}
                    position={"top center"}
                    content={disabledMsg}
                    trigger={<span>{name}</span>}
                />
            </MenuItem>
        );
    }
    return name;
};

const EvalMenuItem = (name: string, model: Model, disabledMsg?: React.ReactNode) => {
    if (modelStatusValueMap[model.status] <= modelStatusValueMap[ModelStatus.EVALUATING]) {
        return (
            <MenuItem disabled key={name}>
                <Popup
                    className={"transition-scale"}
                    position={"top center"}
                    content={"Evaluation statistics are not available until model has finished evaluation."}
                    trigger={<span>{name}</span>}
                />
            </MenuItem>
        );
    }
    if (disabledMsg) {
        return (
            <MenuItem disabled key={name}>
                <Popup
                    className={"transition-scale"}
                    position={"top center"}
                    content={disabledMsg}
                    trigger={<span>{name}</span>}
                />
            </MenuItem>
        );
    }
    return name;
};

const getTabIndexForStatus = (status: string) => {
    if (status === ModelStatus.STOPPING || status === ModelStatus.CANCELED) {
        return 0;
    }
    if (status === ModelStatus.FAILED) {
        return 4;
    }
    return 2;
};

const stepsPerMessageText = (currentStep?: number, stepsPerEpoch?: number, totalSteps?: number) => {
    if (currentStep === undefined || stepsPerEpoch === undefined || totalSteps === undefined) {
        return "";
    }

    let stepsPerMessage = `Step ${currentStep} / ${totalSteps}`;
    if (currentStep >= 0 && currentStep <= stepsPerEpoch) {
        stepsPerMessage = `Waiting for model metrics... Step ${currentStep} / ${stepsPerEpoch} for first checkpoint.`;
    }

    return stepsPerMessage;
};

const StepsMessage = (props: {
    isTraining: boolean;
    currentStep?: number;
    stepsPerEpoch?: number;
    totalSteps?: number;
    modelName?: string;
}) => {
    const { isTraining, currentStep, stepsPerEpoch, totalSteps, modelName } = props;
    const isActive = isTraining || currentStep === undefined || stepsPerEpoch === undefined || totalSteps === undefined;

    if (!isActive) {
        return null;
    }

    const stepsPerMessage = stepsPerMessageText(currentStep, stepsPerEpoch, totalSteps);
    if (!stepsPerMessage) {
        return null;
    }

    return (
        <InfoMessage
            header={`Model [${modelName}] is training`}
            infoMessage={
                <React.Fragment>
                    <br />
                    <Loader active inline size={"small"} />
                    &nbsp;
                    {stepsPerMessage}
                </React.Fragment>
            }
        />
    );
};

function ModelVersionSubmenu(props: {
    modelVersion: Model;
    inProgress: boolean;
    metricHistory: MetricHistory;
    selectedMetrics: string[];
    handleMetricClick: (metricKey: string) => void;
    batchHandleMetricClick: (metricKeys: string[], force?: boolean) => void;
    specificMetrics: SpecificModelMetrics;
    bestModelMetrics?: BestModelMetrics;
    selectedRuns: string[];
    experimentRuns: ModelRun[];
    runToNumberMap: React.MutableRefObject<Record<string, number>>;
    timeline?: ModelTimeline;
    modelSteps: ModelSteps;
    llmSampleOutputs?: LlmSampleOutput[];
}) {
    const [tabIndex, setTabIndex] = useState(getTabIndexForStatus(props.modelVersion.status));
    const [user] = useRecoilState(USER_STATE);
    const isHyperopt = checkIfHyperoptEnabled(props.modelVersion.config);
    let showCompareLink = false;

    let learningCurves: JSX.Element | null = null;
    if (props.specificMetrics && Object.keys(props.specificMetrics).length > 0) {
        learningCurves = (
            <div style={{ display: "flex", flexDirection: "row", justifyContent: "flex-start" }}>
                <div style={{ display: "flex" }}>
                    <ModelMetricsTable
                        metricHistory={props.metricHistory}
                        metrics={props.specificMetrics}
                        bestModelMetrics={props.bestModelMetrics}
                        selectedMetrics={props.selectedMetrics}
                        handleMetricClick={props.handleMetricClick}
                        batchHandleMetricClick={props.batchHandleMetricClick}
                    />
                </div>
                <div style={{ display: "flex", width: "100%", paddingLeft: `${16 / 14}rem`, maxHeight: "40vh" }}>
                    <ModelMetricsGraph
                        metricHistory={props.metricHistory}
                        models={[props.modelVersion]}
                        isHyperopt={
                            /**
                             * For streaming, the key in metricHistory will be "default", but the key in selectedMetrics
                             * will be the actual trial uuid.
                             */
                            props.experimentRuns?.length > 1 ||
                            (props.metricHistory &&
                                !props.metricHistory.default &&
                                Object.keys(props.metricHistory).length > 1)
                        }
                        selectedRuns={props.selectedRuns}
                        selectedMetrics={props.selectedMetrics}
                        runToNumberMap={props.runToNumberMap}
                        bestModelMetrics={props.bestModelMetrics}
                    />
                </div>
            </div>
        );
    }

    let LearningCurvesPane = (
        <Tab.Pane className={`model-tab ${metrics.BLOCK_AUTO_CAPTURE}`} key={"learningCurves"}>
            <StepsMessage
                isTraining={isTrainingStatus(props.modelVersion.status) && props.inProgress}
                currentStep={props.modelSteps.steps}
                stepsPerEpoch={props.modelSteps.stepsPerEpoch}
                totalSteps={props.modelSteps.totalSteps}
                modelName={props.modelVersion?.repo?.modelName}
            />
            {learningCurves}
        </Tab.Pane>
    );

    const vizErrorMessage = props.timeline?.[ModelStatus.VISUALIZING]?.errorMessage;
    const evalErrorMessage = props.timeline?.[ModelStatus.EVALUATING]?.errorMessage;
    const explainErrorMessage = props.timeline?.[ModelStatus.EXPLAINING]?.errorMessage;

    // Visualization error can be a map of viz_type -> error or a single error
    let vizErrorMap: Record<string, string> = {};
    if (vizErrorMessage) {
        try {
            vizErrorMap = JSON.parse(vizErrorMessage);
        } catch (err) {
            console.error("Could not parse visualization error as JSON");
        }
    }

    const getVizError = (vizType: VizType) => {
        if (Object.keys(vizErrorMap).length > 0) {
            return vizErrorMap[vizType];
        }
        return vizErrorMessage;
    };

    let panes: any[] = [];
    if (isLLMModel(props.modelVersion.config)) {
        panes = [
            { key: "modelConfigViewer", menuItem: "Config", render: () => ModelConfigViewerPane(props.modelVersion) },
            { key: "modelGraphViewer", menuItem: "Graph", render: () => ModelGraphViewerPane(props.modelVersion) },
            isFineTunedModel(props.modelVersion.config) && {
                key: "learningCurves",
                menuItem: "Learning Curves",
                render: () => LearningCurvesPane,
            },
            {
                key: "modelComputeViewer",
                menuItem: "Compute",
                render: () => ModelComputeViewerPane(props.modelVersion),
            },
            props.modelVersion.errorText && {
                key: "logsViewer",
                menuItem: "Logs",
                render: () => LogsViewerPane(props.modelVersion),
            },
        ];
    } else if (
        props.modelVersion.status === ModelStatus.STOPPING ||
        props.modelVersion.status === ModelStatus.CANCELED
    ) {
        panes = [
            { key: "modelConfigViewer", menuItem: "Config", render: () => ModelConfigViewerPane(props.modelVersion) },
            { key: "modelGraphViewer", menuItem: "Graph", render: () => ModelGraphViewerPane(props.modelVersion) },
            learningCurves && { key: "learningCurves", menuItem: "Learning Curves", render: () => LearningCurvesPane },
            {
                key: "modelComputeViewer",
                menuItem: "Compute",
                render: () => ModelComputeViewerPane(props.modelVersion),
            },
        ];
    } else if (props.modelVersion.status === ModelStatus.FAILED) {
        panes = [
            { key: "modelConfigViewer", menuItem: "Config", render: () => ModelConfigViewerPane(props.modelVersion) },
            { key: "modelGraphViewer", menuItem: "Graph", render: () => ModelGraphViewerPane(props.modelVersion) },
            learningCurves && { key: "learningCurves", menuItem: "Learning Curves", render: () => LearningCurvesPane },
            {
                key: "modelComputeViewer",
                menuItem: "Compute",
                render: () => ModelComputeViewerPane(props.modelVersion),
            },
            { key: "logsViewer", menuItem: "Logs", render: () => LogsViewerPane(props.modelVersion) },
        ];
    } else {
        panes = [
            { key: "modelConfigViewer", menuItem: "Config", render: () => ModelConfigViewerPane(props.modelVersion) },
            { key: "modelGraphViewer", menuItem: "Graph", render: () => ModelGraphViewerPane(props.modelVersion) },
            { key: "learningCurves", menuItem: "Learning Curves", render: () => LearningCurvesPane },
            {
                key: "confusionMatrixViewer",
                menuItem: VizMenuItem(
                    "Confusion Matrix",
                    props.modelVersion,
                    confusionMatrixDisabledCriteria(props.modelVersion),
                ),
                render: () => ConfusionMatrixViewerPane(props.modelVersion, getVizError(VizType.CONFUSION_MATRIX)),
            },
            {
                key: "rocViewer",
                menuItem: VizMenuItem("ROC Curve", props.modelVersion, rocDisabledCriteria(props.modelVersion)),
                render: () => RocViewerPane(props.modelVersion, getVizError(VizType.ROC_CURVES_FROM_TEST_STATISTICS)),
            },
            {
                key: "prCurveViewer",
                menuItem: VizMenuItem(
                    "Precision Recall Curve",
                    props.modelVersion,
                    prCurveDisabledCriteria(props.modelVersion),
                ),
                render: () =>
                    PRCurveViewerPane(
                        props.modelVersion,
                        getVizError(VizType.PRECISION_RECALL_CURVES_FROM_TEST_STATISTICS),
                    ),
            },
            {
                key: "f1Viewer",
                menuItem: VizMenuItem("F1", props.modelVersion, f1DisabledCriteria(props.modelVersion)),
                render: () => F1ViewerPane(props.modelVersion, getVizError(VizType.FREQUENCY_VS_F1)),
            },
            {
                key: "calibrationPlotsViewer",
                menuItem: VizMenuItem(
                    "Calibration Plots",
                    props.modelVersion,
                    calibrationPlotsDisabledCriteria(props.modelVersion),
                ),
                render: () => CalibrationPlotsViewerPane(props.modelVersion, getVizError(VizType.CALIBRATION_1_VS_ALL)),
            },
            {
                key: "featureImportanceViewer",
                menuItem: VizMenuItem(
                    "Feature Importance",
                    props.modelVersion,
                    featureImportanceDisabledCriteria(props.modelVersion),
                    ModelStatus.EXPLAINING,
                ),
                render: () => FeatureImportancePane(props.modelVersion, explainErrorMessage),
            },
        ];

        showCompareLink = true;

        // Add hyperopt tab before All Metrics tab
        if (isHyperopt) {
            panes = [
                ...panes,
                {
                    key: "hyperoptReportViewer",
                    menuItem: VizMenuItem("Hyperopt", props.modelVersion),
                    render: () => HyperoptReportViewerPane(props.modelVersion),
                },
            ];
        }

        panes = [
            ...panes,
            {
                key: "evalStatsViewer",
                menuItem: EvalMenuItem("All Metrics", props.modelVersion),
                render: () => EvalStatsViewerPane(props.modelVersion, evalErrorMessage),
            },
            // Always add the compute tab last
            {
                key: "modelComputeViewer",
                menuItem: "Compute",
                render: () => ModelComputeViewerPane(props.modelVersion),
            },
        ];
    }

    const onTabChange = (e: React.MouseEvent<HTMLDivElement>, { activeIndex }: TabProps) => {
        const idx = Number(activeIndex);
        setTabIndex(idx);
    };

    return (
        <div style={{ display: "flex" }}>
            <Tab
                style={{ position: "relative" }}
                menu={{ secondary: true, pointing: true }}
                panes={panes}
                activeIndex={tabIndex}
                onTabChange={onTabChange}
            />
            {showCompareLink && (
                <div style={{ float: "right", position: "absolute", right: 20, marginTop: "8px" }}>
                    <Link
                        to={`/models/repo/${props.modelVersion.repoID}/compare/${props.modelVersion.id}`}
                        className={"black-link"}
                        onClick={() => {
                            metrics.capture("model_compare", {
                                model_a: props.modelVersion.repoID,
                                model_b: props.modelVersion.id,
                            });
                            user && track(user, "model_compare");
                        }}
                    >
                        <b>
                            Compare&ensp;
                            <FontAwesomeIcon icon={"angle-right"} />
                        </b>
                    </Link>
                </div>
            )}
        </div>
    );
}

export default ModelVersionSubmenu;
