import _ from "lodash";
import { getClientEndpoint } from "../../../../utils/api";
import { snakeToTitle } from "../../../../utils/config";
import { ModelOptimizationMapping, ModelOptimizationStrategy } from "./ModelRepoPerformanceGraphSpec";

/**
 * @param metricName metric name to be converted
 *
 * @returns {string} output feature name + metric name label
 *
 * @description
 * Returns a label for a metric name that is human-readable.
 *
 * For example, "best.test_metrics.Survived.accuracy" becomes "Survived Accuracy".
 */
export const featureAndMetricNameLabel = (metricName: string) => {
    if (!metricName) {
        return;
    }

    const metricParts = metricName.split(".");
    const metric = _.nth(metricParts, -1);
    const feature = _.nth(metricParts, -2);

    if (!metric) {
        return;
    }

    if (metric === "roc_auc") {
        // one-off, since the capitalization rule below doesn't work for abbreviations
        return `${feature} ROC AUC`;
    }

    return `${feature} ${snakeToTitle(metric)}`;
};

/**
 * @param modelVersions array of model versions in ready state present in repo
 *
 * @returns {any[]} array of metric names
 *
 * @description
 * Returns an array of metric names that are available from the output features present in the given model versions.
 *
 * For example, if the model versions present in the repo collectively contain a number, binary, and category
 * output feature, this function will return an array of metric names for each of those output features.
 */
export const collectMetrics = (modelVersions: Model[]) => {
    const metricNames: Set<any> = new Set();
    modelVersions.forEach((model: Model) => {
        model.modelMetrics?.forEach((metric: ModelMetric) => {
            if (metric.isBest && metric.split === "test") {
                metricNames.add(metric.runMetricName);
            }
        });
    });

    const metrics: any[] = Array.from(metricNames).map((metricName: string) => {
        return {
            key: metricName,
            text: featureAndMetricNameLabel(metricName),
            value: metricName,
        };
    });

    return metrics;
};

/**
 * @param modelMetrics array of model metrics from a given model version
 * @param metricName metric name selected in UI dropdown
 *
 * @returns {number|undefined} metric value
 *
 * @description
 * This function loops through the metrics available in the given model version and returns the value of the specified metric.
 * The metric will always be a "best" metric, since this function is called from the ModelRepoPerformanceGraph component
 * where the options for the metric dropdown are filtered to only include "best" metrics.
 *
 * For example, if the user selects "Survived Accuracy" from the metric dropdown in the UI, the value associated with
 * this dropdown option is "best.test_metrics.Survived.accuracy".
 */
export const getMetricValueAtBestValidationStep = (metricName: string, modelMetrics?: ModelMetric[]) => {
    if (!modelMetrics || !metricName) {
        return;
    }

    for (const currentMetric of modelMetrics) {
        if (metricName === currentMetric.runMetricName) {
            return currentMetric.metricValue;
        }
    }
};

/**
 * @param modelVersions array of model versions in ready state present in repo
 * @param metricName metric name selected in UI dropdown
 * @param maximizeMetric boolean indicating whether metric should be maximized or minimized
 *
 * @returns {ModelPerformanceDataPoint[]} array of model performance data points to plot
 *
 * @description
 * This function first sorts the models by time, then loops through the given model versions and collects the metric
 * values for the given metric name. On each iteration, the running max or min is set if the model is the best so far.
 */
export const collectModelPerformanceData = (modelVersions: Model[], metricName: string) => {
    // Sorting model versions by time ascending
    modelVersions.sort((a, b) => {
        if (a.created < b.created) {
            return -1;
        }
        if (a.created > b.created) {
            return 1;
        }
        return 0;
    });

    const modelPerformanceData: ModelPerformanceDataPoint[] = [];

    // Initialize running max and min to infinity so that the first model is always the best so far
    let runningMax = Number.NEGATIVE_INFINITY;
    let runningMin = Number.POSITIVE_INFINITY;

    // Loop through model versions, collect metric values, and set either the running max or min depending on
    // optimization direction indicated by `maximizeMetric`
    modelVersions.forEach((model: Model) => {
        let metricValue: number | undefined = getMetricValueAtBestValidationStep(metricName, model.modelMetrics);

        if (metricValue === undefined) {
            return;
        }

        if (metricValue && metricValue < runningMin) {
            runningMin = metricValue;
        }
        if (metricValue && metricValue > runningMax) {
            runningMax = metricValue;
        }

        // Add individual model performance data point to array - this creates the data to be plotted in the component
        modelPerformanceData.push({
            modelVersionNumber: model.repoVersion,
            modelCreatedAt: model.created,
            metricValue: metricValue,
            description: model.description,
            runningMax: runningMax,
            runningMin: runningMin,
            url: getClientEndpoint() + "/models/version/" + model.id,
        });
    });

    return modelPerformanceData;
};

/**
 * @param metricName Metric name currently selected in UI dropdown (e.g. "best.test_metrics.Survived.accuracy")
 *
 * @returns {ModelOptimizationStrategy} Optimization strategy for the given metric
 *
 * @description
 * This function returns the optimization strategy for the given metric name. The optimization strategy is used to
 * determine whether the best model line should track the max or minimum for the best model.
 */
export const getOptimizationBehavior = (metricName: string | undefined) => {
    if (!metricName) {
        return ModelOptimizationStrategy.MAXIMIZE;
    }

    return _.get(ModelOptimizationMapping, metricName?.split(".").slice(-1)[0], ModelOptimizationStrategy.MAXIMIZE);
};
