import type { JSONSchema7 } from "json-schema";
import { Link } from "react-router-dom";
import { Icon, Label } from "semantic-ui-react";
import { EngineServiceType } from "../types/engineServiceType";
import { ModelStatus } from "../types/model/modelStatus";
import { ModelTypes } from "../types/model/modelTypes";
import { SEMANTIC_BLUE, SEMANTIC_GREY } from "../utils/colors";
import { deletedUser } from "../utils/constants";
import { isLocal } from "../utils/environment";

import ecdSchema from "../assets/ecd-minified.json";
import gbmSchema from "../assets/gbm-minified.json";
import llmSchema from "../assets/llm-minified.json";

export const generalMetricsKeys = [
    "epoch",
    "steps",
    "batch_size",
    "learning_rate",
    "best_eval_metric_checkpoint_number",
];

// Values are measured in billions of parameters
const presetModelSizeMap: Map<string, number> = new Map<string, number>([
    ["llama-2-7b", 7],
    ["llama-2-13b", 13],
    ["llama-2-70b", 70],
    ["llama-2-7b-chat", 7],
    ["llama-2-13b-chat", 13],
    ["llama-2-70b-chat", 70],
    ["vicuna-7b", 7],
    ["vicuna-13b", 13],
    ["bloomz-3b", 3],
    ["bloomz-7b1", 7],
    ["opt-350m", 0.35],
    ["opt-1.3b", 1.3],
    ["opt-6.7b", 6.7],
    ["gpt-neo-2.7B", 2.7],
    ["gpt-j-6b", 6],
    ["pythia-2.8b", 2.8],
    ["pythia-12b", 12],
]);

export const getNumTextInputFeaturesForModelConfig = (config: CreateModelConfig): number => {
    return config.input_features.filter((x) => x.type === "text").length;
};

export const getNumInputFeaturesForModelConfig = (config: CreateModelConfig): number => {
    return config.input_features.length;
};

export const engineCanTrainModels = (engine: Engine, linkedOkay = false) => {
    // disable this for local engines since local doesn't have training engines
    if (isLocal()) {
        return true;
    }
    return (
        engine.serviceType === EngineServiceType.BATCH ||
        (linkedOkay && engine.serviceType === EngineServiceType.RAY && engine.engineStatus === "linked")
    );
};

export const getConfigOutputFeatureNames = (config?: CreateModelConfig): string[] => {
    return config?.output_features?.map((x: any) => x.name) || [];
};

export const modelStatusFilter = (x: Model | string) => {
    const status = typeof x === "string" ? x : x.status;
    return (
        status === ModelStatus.READY ||
        status === ModelStatus.DEPLOYED ||
        status === ModelStatus.DEPLOYING ||
        status === ModelStatus.UNDEPLOYING
    );
};

export const modelArchivedFilter = (x: Model) => {
    return !x.archived;
};

export const isTerminalStatus = (status?: string): boolean => {
    return (
        status === ModelStatus.READY ||
        status === ModelStatus.DEPLOYED ||
        status === ModelStatus.DEPLOYING ||
        status === ModelStatus.UNDEPLOYING ||
        status === ModelStatus.FAILED ||
        status === ModelStatus.CANCELED
    );
};

export const isCancelableStatus = (status?: string) => {
    return !isTerminalStatus(status) && status !== ModelStatus.STOPPING;
};

export const isFailedOrCanceledStatus = (status?: string): boolean => {
    return status === ModelStatus.FAILED || status === ModelStatus.CANCELED;
};

export const isTrainingStatus = (status?: string): boolean => {
    return (
        status === ModelStatus.QUEUED ||
        status === ModelStatus.PREPROCESSING ||
        status === ModelStatus.TRAINING ||
        status === ModelStatus.EVALUATING
    );
};

export const isSuccessfulPastTrainingStatus = (status?: string): boolean => {
    return !isTrainingStatus(status) && !isFailedOrCanceledStatus(status);
};

export const isMetricsStreamingStatus = (status?: string): boolean => {
    return (
        status === ModelStatus.TRAINING ||
        status === ModelStatus.EVALUATING ||
        status === ModelStatus.VISUALIZING ||
        status === ModelStatus.EXPLAINING
    );
};

export const getLatestModel = (models?: Model[], mustBeReady?: boolean) => {
    if (!Array.isArray(models) || models.length === 0) {
        return null;
    }
    if (mustBeReady) {
        models = models.filter(modelStatusFilter);
        if (models.length === 0) {
            return null;
        }
    }
    return models.reduce((prev, current) => (+prev.id > +current.id ? prev : current));
};

export const getModelType = (config?: CreateModelConfig) => config?.["model_type"] || ModelTypes.NEURAL_NETWORK;
export const isECDModel = (config?: CreateModelConfig) => config?.model_type === ModelTypes.NEURAL_NETWORK;
export const isGBMModel = (config?: CreateModelConfig) => config?.model_type === ModelTypes.DECISION_TREE;
export const isLLMModel = (config?: CreateModelConfig) => config?.model_type === ModelTypes.LARGE_LANGUAGE_MODEL;
export const isFineTunedModel = (config?: CreateModelConfig) =>
    config?.trainer?.type === "finetune" || config?.engine_version !== undefined;

export const getModelSchema = (config?: CreateModelConfig) => {
    const modelType = getModelType(config);
    // TODO: Figure out why we need to cast as unknown for each schema type.
    switch (modelType) {
        case ModelTypes.DECISION_TREE:
            return gbmSchema as unknown as JSONSchema7;
        case ModelTypes.LARGE_LANGUAGE_MODEL:
            return llmSchema as unknown as JSONSchema7;
        default:
            return ecdSchema as unknown as JSONSchema7;
    }
};

