import { UseMutateFunction, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useState } from "react";
import { useMatch, useNavigate, useSearchParams } from "react-router-dom";
import { useRecoilState } from "recoil";
import {
    Button,
    Divider,
    DropdownProps,
    Form,
    Header,
    Icon,
    List,
    Loader,
    Menu,
    Message,
    Popup,
} from "semantic-ui-react";
import {
    adapterVersion,
    baseModel,
    createFinetuningJobRequest,
    finetuningJob,
    repo,
    tier,
} from "../../../api_generated";
import { SubscriptionButton } from "../../../components/GlobalHeader/SubscriptionButtons";
import { useAuth0TokenOptions } from "../../../data";
import DatasetSelector from "../../../data/DatasetSelector";
import { defaultDatasetPreviewParams } from "../../../data/data";
import DatasetPreviewer from "../../../data/datasets/DatasetPreviewer";
import {
    commonGetDatasetPreviewQueryOptions,
    prefetchDatasetPreviewQuery,
    useConnectionsQuery,
} from "../../../data/query";
import { useBaseModelsQuery } from "../../../deployments/data/query";
import { USER_STATE } from "../../../state/global";
import { SEMANTIC_DARK_YELLOW, SEMANTIC_GREY } from "../../../utils/colors";
import { getErrorMessage } from "../../../utils/errors";
import { isKratosUserContext } from "../../../utils/kratos";
import { getOverflowItem } from "../../../utils/overflow";
import { rawtextSearch } from "../../../utils/search";
import AdapterBreadcrumb from "../../AdapterBreadcrumb";
import { createFinetuningJob } from "../../data";
import { getNextAdapterVersion } from "../../misc/utils";
import { GET_ADAPTER_REPOS_QUERY_KEY, useAdapterRepoQuery, useAdapterVersionQuery } from "../../query";
import ParametersForm from "./ParametersForm";
import { AdapterConfig, AdapterConfigSchema } from "./schema";
import { ConfigProvider, useConfigState } from "./store";
import { adapterTypes, getConfigValidationErrors } from "./utils";

import ContinueTrainingForm from "./ContinueTrainingForm";
import "./CreateAdapterView.css";

