import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useNavigate } from "react-router-dom";
import { Button, Icon, Modal } from "semantic-ui-react";

import { adapterVersion, createFinetuningJobRequest, repo } from "@/autogen/openapi";

import { useAuth0TokenOptions } from "../../../data";
import { createFinetuningJob } from "../../data";
import { GET_ADAPTER_REPOS_QUERY_KEY, useAdapterRepoQuery, useAdapterVersionQuery } from "../../query";

const RetrainAdapterModal = (props: {
    adapterRepo?: repo;
    // If set, repo UUID and version number props are unncessary (and vice-versa for adapterVersion if they are set)
    adapterVersion?: adapterVersion;
    adapterRepoUuid?: string;
    adapterVersionNumber?: number;
    open: boolean;
    setOpen: React.Dispatch<React.SetStateAction<boolean>>;
}) => {
    // Parent state:
    const { adapterRepo, adapterRepoUuid, adapterVersion, adapterVersionNumber, open, setOpen } = props;

    // Route state:
    const navigate = useNavigate();

    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    // Query state:
    const queryClient = useQueryClient();

    const { data: adapterVersionLoaded } = useAdapterVersionQuery(
        adapterRepo?.uuid ? adapterRepo.uuid : (adapterRepoUuid ?? ""),
        adapterVersionNumber ?? 0,
        {
            enabled:
                open &&
                adapterVersion === undefined &&
                typeof adapterVersionNumber === "number" &&
                adapterVersionNumber > 0,
            refetchOnWindowFocus: false,
            retry: false,
        },
    );

    const { data: adapterRepoLoaded } = useAdapterRepoQuery(adapterRepoUuid ?? "", {
        enabled: open && adapterRepo === undefined && adapterRepoUuid !== undefined,
        refetchOnWindowFocus: false,
        retry: false,
    });

    const { mutate: retrainAdapter, reset: resetMutation } = useMutation({
        mutationFn: (config: createFinetuningJobRequest) => createFinetuningJob(config, auth0TokenOptions),
        onSuccess: (data) => {
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_REPOS_QUERY_KEY() });
            resetMutation();
            navigate(`/adapters/repo/${adapterRepo?.uuid}/version/${data.targetVersionTag}`);
            setOpen(false);
        },
    });

    const adapterVersionLocal = adapterVersion ? adapterVersion : adapterVersionLoaded;
    const adapterRepoLocal = adapterRepo ? adapterRepo : adapterRepoLoaded;
    const adapterVersionNumberLocal = adapterVersionNumber ?? adapterVersionLocal?.tag ?? 0;

    return (
        <Modal
            name="retrain-adapter-modal"
            open={open}
            size="mini"
            onClose={() => {
                setOpen(false);
                resetMutation();
            }}
        >
            <Modal.Header>Retrain Adapter #{adapterVersionNumberLocal}</Modal.Header>
            <Modal.Content>Would you like to retrain this model?</Modal.Content>
            <Modal.Actions>
                <Button
                    onClick={() => {
                        setOpen(false);
                        resetMutation();
                    }}
                >
                    Cancel
                </Button>
                <Button
                    icon
                    color={"green"}
                    labelPosition={"right"}
                    size={"small"}
                    onClick={() => {
                        retrainAdapter({
                            params: adapterVersionLocal?.finetuningJob?.params ?? {},
                            dataset: adapterVersionLocal?.datasetName ?? "",
                            repo: adapterRepoLocal?.name ?? "",
                            description: "",
                        });
                    }}
                >
                    Retrain
                    <Icon name="checkmark" />
                </Button>
            </Modal.Actions>
        </Modal>
    );
};

export default RetrainAdapterModal;
