import { UseMutateFunction, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useState } from "react";
import { useLocation, useMatch, useNavigate } from "react-router-dom";
import { useRecoilState } from "recoil";
import { Button, Divider, Form, Header, Icon, List, Loader, Menu, Message, Popup } from "semantic-ui-react";
import {
    adapterVersion,
    baseModel,
    createFinetuningJobRequest,
    finetuningJob,
    repo,
    tier,
} from "../../../api_generated";
import { SubscriptionButton } from "../../../components/GlobalHeader/SubscriptionButtons";
import { useAuth0TokenOptions } from "../../../data";
import ConnectionSelector from "../../../data/ConnectionSelector";
import DatasetSelector from "../../../data/DatasetSelector";
import { defaultDatasetPreviewParams } from "../../../data/data";
import DatasetPreviewer from "../../../data/datasets/DatasetPreviewer";
import { commonGetDatasetPreviewQueryOptions, prefetchDatasetPreviewQuery } from "../../../data/query";
import { useBaseModelsQuery } from "../../../deployments/data/query";
import { USER_STATE } from "../../../state/global";
import { SEMANTIC_DARK_YELLOW, SEMANTIC_GREY } from "../../../utils/colors";
import { getErrorMessage } from "../../../utils/errors";
import { isKratosUserContext } from "../../../utils/kratos";
import Breadcrumbs from "../../Breadcrumbs";
import { createFinetuningJob } from "../../data";
import { getNextAdapterVersion, removeEmptyParameters } from "../../misc/utils";
import { GET_ADAPTER_REPOS_QUERY_KEY, useAdapterRepoQuery, useAdapterVersionQuery } from "../../query";
import ParametersForm from "./ParametersForm";
import { ConfigProvider, useConfigState, useDispatch } from "./store";
import { getModelValidationErrors } from "./utils";

import "./CreateAdapterView.css";