const CreateAdapterButton = (props: {
    onSubmit: () => void;
    invalid: boolean;
    uploading: boolean;
    validating: boolean;
}) => {
    // Parent state:
    const { onSubmit, invalid, uploading, validating } = props;

    const [userContext] = useRecoilState(USER_STATE);

    return (
        <Button
            content="Train"
            labelPosition="right"
            icon={
                <i className="icon">
                    <svg
                        style={{ width: "24px", height: "100%" }}
                        viewBox="0 0 24 21"
                        fill="none"
                        xmlns="http://www.w3.org/2000/svg"
                    >
                        <path
                            d="M9.65484 8.99984C10.4551 8.83955 10.9765 8.05157 10.8142 7.24203C10.6513 6.43096 9.86819 5.90609 9.06709 6.06658C8.26481 6.22727 7.74616 7.01228 7.90871 7.82377C8.07041 8.63284 8.8524 9.15991 9.65484 8.99984"
                            fill="white"
                        />
                        <path d="M19.6018 0L20.2972 1.76881L18.5284 2.46419L17.833 0.695378L19.6018 0Z" fill="white" />
                        <path
                            d="M20.9924 3.5368L21.6878 5.30602L19.919 6.00125L19.2236 4.23203L20.9924 3.5368Z"
                            fill="white"
                        />
                        <path
                            d="M22.3835 7.07465L23.0787 8.84345L21.3095 9.53883L20.6143 7.77002L22.3835 7.07465Z"
                            fill="white"
                        />
                        <path
                            d="M16.0647 1.3894L16.7597 3.15806L14.9909 3.85308L14.2959 2.08442L16.0647 1.3894Z"
                            fill="white"
                        />
                        <path
                            d="M17.4548 4.92773L18.1501 6.69716L16.3809 7.39254L15.6855 5.62311L17.4548 4.92773Z"
                            fill="white"
                        />
                        <path
                            d="M18.8444 8.46606L19.54 10.2355L17.7708 10.9311L17.0752 9.16165L18.8444 8.46606Z"
                            fill="white"
                        />
                        <path
                            d="M18.5278 2.46307L19.2234 4.23173L17.4544 4.92752L16.7588 3.15886L18.5278 2.46307Z"
                            fill="white"
                        />
                        <path
                            d="M19.9184 6.00092L20.6136 7.76993L18.8446 8.46515L18.1494 6.69614L19.9184 6.00092Z"
                            fill="white"
                        />
                        <path
                            d="M14.9829 3.85718L15.6785 5.62599L13.9095 6.32177L13.2139 4.55297L14.9829 3.85718Z"
                            fill="white"
                        />
                        <path
                            d="M16.3725 7.396L17.0673 9.16521L15.2983 9.86003L14.6035 8.09081L16.3725 7.396Z"
                            fill="white"
                        />
                        <path
                            d="M9.30656 19.8823L10.722 16.4623C10.9062 16.0176 10.7206 15.5056 10.294 15.2814L7.81249 13.9806L9.95714 11.89L13.5864 14.4836L12.2321 15.0099L12.4949 15.6866L24 11.2143L23.7371 10.5375L21.9721 11.2241L21.3094 9.53924L19.5404 10.2348L20.1998 11.9127L18.4199 12.6044L17.7629 10.9336L15.9943 11.6287L16.6487 13.2932L15.1894 13.8605L9.22551 9.42652C8.85933 9.18307 8.4835 9.13145 8.02456 9.33372L2.5867 11.7588C2.20285 11.9297 2.03073 12.3792 2.20164 12.7631C2.32778 13.0463 2.60558 13.214 2.89665 13.214C3.0001 13.214 3.10515 13.1929 3.20619 13.1479L5.8701 11.9598L0.21318 18.7532C0.194499 18.7759 0.179033 18.7968 0.165173 18.8173C0.152317 18.8357 0.142676 18.8554 0.131629 18.8747C0.129419 18.8795 0.126005 18.8844 0.123795 18.8888C-0.0764673 19.2393 -0.0354941 19.6922 0.259379 19.9966C0.441968 20.1854 0.685213 20.2802 0.928685 20.2802C1.16189 20.2802 1.3955 20.193 1.5765 20.018L5.83063 15.8995L8.60458 16.7089L7.58617 19.1696C7.44174 19.5183 7.52671 19.9034 7.76876 20.1617H6.08649V20.6923H16.1842V20.1617L9.12346 20.1621C9.19859 20.0821 9.26246 19.9893 9.30665 19.8821L9.30656 19.8823Z"
                            fill="white"
                        />
                    </svg>
                </i>
            }
            onClick={onSubmit}
            loading={uploading}
            disabled={invalid || validating || uploading || userContext?.isExpired}
            primary
        />
    );
};

const AdapterTabPanel = (props: {
    baseModels?: baseModel[];
    selectedDataset?: Dataset;
    modelDescription: string;
    nextVersion?: number;
    setModelDescription: React.Dispatch<React.SetStateAction<string>>;
    continueMode?: boolean;
    parentAdapterVersion?: adapterVersion;
}) => {
    // Parent state:
    const {
        baseModels,
        selectedDataset,
        modelDescription,
        nextVersion,
        setModelDescription,
        continueMode,
        parentAdapterVersion,
    } = props;

    // Tab state:
    const [activeTabPanel, setActiveTabPanel] = useState("parameters");

    // Event listeners:
    const handleMenuItemClick = (event: React.SyntheticEvent<HTMLElement, Event>, data: any) => {
        setActiveTabPanel(data.name);
    };

    const isImportedAdapter = parentAdapterVersion?.importedAdapterProperties !== undefined;
    const paramsFromJob = useMemo(() => {
        return isImportedAdapter && continueMode
            ? { base_model: parentAdapterVersion?.baseModel, epochs: AdapterConfigSchema.properties.epochs.default } // TODO: Hacky but works for now
            : (parentAdapterVersion?.finetuningJob?.params as Partial<AdapterConfig> | undefined);
    }, [parentAdapterVersion, continueMode]);
    const parentAndChildDatasetsMatch = selectedDataset?.id === parentAdapterVersion?.datasetId;
    const showLoaderForAllTabs = isImportedAdapter
        ? false
        : continueMode && (baseModels === undefined || selectedDataset === undefined);

    return (
        <>
            <Menu pointing secondary>
                <Menu.Item
                    name="parameters"
                    content="Parameters"
                    active={activeTabPanel === "parameters"}
                    onClick={handleMenuItemClick}
                />
                <Menu.Item
                    name="dataset"
                    content="Dataset Preview"
                    active={activeTabPanel === "dataset"}
                    onClick={handleMenuItemClick}
                />
            </Menu>
            {showLoaderForAllTabs ? (
                <div style={{ textAlign: "center", marginTop: "4em", marginBottom: "4em" }}>
                    <Loader active inline />
                </div>
            ) : (
                <>
                    {activeTabPanel === "parameters" && (
                        <>
                            {continueMode ? (
                                <ContinueTrainingForm
                                    baseModels={baseModels}
                                    modelDescription={modelDescription}
                                    setModelDescription={setModelDescription}
                                    nextVersion={nextVersion}
                                    parentAdapterType={
                                        (parentAdapterVersion?.adapterType as adapterTypes | "" | undefined) || "lora"
                                    } // TODO: Hacky for now because the adapterType for an imported adapter is an empty string
                                    paramsFromJob={paramsFromJob}
                                    parentAndChildDatasetsMatch={parentAndChildDatasetsMatch}
                                />
                            ) : (
                                <ParametersForm
                                    baseModels={baseModels}
                                    modelDescription={modelDescription}
                                    setModelDescription={setModelDescription}
                                    nextVersion={nextVersion}
                                    paramsFromJob={paramsFromJob}
                                />
                            )}
                        </>
                    )}
                    {activeTabPanel === "dataset" && <DatasetPreviewer dataset={selectedDataset} />}
                </>
            )}
        </>
    );
};

