import { AxiosInstance } from "axios";
import React, { useEffect, useState } from "react";
import { Checkbox, Divider, Dropdown, Form, Grid, Header, Segment } from "semantic-ui-react";
import { DropdownItemProps } from "semantic-ui-react/dist/commonjs/modules/Dropdown/DropdownItem";
import TargetSelector from "../../components/TargetSelector";
import { useAuth0TokenOptions } from "../../data";
import { createV1APIServer } from "../../utils/api";
import { getErrorMessage } from "../../utils/errors";
import { loadResultsDataset } from "../../utils/results";
import { featureImportanceDisabledCriteria, getConfigOutputFeatureNames, modelStatusDisabledCriteria } from "../util";
import FeatureImportanceViewer from "./FeatureImportanceViewer";
import { UseOutputFeaturesForSelectedModels } from "./util";

const getExplanationsDatasetHelper = (
    apiServer: AxiosInstance | null,
    x: Model,
    setDataset: React.Dispatch<React.SetStateAction<ResultsDataset | null>>,
    setErrorMessage: React.Dispatch<React.SetStateAction<string | null>>,
    target: string,
    setLoading: React.Dispatch<React.SetStateAction<boolean>>,
) => {
    if (
        target &&
        x.config.output_features.map((feat: any) => feat.name).includes(target) &&
        !(modelStatusDisabledCriteria(x) && !featureImportanceDisabledCriteria(x))
    ) {
        setLoading(true);
        apiServer
            ?.get("visualize/explanations", { params: { modelID: x.id, outputFeature: target } })
            .then((res) => {
                try {
                    setDataset(loadResultsDataset(res.data));
                    setErrorMessage(null);
                } catch (err: any) {
                    setDataset(null);
                    setErrorMessage(err.toString());
                }
            })
            .catch((err) => setErrorMessage(getErrorMessage(err)))
            .finally(() => setLoading(false));
    }
};