const CreateAdapterButton = (props: {
    onSubmit: () => void;
    invalid: boolean;
    uploading: boolean;
    validating: boolean;
}) => {
    const [userContext] = useRecoilState(USER_STATE);

    return (
        <Button
            content="Train"
            labelPosition="right"
            icon={
                <i className="icon">
                    <svg
                        style={{ width: "24px", height: "100%" }}
                        viewBox="0 0 24 21"
                        fill="none"
                        xmlns="http://www.w3.org/2000/svg"
                    >
                        <path
                            d="M9.65484 8.99984C10.4551 8.83955 10.9765 8.05157 10.8142 7.24203C10.6513 6.43096 9.86819 5.90609 9.06709 6.06658C8.26481 6.22727 7.74616 7.01228 7.90871 7.82377C8.07041 8.63284 8.8524 9.15991 9.65484 8.99984"
                            fill="white"
                        />
                        <path d="M19.6018 0L20.2972 1.76881L18.5284 2.46419L17.833 0.695378L19.6018 0Z" fill="white" />
                        <path
                            d="M20.9924 3.5368L21.6878 5.30602L19.919 6.00125L19.2236 4.23203L20.9924 3.5368Z"
                            fill="white"
                        />
                        <path
                            d="M22.3835 7.07465L23.0787 8.84345L21.3095 9.53883L20.6143 7.77002L22.3835 7.07465Z"
                            fill="white"
                        />
                        <path
                            d="M16.0647 1.3894L16.7597 3.15806L14.9909 3.85308L14.2959 2.08442L16.0647 1.3894Z"
                            fill="white"
                        />
                        <path
                            d="M17.4548 4.92773L18.1501 6.69716L16.3809 7.39254L15.6855 5.62311L17.4548 4.92773Z"
                            fill="white"
                        />
                        <path
                            d="M18.8444 8.46606L19.54 10.2355L17.7708 10.9311L17.0752 9.16165L18.8444 8.46606Z"
                            fill="white"
                        />
                        <path
                            d="M18.5278 2.46307L19.2234 4.23173L17.4544 4.92752L16.7588 3.15886L18.5278 2.46307Z"
                            fill="white"
                        />
                        <path
                            d="M19.9184 6.00092L20.6136 7.76993L18.8446 8.46515L18.1494 6.69614L19.9184 6.00092Z"
                            fill="white"
                        />
                        <path
                            d="M14.9829 3.85718L15.6785 5.62599L13.9095 6.32177L13.2139 4.55297L14.9829 3.85718Z"
                            fill="white"
                        />
                        <path
                            d="M16.3725 7.396L17.0673 9.16521L15.2983 9.86003L14.6035 8.09081L16.3725 7.396Z"
                            fill="white"
                        />
                        <path
                            d="M9.30656 19.8823L10.722 16.4623C10.9062 16.0176 10.7206 15.5056 10.294 15.2814L7.81249 13.9806L9.95714 11.89L13.5864 14.4836L12.2321 15.0099L12.4949 15.6866L24 11.2143L23.7371 10.5375L21.9721 11.2241L21.3094 9.53924L19.5404 10.2348L20.1998 11.9127L18.4199 12.6044L17.7629 10.9336L15.9943 11.6287L16.6487 13.2932L15.1894 13.8605L9.22551 9.42652C8.85933 9.18307 8.4835 9.13145 8.02456 9.33372L2.5867 11.7588C2.20285 11.9297 2.03073 12.3792 2.20164 12.7631C2.32778 13.0463 2.60558 13.214 2.89665 13.214C3.0001 13.214 3.10515 13.1929 3.20619 13.1479L5.8701 11.9598L0.21318 18.7532C0.194499 18.7759 0.179033 18.7968 0.165173 18.8173C0.152317 18.8357 0.142676 18.8554 0.131629 18.8747C0.129419 18.8795 0.126005 18.8844 0.123795 18.8888C-0.0764673 19.2393 -0.0354941 19.6922 0.259379 19.9966C0.441968 20.1854 0.685213 20.2802 0.928685 20.2802C1.16189 20.2802 1.3955 20.193 1.5765 20.018L5.83063 15.8995L8.60458 16.7089L7.58617 19.1696C7.44174 19.5183 7.52671 19.9034 7.76876 20.1617H6.08649V20.6923H16.1842V20.1617L9.12346 20.1621C9.19859 20.0821 9.26246 19.9893 9.30665 19.8821L9.30656 19.8823Z"
                            fill="white"
                        />
                    </svg>
                </i>
            }
            onClick={props.onSubmit}
            loading={props.uploading}
            disabled={props.invalid || props.validating || props.uploading || userContext?.isExpired}
            primary
        />
    );
};

const AdapterTabPanel = (props: {
    baseModels?: baseModel[];
    dataset?: Dataset;
    modelDescription: string;
    nextVersion?: number;
    setModelDescription: React.Dispatch<React.SetStateAction<string>>;
}) => {
    // Parent state:
    const [activeTabPanel, setActiveTabPanel] = useState("parameters");

    // Event listeners:
    const handleMenuItemClick = (event: React.SyntheticEvent<HTMLElement, Event>, data: any) => {
        setActiveTabPanel(data.name);
    };

    return (
        <>
            <Menu pointing secondary>
                <Menu.Item
                    name="parameters"
                    content="Parameters"
                    active={activeTabPanel === "parameters"}
                    onClick={handleMenuItemClick}
                />
                <Menu.Item
                    name="dataset"
                    content="Dataset Preview"
                    active={activeTabPanel === "dataset"}
                    onClick={handleMenuItemClick}
                />
            </Menu>
            {activeTabPanel === "parameters" && (
                <ParametersForm
                    baseModels={props.baseModels}
                    modelDescription={props.modelDescription}
                    nextVersion={props.nextVersion}
                    setModelDescription={props.setModelDescription}
                />
            )}
            {activeTabPanel === "dataset" && <DatasetPreviewer dataset={props.dataset} />}
        </>
    );
};

const ValidationErrorList = (props: { validationErrors: any[] }) => (
    <List>
        <List.Content>
            {props.validationErrors.map((validationError: string) => (
                <List.Item key={validationError}>
                    <Icon name="warning sign" />
                    {validationError}
                </List.Item>
            ))}
        </List.Content>
    </List>
);