const ValidationErrorList = (props: { validationErrors: any[] }) => (
    <List>
        <List.Content>
            {props.validationErrors.map((validationError: string) => (
                <List.Item key={validationError}>
                    <Icon name="warning sign" />
                    {validationError}
                </List.Item>
            ))}
        </List.Content>
    </List>
);

// TODO: Move to a shared location after we delete the old CMV code.
const ImprovedConnectionSelector = (props: {
    selectedConnection?: Connection;
    onSelectConnection: (connection: Connection) => void;
    disabled?: boolean;
    previousConnectionId?: Connection["id"];
}) => {
    // Parent state:
    const { selectedConnection, onSelectConnection, disabled, previousConnectionId } = props;

    // Query state:
    const {
        data: connectionsResponse,
        isLoading: connectionsAreLoading,
        error: connectionsError,
    } = useConnectionsQuery(undefined, {
        refetchInterval: false,
    });
    const availableConnections =
        connectionsError !== null || connectionsResponse === undefined
            ? []
            : (connectionsResponse.connections || [])
                  .filter((x: Connection) => x.status === "connected")
                  .sort((a: Connection, b: Connection) => {
                      const collator = new Intl.Collator("en", { numeric: true, sensitivity: "base" });
                      const first = a.type + "_" + a.name;
                      const second = b.type + "_" + b.name;
                      return collator.compare(first, second);
                  });
    // On first load, auto-select the previous connection (if provided):
    useEffect(() => {
        if (availableConnections.length === 0 || selectedConnection !== undefined) {
            return;
        }
        const previousConnection = availableConnections.find((connection) => connection.id === previousConnectionId);
        if (previousConnection) {
            onSelectConnection(previousConnection);
        } else {
            onSelectConnection(availableConnections.find((connection) => connection.name === "file_uploads")!);
        }
    }, [availableConnections]);

    // Dropdown options:
    const dropdownOptions = availableConnections.map((connection) => {
        const fullName = connection.type + ": " + connection.name;
        return {
            rawtext: fullName,
            text: fullName,
            value: connection.id,
        };
    });

    return (
        <div style={{ width: "100%", paddingLeft: "0.5rem" }}>
            <Form.Select
                error={connectionsError !== null}
                name="connection"
                label="Connection"
                options={dropdownOptions}
                value={selectedConnection?.id}
                // @ts-expect-error Element | undefined is officially not allowed, but it works
                text={
                    // TODO: Copied from DatasetSelector because it renders better than original connection selector code, generalize later:
                    <div style={{ width: "100%" }}>
                        {getOverflowItem(
                            dropdownOptions.find((co) => co.value === selectedConnection?.id)?.rawtext,
                            true,
                            85,
                            "inline-block",
                        )}
                    </div>
                }
                fluid
                placeholder="Connection"
                onChange={(_, data: DropdownProps) => {
                    const newConnection = availableConnections.find((c) => c.id === (data.value as number));
                    if (newConnection) {
                        onSelectConnection(newConnection);
                    }
                }}
                disabled={disabled}
                search={rawtextSearch}
                selection
                loading={connectionsAreLoading}
            />
        </div>
    );
};

