import { SyntheticEvent } from "react";
import { DropdownProps, Form } from "semantic-ui-react";

import metrics from "../../../metrics/metrics";
import { useConfigState, useDispatch } from "../store";

// Taken from ludwig.features.<feature_type>.<feature_type>OutputFeature.metric_functions class var(s).
type MetricOptions = {
    [key: string]: string[];
};
const metricOptions: MetricOptions = {
    combined: ["loss"],
    binary: ["loss", "accuracy", "roc_auc"],
    category: ["loss", "accuracy", "hits_at_k"],
    number: [
        "loss",
        "mean_squared_error",
        "mean_absolute_error",
        "root_mean_squared_error",
        "root_mean_squared_percentage_error",
        "r2",
    ],
    sequence: ["loss", "token_accuracy", "sequence_accuracy", "last_accuracy", "perplexity", "edit_distance"],
    set: ["loss", "jaccard"],
    text: ["loss", "token_accuracy", "last_accuracy", "perplexity", "edit_distance"],
    vector: ["loss", "mean_squared_error", "mean_absolute_error", "r2"],
};

// Relations taken from: https://github.com/ludwig-ai/ludwig/blob/master/ludwig/modules/metric_modules.py
type GoalOptions = {
    [key: string]: string;
};

const goalOptions: GoalOptions = {
    root_mean_squared_error: "minimize",
    roc_auc: "maximize",
    root_mean_squared_percentage_error: "minimize",
    r2: "maximize",
    loss: "minimize",
    token_accuracy: "maximize",
    accuracy: "maximize",
    hits_at_k: "maximize",
    mean_absolute_error: "minimize",
    mean_squared_error: "minimize",
    jaccard: "maximize",
    perplexity: "minimize",
};

function GoalAndMetricForm() {
    const dispatch = useDispatch();
    const { config, invalidFields } = useConfigState();
    const hyperoptConfig = config?.hyperopt;

    const getCurrentValue = (hyperoptConfig: CreateModelHyperoptProps) => {
        const goal = hyperoptConfig?.goal;
        const metric = hyperoptConfig?.metric;
        return `${goal} ${metric}`.trim();
    };

    const getOutputFeatureOptions = (config?: CreateModelConfig) => {
        if (!config) return [];

        const outputOptions = [
            {
                text: "combined",
                value: "combined",
            },
        ];

        if (config.output_features?.length === 1) {
            outputOptions.push({
                text: config.output_features[0].name,
                value: config.output_features[0].name,
            });
        }

        return outputOptions;
    };

    const getGoalAndMetricOptions = (config?: CreateModelConfig, hyperoptConfig?: CreateModelHyperoptProps) => {
        let goalAndMetricOptions: any = {};
        for (const ftype in metricOptions) {
            const possibleMetrics = metricOptions[ftype];
            let ftypeGoalAndMetricOptions: any[] = [];
            possibleMetrics.forEach((m) => {
                const g = goalOptions[m];
                const gmOptionText = `${m} (${g})`;
                const gmOptionValue = `${g} ${m}`;
                ftypeGoalAndMetricOptions.push({
                    text: gmOptionText,
                    value: gmOptionValue,
                });
            });
            goalAndMetricOptions[ftype] = ftypeGoalAndMetricOptions;
        }

        if (hyperoptConfig?.output_feature !== undefined) {
            for (const of in config?.output_features) {
                // @ts-ignore
                const ofSchema = config?.output_features[of];
                if (ofSchema.name == hyperoptConfig.output_feature) {
                    return goalAndMetricOptions[ofSchema.type];
                }
            }
        }

        return goalAndMetricOptions["combined"];
    };

    return (
        <Form style={{ marginTop: "2em" }}>
            <Form.Select
                className={metrics.BLOCK_AUTO_CAPTURE}
                name={"outputFeature"}
                label={"Output Feature"}
                options={getOutputFeatureOptions(config)}
                value={hyperoptConfig?.output_feature}
                style={{ width: "197px" }}
                placeholder="Select output feature."
                error={invalidFields["hyperopt/output_feature"]}
                onChange={(event: SyntheticEvent, data: DropdownProps) => {
                    let selection = data.value as string;
                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: "hyperopt.output_feature", value: selection });
                }}
            />

            <Form.Select
                className={metrics.BLOCK_AUTO_CAPTURE}
                name={"goalAndMetric"}
                label={"Goal and Metric"}
                options={getGoalAndMetricOptions(config, hyperoptConfig)}
                value={getCurrentValue(hyperoptConfig)}
                style={{ width: "197px" }}
                placeholder="Select goal and metric."
                error={invalidFields["hyperopt/goal"] || invalidFields["hyperopt/metric"]}
                onChange={(event: SyntheticEvent, data: DropdownProps) => {
                    let selection = (data.value as string).split(" ");
                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: "hyperopt.goal", value: selection[0] });
                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: "hyperopt.metric", value: selection[1] });
                }}
            />
        </Form>
    );
}

export default GoalAndMetricForm;