const onConnectionsLoaded = (
    connections: Connection[] | undefined,
    setConnection: React.Dispatch<React.SetStateAction<Connection | undefined>>,
    previousConnectionId?: number,
) => {
    if (!Array.isArray(connections) || connections.length === 0) {
        return;
    }

    const connectionToSet = connections.find((connection) => connection.id === previousConnectionId);
    if (connectionToSet) {
        setConnection(connectionToSet);
    }
};

const CreateAdapter = (props: {
    adapterRepo?: repo;
    adapterVersion?: adapterVersion;
    baseModels?: baseModel[];
    mutationFn: UseMutateFunction<finetuningJob, Error, createFinetuningJobRequest, unknown>;
    mutationIsPending: boolean;
    loading: boolean;
    errorMessage: string | null;
}) => {
    // Auth0 state (used only for prefetching dataset preview):
    const auth0TokenOptions = useAuth0TokenOptions();

    // Parent state:
    const { adapterRepo, adapterVersion, baseModels, mutationFn, mutationIsPending, loading, errorMessage } = props;

    // Reducer state:
    const { config, invalidFields } = useConfigState();
    const dispatch = useDispatch();

    // Recoil state:
    const [userContext] = useRecoilState(USER_STATE);
    // Derived user state:
    let userTier: tier | undefined;
    if (userContext) {
        const isKratosContext = isKratosUserContext(userContext);
        userTier = isKratosContext ? userContext.tenant.subscription.tier : userContext?.tenant.tier;
    }

    // Local state:
    const [connection, setConnection] = useState<Connection>();
    const [dataset, setDataset] = useState<Dataset>();
    const [modelDescription, setModelDescription] = useState<string>("");

    // Query state:
    const queryClient = useQueryClient();

    // Derived state:
    const nextAdapterVersion = getNextAdapterVersion(adapterRepo);

    // Event listeners:
    const selectConnection = (connection: Connection) => {
        setConnection(connection);
        setDataset(undefined);
    };

    const selectDataset = (dataset: Dataset) => {
        setDataset(dataset);
    };

    useEffect(() => {
        let config = adapterVersion?.finetuningJob?.params ?? {};
        config = removeEmptyParameters(config);
        dispatch({ type: "INIT", config, featureFlags: {} });
    }, [adapterVersion]);

    // Prefetch the dataset preview so that it's available when the user clicks the "Preview" button.
    if (dataset?.id && dataset?.status === "connected") {
        prefetchDatasetPreviewQuery(
            queryClient,
            dataset.id,
            defaultDatasetPreviewParams,
            commonGetDatasetPreviewQueryOptions(dataset),
            auth0TokenOptions,
        );
    }

    let modelValidationErrors = getModelValidationErrors(invalidFields, connection, dataset);

    return (
        <div style={{ padding: "20px", height: "100vh", maxHeight: "100vh", overflowY: "scroll" }}>
            <div className={"builder-header"}>
                <Breadcrumbs
                    adapterRepo={props.adapterRepo}
                    adapterVersionText={`New Adapter${
                        props.adapterVersion?.tag ? ` (from v${props.adapterVersion?.tag})` : ""
                    }`}
                />
                <SubscriptionButton isExpired={userContext?.isExpired} currentTier={userTier} />
            </div>
            {userContext?.isExpired && (
                <Message warning color="yellow" className="disabled-warning">
                    <p style={{ color: SEMANTIC_DARK_YELLOW }}>
                        <Icon name={"warning sign"} />
                        <b>You may no longer train models.</b>{" "}
                        <a href="https://predibase.com/contact-us" style={{ fontWeight: "bold" }}>
                            Contact us
                        </a>{" "}
                        to upgrade.
                    </p>
                </Message>
            )}
            <Divider hidden />
            <div style={{ display: "flex", justifyContent: "space-between" }}>
                <Header className="header" as="h2">
                    {adapterRepo?.name}
                    &nbsp;
                    <span style={{ color: SEMANTIC_GREY, fontWeight: "normal" }}>Version {nextAdapterVersion}</span>
                </Header>
                {modelValidationErrors.length > 0 ? (
                    <Popup
                        className="transition-scale"
                        content={<ValidationErrorList validationErrors={modelValidationErrors} />}
                        position={"right center"}
                        trigger={
                            <div style={{ display: "inline-block" }}>
                                <CreateAdapterButton
                                    onSubmit={() => {}}
                                    invalid={true}
                                    uploading={false}
                                    validating={false}
                                />
                            </div>
                        }
                    />
                ) : (
                    <CreateAdapterButton
                        onSubmit={() => {
                            mutationFn({
                                params: config as Record<string, any>,
                                dataset: `${connection?.name ?? ""}/${dataset?.name ?? ""}`,
                                repo: adapterRepo?.name ?? "",
                                description: modelDescription,
                            });
                        }}
                        invalid={modelValidationErrors.length > 0}
                        uploading={mutationIsPending}
                        validating={loading}
                    />
                )}
            </div>
            {errorMessage && (
                <Message negative>
                    <Message.Header>Error in model configuration</Message.Header>
                    <p>{errorMessage}</p>
                </Message>
            )}
            <Form>
                <Form.Group widths="equal">
                    {loading ? (
                        <Loader active />
                    ) : (
                        <>
                            <ConnectionSelector
                                connection={connection}
                                selectConnection={selectConnection}
                                onConnectionsLoaded={(connections: Connection[] | undefined) => {
                                    onConnectionsLoaded(connections, setConnection, adapterVersion?.connectionId);
                                }}
                            />
                            <DatasetSelector
                                connection={connection}
                                dataset={dataset}
                                initialDatasetId={adapterVersion?.datasetId}
                                selectDataset={selectDataset}
                            />
                        </>
                    )}
                </Form.Group>
            </Form>
            <Divider hidden />
            <AdapterTabPanel
                baseModels={baseModels}
                dataset={dataset}
                modelDescription={modelDescription}
                nextVersion={nextAdapterVersion}
                setModelDescription={setModelDescription}
            />
        </div>
    );
};