function FeatureImportanceCompareViewer(props: {
    modelDropdownOptions: any[];
    modelSearch: (options: DropdownItemProps[], value: string) => DropdownItemProps[];
    allModels: Model[];
    selectedModels: (Model | null)[];
    setSelectedModels: React.Dispatch<React.SetStateAction<(Model | null)[]>>;
    noModelSelected: JSX.Element;
}) {
    const [target, setTarget] = useState<string>("");
    const [className, setClassName] = useState<string>();
    // Maps className -> datasetIdx -> activeIndex
    const [classToIndex, setClassToIndex] = useState<Record<string, Record<number, number>>>({});

    const [dataset0, setDataset0] = useState<ResultsDataset | null>(null);
    const [dataset1, setDataset1] = useState<ResultsDataset | null>(null);
    const [loading0, setLoading0] = useState<boolean>(false);
    const [loading1, setLoading1] = useState<boolean>(false);
    const [errorMessage0, setErrorMessage0] = useState<string | null>(null);
    const [errorMessage1, setErrorMessage1] = useState<string | null>(null);

    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    const [apiServer, setAPIServer] = useState<AxiosInstance | null>(null);
    useEffect(() => {
        const getAPIServer = async () => {
            const v1APIServer = await createV1APIServer(auth0TokenOptions);
            // NOTE: Whoever wrote the axios typings is a moron because the return type of axios.create is not
            // AxiosInstance -- it's a wrap function. And React will see that and treat it as a callback that
            // setState should directly call. FML.
            // See: [1], [2]:
            // [1]: https://github.com/axios/axios/issues/4365
            // [2]: https://stackoverflow.com/questions/64427195/calling-setstate-will-execute-the-function-value-instead-of-passing-it
            setAPIServer(() => v1APIServer);
        };
        getAPIServer();
    }, []);

    useEffect(() => {
        if (target) {
            const model0 = props.selectedModels[0];
            if (
                model0 &&
                getConfigOutputFeatureNames(model0.config).includes(target) &&
                !modelStatusDisabledCriteria(model0) &&
                !featureImportanceDisabledCriteria(model0)
            ) {
                getExplanationsDatasetHelper(apiServer, model0, setDataset0, setErrorMessage0, target, setLoading0);
            } else {
                setDataset0(null);
            }
        }
    }, [props.selectedModels[0]?.id, target, apiServer]);

    useEffect(() => {
        if (target) {
            const model1 = props.selectedModels[1];
            if (
                model1 &&
                getConfigOutputFeatureNames(model1.config).includes(target) &&
                !modelStatusDisabledCriteria(model1) &&
                !featureImportanceDisabledCriteria(model1)
            ) {
                getExplanationsDatasetHelper(apiServer, model1, setDataset1, setErrorMessage1, target, setLoading1);
            } else {
                setDataset1(null);
            }
        }
    }, [props.selectedModels[1]?.id, target, apiServer]);

    useEffect(() => {
        const keys = Object.keys(classToIndex);
        if (keys.length > 0 && (className === undefined || !keys.includes(className))) {
            setClassName(keys[0]);
        }
    }, [classToIndex]);

    useEffect(() => {
        const results = [dataset0, dataset1];
        const localClassToIndex: Record<string, Record<number, number>> = {};
        results.forEach((dataset, datasetIdx) => {
            if (dataset !== null) {
                const col = dataset.columns[0];
                const metadata = dataset?.metadata?.[col];
                const idx2str = metadata.targetIdx2str;
                idx2str?.forEach((c: string, i: number) => {
                    if (localClassToIndex[c] === undefined) {
                        localClassToIndex[c] = {};
                    }
                    localClassToIndex[c][datasetIdx] = i;
                });
            }
        });
        setClassToIndex(localClassToIndex);
    }, [dataset0, dataset1]);

    const outputFeatures = UseOutputFeaturesForSelectedModels(target, setTarget, props.selectedModels);

    const getActiveIndex = (firstOrSecond: number) => {
        if (Object.keys(classToIndex).length === 0) {
            return 0;
        } else if (className !== undefined && classToIndex[className]) {
            return classToIndex[className][firstOrSecond];
        }
        return -1;
    };

    const indexSelector =
        Object.keys(classToIndex).length > 0 ? (
            <>
                <Header as="h5">Class</Header>
                <Form>
                    {Object.keys(classToIndex).map((label: string, labelIndex: number) => (
                        <Form.Field key={label}>
                            <Checkbox
                                radio
                                label={label}
                                checked={label === className}
                                onChange={(e, data) => {
                                    setClassName(label);
                                }}
                            />
                        </Form.Field>
                    ))}
                </Form>
            </>
        ) : null;

    return (
        <Grid columns={"equal"} style={{ width: "100%", height: "100%" }} divided>
            {(outputFeatures.length > 1 || indexSelector) && (
                <Grid.Column width={2}>
                    <TargetSelector
                        outputFeatures={outputFeatures}
                        target={target}
                        setTarget={(target) => {
                            setTarget(target);
                            setClassName(undefined);
                        }}
                    />
                    <Segment style={{ backgroundColor: "rgba(0, 0, 0, 0.03)" }}>{indexSelector}</Segment>
                </Grid.Column>
            )}
            <Grid.Column style={{ overflowX: "auto" }}>
                <Dropdown
                    fluid
                    search={props.modelSearch}
                    selection
                    options={props.modelDropdownOptions}
                    onChange={(event, data) => {
                        const selected = props.allModels.find((x) => x.id === data.value);
                        if (selected) {
                            props.setSelectedModels((x) => [selected, x[1]]);
                        }
                    }}
                    value={props.selectedModels[0]?.id || ""}
                />
                <Divider hidden />
                {props.selectedModels[0] ? (
                    <FeatureImportanceViewer
                        model={props.selectedModels[0]}
                        errorMessage={errorMessage0}
                        setErrorMessage={setErrorMessage0}
                        className={className}
                        loading={loading0}
                        dataset={dataset0}
                        activeIndex={getActiveIndex(0)}
                        target={target}
                        setTarget={setTarget}
                    />
                ) : (
                    props.noModelSelected
                )}
            </Grid.Column>
            <Grid.Column style={{ overflowX: "auto" }}>
                <Dropdown
                    fluid
                    search={props.modelSearch}
                    selection
                    options={props.modelDropdownOptions}
                    onChange={(event, data) => {
                        const selected = props.allModels.find((x) => x.id === data.value);
                        if (selected) {
                            props.setSelectedModels((x) => [x[0], selected]);
                        }
                    }}
                    value={props.selectedModels[1]?.id || ""}
                />
                <Divider hidden />
                {props.selectedModels[1] ? (
                    <FeatureImportanceViewer
                        model={props.selectedModels[1]}
                        errorMessage={errorMessage1}
                        setErrorMessage={setErrorMessage1}
                        className={className}
                        loading={loading1}
                        dataset={dataset1}
                        activeIndex={getActiveIndex(1)}
                        target={target}
                        setTarget={setTarget}
                    />
                ) : (
                    props.noModelSelected
                )}
            </Grid.Column>
        </Grid>
    );
}

export default FeatureImportanceCompareViewer;