const CreateAdapter = (props: {
    adapterRepo?: repo;
    parentAdapterVersion?: adapterVersion;
    baseModels?: baseModel[];
    mutationFn: UseMutateFunction<finetuningJob, Error, createFinetuningJobRequest, unknown>;
    mutationIsPending: boolean;
    loading: boolean;
    errorMessage: string | null;
    continueMode: boolean;
}) => {
    // Auth0 state (used only for prefetching dataset preview):
    const auth0TokenOptions = useAuth0TokenOptions();

    // Parent state:
    const {
        adapterRepo,
        parentAdapterVersion,
        baseModels,
        mutationFn,
        mutationIsPending,
        loading,
        errorMessage,
        continueMode,
    } = props;

    // Reducer state:
    const { config, invalidFields } = useConfigState();

    // Recoil state:
    const [userContext] = useRecoilState(USER_STATE);
    // Derived user state:
    let userTier: tier | undefined;
    if (userContext) {
        const isKratosContext = isKratosUserContext(userContext);
        userTier = isKratosContext ? userContext.tenant.subscription.tier : userContext?.tenant.tier;
    }

    // Local state:
    const [selectedConnection, setSelectedConnection] = useState<Connection>();
    const [selectedDataset, setSelectedDataset] = useState<Dataset>();
    const [modelDescription, setModelDescription] = useState<string>("");

    // Query state:
    const queryClient = useQueryClient();

    // Derived state:
    const nextAdapterVersion = getNextAdapterVersion(adapterRepo);
    const parentAndChildDatasetsMatch = selectedDataset?.id === parentAdapterVersion?.datasetId;

    // Event listeners:
    const onSelectConnection = (connection: Connection) => {
        setSelectedConnection(connection);
        setSelectedDataset(undefined);
    };

    const onSelectDataset = (dataset: Dataset) => {
        setSelectedDataset(dataset);
    };

    // Prefetch the dataset preview so that it's available when the user clicks the "Preview" button.
    if (selectedDataset?.id && selectedDataset?.status === "connected") {
        prefetchDatasetPreviewQuery(
            queryClient,
            selectedDataset.id,
            defaultDatasetPreviewParams,
            commonGetDatasetPreviewQueryOptions(selectedDataset),
            auth0TokenOptions,
        );
    }

    const configValidationErrors = getConfigValidationErrors(
        invalidFields,
        selectedConnection,
        selectedDataset,
        continueMode,
        parentAndChildDatasetsMatch,
    );

    return (
        <div style={{ padding: "20px", height: "100vh", maxHeight: "100vh", overflowY: "scroll" }}>
            <div className={"builder-header"}>
                <AdapterBreadcrumb
                    adapterRepo={adapterRepo}
                    adapterVersionTag={parentAdapterVersion?.tag}
                    continueMode={continueMode}
                    newVersion={true}
                />
                <SubscriptionButton isExpired={userContext?.isExpired} currentTier={userTier} />
            </div>
            {userContext?.isExpired && (
                <Message warning color="yellow" className="disabled-warning">
                    <p style={{ color: SEMANTIC_DARK_YELLOW }}>
                        <Icon name={"warning sign"} />
                        <b>You may no longer train models.</b>{" "}
                        <a href="https://predibase.com/contact-us" style={{ fontWeight: "bold" }}>
                            Contact us
                        </a>{" "}
                        to upgrade.
                    </p>
                </Message>
            )}
            <Divider hidden />
            <div style={{ display: "flex", justifyContent: "space-between" }}>
                <Header className="header" as="h2">
                    {adapterRepo?.name}
                    &nbsp;
                    <span style={{ color: SEMANTIC_GREY, fontWeight: "normal" }}>Version {nextAdapterVersion}</span>
                </Header>
                {configValidationErrors.length > 0 ? (
                    <Popup
                        className="transition-scale"
                        content={<ValidationErrorList validationErrors={configValidationErrors} />}
                        position={"right center"}
                        trigger={
                            <div style={{ display: "inline-block" }}>
                                <CreateAdapterButton
                                    onSubmit={() => {}}
                                    invalid={true}
                                    uploading={false}
                                    validating={false}
                                />
                            </div>
                        }
                    />
                ) : (
                    <CreateAdapterButton
                        onSubmit={() => {
                            mutationFn({
                                params: config!,
                                dataset: `${selectedConnection?.name ?? ""}/${selectedDataset?.name ?? ""}`,
                                repo: adapterRepo?.name ?? "",
                                description: modelDescription,
                                continueFromVersion: continueMode
                                    ? `${adapterRepo?.name}/${parentAdapterVersion?.tag}`
                                    : undefined,
                            });
                        }}
                        invalid={configValidationErrors.length > 0}
                        uploading={mutationIsPending}
                        validating={loading}
                    />
                )}
            </div>
            {continueMode && (
                <h3 style={{ marginTop: "0.5rem", marginBottom: "2rem", fontWeight: "normal" }}>
                    Continue training from{" "}
                    <b>
                        {adapterRepo?.name}/{parentAdapterVersion?.tag}
                    </b>
                </h3>
            )}
            {errorMessage && (
                <Message negative>
                    <Message.Header>Error in model configuration</Message.Header>
                    <p>{errorMessage}</p>
                </Message>
            )}
            <Form>
                <Form.Group widths="equal">
                    {loading ? (
                        <Loader active />
                    ) : (
                        <>
                            <ImprovedConnectionSelector
                                selectedConnection={selectedConnection}
                                onSelectConnection={onSelectConnection}
                                previousConnectionId={parentAdapterVersion?.connectionId}
                            />
                            <DatasetSelector
                                connection={selectedConnection}
                                dataset={selectedDataset}
                                initialDatasetId={parentAdapterVersion?.datasetId}
                                selectDataset={onSelectDataset}
                            />
                        </>
                    )}
                </Form.Group>
            </Form>
            <Divider hidden />
            <AdapterTabPanel
                baseModels={baseModels}
                selectedDataset={selectedDataset}
                modelDescription={modelDescription}
                nextVersion={nextAdapterVersion}
                setModelDescription={setModelDescription}
                continueMode={continueMode}
                parentAdapterVersion={parentAdapterVersion}
            />
        </div>
    );
};