const CreateAdapterView = () => {
    // Meta state:
    const { search } = useLocation();
    const params = useMemo(() => new URLSearchParams(search), [search]);
    const match = useMatch("/adapters/create/:repoUUID");
    const adapterRepoUUID = match?.params?.repoUUID ?? "";
    const adapterVersionID = parseInt(params.get("adapterVersionID") ?? "");

    const navigate = useNavigate();

    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    // Query state:
    const queryClient = useQueryClient();
    const {
        data: baseModels,
        isLoading: isLoadingBaseModels,
        error: baseModelsError,
    } = useBaseModelsQuery({
        refetchOnWindowFocus: false,
    });
    const {
        data: adapterRepo,
        isLoading: isLoadingAdapterRepo,
        error: adapterError,
    } = useAdapterRepoQuery(adapterRepoUUID);
    const {
        data: adapterVersion,
        isLoading: isLoadingAdapterVersion,
        error: adapterVersionError,
    } = useAdapterVersionQuery(adapterRepoUUID, adapterVersionID, {
        enabled: Boolean(adapterRepoUUID) && !isNaN(adapterVersionID),
        refetchOnWindowFocus: false,
    });

    const {
        mutate: mutationFn,
        isPending: mutationIsPending,
        error: mutationError,
    } = useMutation({
        mutationFn: (config: createFinetuningJobRequest) => createFinetuningJob(config, auth0TokenOptions),
        onSuccess: (data) => {
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_REPOS_QUERY_KEY() });
            navigate(`/adapters/repo/${adapterRepoUUID}/version/${data.targetVersionTag}`);
        },
    });

    // Derived state:
    const isLoading = isLoadingAdapterRepo || isLoadingBaseModels || isLoadingAdapterVersion;
    const errorMessage =
        getErrorMessage(adapterError) ||
        getErrorMessage(baseModelsError) ||
        getErrorMessage(adapterVersionError) ||
        getErrorMessage(mutationError);

    // TODO: do we show some sort of disabled view when there's an error? Look at comps.

    return (
        <ConfigProvider>
            <CreateAdapter
                adapterRepo={adapterRepo}
                adapterVersion={adapterVersion}
                baseModels={baseModels}
                mutationFn={mutationFn}
                mutationIsPending={mutationIsPending}
                loading={isLoading}
                errorMessage={errorMessage}
            />
        </ConfigProvider>
    );
};

export default CreateAdapterView;
