export enum ModelOptimizationStrategy {
    MINIMIZE = "minimize",
    MAXIMIZE = "maximize",
}

export const ModelOptimizationMapping = {
    root_mean_squared_error: "minimize",
    precision: "maximize",
    recall: "maximize",
    roc_auc: "maximize",
    specificity: "maximize",
    root_mean_squared_percentage_error: "minimize",
    r2: "maximize",
    loss: "minimize",
    binary_weighted_cross_entropy: "minimize",
    softmax_cross_entropy: "minimize",
    sequence_softmax_cross_entropy: "minimize",
    sigmoid_cross_entropy: "minimize",
    token_accuracy: "maximize",
    sequence_accuracy: "maximize",
    percplexity: "minimize",
    char_error_rate: "minimize",
    hits_at_k: "maximize",
    mean_absolute_error: "minimize",
    mean_squared_error: "minimize",
    mean_absolute_percentage_error: "minimize",
    jaccard: "maximize",
    huber: "minimize",
};

export const ModelRepoPerformanceGraphSpec = (
    useTemporalScale: boolean,
    metricName?: string,
    maximizeMetric?: boolean,
    dataLength?: number,
    yDomain?: [number, number],
) => {
    const xFieldName = useTemporalScale ? "modelCreatedAt" : "modelVersionNumber";
    const xFieldTitle = useTemporalScale ? "Model Creation Time" : "Model Version Number";
    let xFieldType;

    if (useTemporalScale) {
        xFieldType = "temporal";
    } else {
        if (dataLength && dataLength > 30) {
            xFieldType = "quantitative";
        } else {
            xFieldType = "ordinal";
        }
    }

    let yFieldName = maximizeMetric ? "runningMax" : "runningMin";

    const timeUnit = useTemporalScale ? "monthdatehoursminutes" : undefined;
    const labelAngle = useTemporalScale ? -15 : 0;

    function getWindowDimensions() {
        const { innerWidth: width, innerHeight: height } = window;
        return {
            width,
            height,
        };
    }

    const { width } = getWindowDimensions();

    const vegaSpec = {
        $schema: "https://vega.github.io/schema/vega-lite/v5.json",
        description: "Graph of model performance over entire repo",
        width: width * 0.65,
        data: { name: "table" },
        layer: [
            {
                mark: { type: "circle", size: 200 },
                encoding: {
                    x: {
                        field: xFieldName,
                        timeUnit: timeUnit,
                        type: xFieldType,
                        axis: {
                            title: xFieldTitle,
                            labelAngle: labelAngle,
                        },
                    },
                    y: {
                        field: "metricValue",
                        type: "quantitative",
                        scale: { domain: yDomain },
                        axis: {
                            title: metricName,
                            offset: 10,
                            zero: false,
                        },
                    },
                    tooltip: [
                        { field: "modelVersionNumber", title: "Model Version Number" },
                        {
                            field: "modelCreatedAt",
                            title: "Model Creation Time",
                            type: "temporal",
                            format: "%b %d, %Y %H:%M",
                        },
                        { field: "metricValue", title: "Metric Value", format: ".4f" },
                        { field: "description", title: "Description" },
                    ],
                    href: { field: "url", type: "nominal" },
                },
            },
            {
                name: yFieldName,
                mark: "line",
                encoding: {
                    x: {
                        field: xFieldName,
                        timeUnit: timeUnit,
                        type: xFieldType,
                        axis: {
                            title: xFieldTitle,
                            labelAngle: labelAngle,
                        },
                    },
                    y: {
                        field: yFieldName,
                        type: "quantitative",
                    },
                },
            },
            {
                mark: { type: "text", align: "center", dy: -10 },
                encoding: {
                    x: {
                        field: xFieldName,
                        timeUnit: timeUnit,
                        type: xFieldType,
                        axis: {
                            title: xFieldTitle,
                            labelAngle: labelAngle,
                        },
                    },
                    y: {
                        field: "metricValue",
                        type: "quantitative",
                        axis: {
                            title: metricName,
                        },
                    },
                    text: { field: "modelVersionNumber", type: "nominal" },
                },
            },
        ],
    };

    if (dataLength && dataLength > 30) {
        // @ts-ignore
        vegaSpec.layer[0].encoding.x.axis.tickCount = 30;
    }

    return vegaSpec;
};