const CreateAdapterView = () => {
    // URL state:
    const navigate = useNavigate();
    const match = useMatch("/adapters/create/:repoUUID/:continue?");
    const adapterRepoUUID = match?.params?.repoUUID ?? "";
    const continueMode = Boolean(match?.params?.continue);
    const [searchParams] = useSearchParams();
    const parsedAdapterVersionTag = parseInt(searchParams.get("adapterVersionID") ?? "");

    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    // Query state:
    const queryClient = useQueryClient();
    const {
        data: baseModels,
        isLoading: isLoadingBaseModels,
        error: baseModelsError,
    } = useBaseModelsQuery({
        refetchOnWindowFocus: false,
    });
    const {
        data: adapterRepo,
        isLoading: isLoadingAdapterRepo,
        error: adapterRepoError,
    } = useAdapterRepoQuery(adapterRepoUUID);
    const {
        data: parentAdapterVersion,
        isLoading: isLoadingParentAdapterVersion,
        error: parentAdapterVersionError,
    } = useAdapterVersionQuery(adapterRepoUUID, parsedAdapterVersionTag, {
        enabled: Boolean(adapterRepoUUID) && !isNaN(parsedAdapterVersionTag),
        refetchOnWindowFocus: false,
    });

    const {
        mutate: mutationFn,
        isPending: mutationIsPending,
        error: mutationError,
    } = useMutation({
        mutationFn: (config: createFinetuningJobRequest) => createFinetuningJob(config, auth0TokenOptions),
        onSuccess: (data) => {
            queryClient.invalidateQueries({ queryKey: GET_ADAPTER_REPOS_QUERY_KEY() });
            navigate(`/adapters/repo/${adapterRepoUUID}/version/${data.targetVersionTag}`);
        },
    });

    // Derived state:
    const isLoading = isLoadingAdapterRepo || isLoadingBaseModels || isLoadingParentAdapterVersion;
    const errorMessage =
        getErrorMessage(adapterRepoError) ||
        getErrorMessage(baseModelsError) ||
        getErrorMessage(parentAdapterVersionError) ||
        getErrorMessage(mutationError);

    // TODO: do we show some sort of disabled view when there's an error? Look at comps.

    return (
        <ConfigProvider>
            <CreateAdapter
                adapterRepo={adapterRepo}
                parentAdapterVersion={parentAdapterVersion}
                baseModels={baseModels}
                mutationFn={mutationFn}
                mutationIsPending={mutationIsPending}
                loading={isLoading}
                errorMessage={errorMessage}
                continueMode={continueMode}
            />
        </ConfigProvider>
    );
};

export default CreateAdapterView;