export const outputFeatureDisabledCriteria = (model: Model, outputFeature: string) => {
    if (!model.config.output_features.map((x: any) => x.name).includes(outputFeature)) {
        return `Model does not contain target [${outputFeature}]`;
    }
    return "";
};

export const calibrationPlotsDisabledCriteria = (model: Model) => {
    if (model.config.hyperopt) {
        return "Calibration plots are not available for hyperopt models.";
    }
    const output_features = model.config.output_features;
    if (output_features.filter((x: any) => x.type === "category").length === 0) {
        return (
            <span>
                This model does not have any <b>category</b> output features, so calibration plots cannot be generated.
            </span>
        );
    }
    return "";
};

export const confusionMatrixDisabledCriteria = (model: Model) => {
    const output_features = model.config.output_features;
    if (output_features.filter((x: any) => x.type === "binary" || x.type === "category").length === 0) {
        return (
            <span>
                This model does not have any <b>binary</b> or <b>category</b> output features, so confusion matrix
                cannot be generated.
            </span>
        );
    }
    return "";
};

export const f1DisabledCriteria = (model: Model) => {
    const output_features = model.config.output_features;
    if (output_features.filter((x: any) => x.type === "category").length === 0) {
        return (
            <span>
                This model does not have any <b>category</b> output features, so F1 visualizations cannot be generated.
            </span>
        );
    }
    return "";
};

export const featureImportanceDisabledCriteria = (model: Model) => {
    return "";
};

export const rocDisabledCriteria = (model: Model) => {
    const output_features = model.config.output_features;
    if (output_features.filter((x: any) => x.type === "binary").length === 0) {
        return (
            <span>
                <Icon name={"ban"} color={"grey"} />
                This model version cannot be graphed as it does not have any <b>binary</b> output features.
            </span>
        );
    }
    return "";
};

export const prCurveDisabledCriteria = (model: Model) => {
    return rocDisabledCriteria(model);
};

export const modelStatusDisabledCriteria = (model: Model, step?: string, cannotBeDisplayedMsg?: string) => {
    const link = <Link to={"/models/version/" + model.id}>#{model.repoVersion}</Link>;
    if (isTrainingStatus(model.status)) {
        return (
            <>
                Model version {link} has not yet finished the{" "}
                <span style={{ color: SEMANTIC_BLUE }}>{step || ModelStatus.VISUALIZING}</span> step so{" "}
                {cannotBeDisplayedMsg || "no visualizations can be displayed."}
            </>
        );
    }
    switch (model.status) {
        case ModelStatus.FAILED:
            return <>Model version {link} failed so no visualizations can be displayed.</>;
        case ModelStatus.STOPPING:
        case ModelStatus.CANCELED:
            return <>Model version {link} was canceled so no visualizations can be displayed.</>;
        default:
            return "";
    }
};

export const ModelContributors = (models: Model[]) => {
    const contributorCount: Record<string, number> = {};
    for (const model of models || []) {
        const username = model.user?.username || deletedUser;
        if (contributorCount[username] === undefined) {
            contributorCount[username] = 0;
        }
        contributorCount[username] += 1;
    }
    return (
        <Label.Group style={{ gap: "0.5em" }}>
            {Object.entries(contributorCount)
                .sort((a, b) => b[1] - a[1])
                .map((contributor) => {
                    const [username, count] = contributor;
                    return (
                        <Label className={"centered-label"} basic key={username} style={{ whiteSpace: "nowrap" }}>
                            {username}
                            <Label.Detail
                                style={{
                                    color: SEMANTIC_GREY,
                                    marginLeft: "0.5em",
                                    fontWeight: "normal",
                                }}
                            >
                                {count}
                            </Label.Detail>
                        </Label>
                    );
                })}
        </Label.Group>
    );
};

export const checkIfHyperoptEnabled = (config?: CreateModelConfig) => {
    if (!config) return false;

    return Object.prototype.toString.call(config?.hyperopt) === "[object Object]";
};

export const isTrainValTestMetric = (metricName: string) => {
    return (
        metricName.startsWith("train_metrics") ||
        metricName.startsWith("validation_metrics") ||
        metricName.startsWith("test_metrics")
    );
};

/**
 * @param metricName
 * @returns a formatted metric name, with "best." prefixes stripped.
 */
export const formatMetricName = (metricName: string) => {
    metricName = metricName.replace("best.", "");
    // Use uppercase for non-feature metrics.
    if (metricName === "best_eval_metric_checkpoint_number") {
        return "Best Checkpoint";
    }
    if (metricName === "learning_rate") {
        return "Learning Rate";
    }
    if (metricName === "batch_size") {
        return "Batch Size";
    }
    if (metricName === "steps") {
        return "Total Steps";
    }
    if (metricName === "epoch") {
        return "Total Epochs";
    }
    if (isTrainValTestMetric(metricName)) {
        const metricParts = getSimpleMetricNameParts(metricName);
        return (
            <span>
                <span style={{ color: SEMANTIC_GREY }}>{metricParts[0]}</span>&nbsp;{metricParts[1]}
            </span>
        );
    }
    return metricName;
};

/**
 * @param metricName
 * @returns List of metric name components, e.g. "test_metrics.Survived.loss" -> ["test", "Survived.loss"].
 */
const getSimpleMetricNameParts = (metricName: string): string[] => {
    return [metricName.slice(0, metricName.indexOf("_")), metricName.slice(metricName.indexOf(".") + 1)];
};
