import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { AxiosInstance, AxiosResponse } from "axios";
import React, { CSSProperties, useContext, useEffect, useRef, useState } from "react";
import { Link, useNavigate } from "react-router-dom";
import { useRecoilState } from "recoil";
import {
    Breadcrumb,
    Button,
    Divider,
    DropdownItemProps,
    DropdownProps,
    Form,
    Grid,
    Header,
    Icon,
    Loader,
    Menu,
    MenuItem,
    Message,
    Modal,
    Popup,
    Radio,
    Segment,
    TabProps,
} from "semantic-ui-react";
import { tier } from "../../../api_generated";
import Chip from "../../../components/Chip";
import Dropdown from "../../../components/Dropdown";
import EditDiv from "../../../components/EditDiv";
import { SubscriptionButton } from "../../../components/GlobalHeader/SubscriptionButtons";
import TextArea from "../../../components/TextArea";
import FieldWrapper from "../../../components/forms/json-schema-form/fields/FieldWrapper";
import { Auth0TokenOptions, useAuth0TokenOptions } from "../../../data";
import ConnectionSelector from "../../../data/ConnectionSelector";
import DatasetSelector from "../../../data/DatasetSelector";
import { defaultDatasetPreviewParams } from "../../../data/data";
import DatasetPreviewer from "../../../data/datasets/DatasetPreviewer";
import { commonGetDatasetPreviewQueryOptions, prefetchDatasetPreviewQuery } from "../../../data/query";
import { useEnginesQuery } from "../../../engines/query";
import { track } from "../../../metrics/june";
import metrics from "../../../metrics/metrics";
import { GET_USER_CREDITS_QUERY_KEY } from "../../../query";
import { USER_STATE } from "../../../state/global";
import { LLMTemplates } from "../../../types/model/llmTemplates";
import { ModelTypes } from "../../../types/model/modelTypes";
import { createV1APIServer, getDocsHome, redirectIfSessionInvalid } from "../../../utils/api";
import { SEMANTIC_BLUE, SEMANTIC_DARK_YELLOW, SEMANTIC_GREY, SEMANTIC_GREY_DISABLED } from "../../../utils/colors";
import { renderValidationErrorList } from "../../../utils/config";
import { getErrorMessage } from "../../../utils/errors";
import { FeatureFlagsContext } from "../../../utils/feature-flags";
import { isKratosUserContext } from "../../../utils/kratos";
import { detectDatasetConfig } from "../../data";
import { engineCanTrainModels, getLatestModel, getModelType, isLLMModel } from "../../util";
import AdvancedConfigManager from "../AdvancedConfigManager";
import ComputeConfigManager from "../ComputeConfigManager";
import FieldsTable from "../FieldsTable/FieldsTable";
import ModelTabPanel from "../ModelTabPanel";
import SelectDefaults from "../SelectDefaults";
import qlora_4bit from "../llms/qlora_4bit.json";
import {
    useConfigState,
    useDispatch,
    useIsDecisionTree,
    useIsHyperoptEnabled,
    useIsLLM,
    useModelConversionLoss,
} from "../store";
import { ECDConfirmationModal, GBMConfirmationModal, LLMConfirmationModal } from "./ModelConfirmationModal";
import { ReactComponent as LLMIcon } from "./llm-icon.svg";
import { ReactComponent as SupervisedMLIcon } from "./supervised-ml-icon.svg";
import { DisablingMessages } from "./util";

import "./CreateModelView.css";

const Viewcrumb = (props: { modelRepo: ModelRepo; parentModelName: string; parentModelTag: number }) => {
    return (
        <Breadcrumb.Section>
            <Link
                to={"/models/repo/" + props.modelRepo?.id}
                className={metrics.BLOCK_AUTO_CAPTURE}
                onClick={() => metrics.capture("Model.View.Navigate", { method: "breadcrumb" })}
            >
                <Icon name={"folder outline"} />
                {props.parentModelName}
            </Link>
        </Breadcrumb.Section>
    );
};

const Breadcrumbs = (props: {
    editType?: string;
    modelRepoID?: number;
    modelRepo?: ModelRepo;
    parentModelName: string;
    parentModelTag: number;
}) => {
    return (
        <Breadcrumb>
            <Breadcrumb.Section>
                <Link
                    to="/models"
                    className={metrics.BLOCK_AUTO_CAPTURE}
                    onClick={() => metrics.capture("Model.Navigate", { method: "breadcrumb" })}
                >
                    Models
                </Link>
            </Breadcrumb.Section>
            <Breadcrumb.Divider />
            {props.modelRepo?.id && (
                <>
                    <Viewcrumb
                        modelRepo={props.modelRepo}
                        parentModelName={props.parentModelName}
                        parentModelTag={props.parentModelTag}
                    />
                    <Breadcrumb.Divider />
                </>
            )}
            <Breadcrumb.Section active>
                {props.modelRepoID != null
                    ? props.editType === "fork"
                        ? "Fork"
                        : "New Model Version"
                    : "New Model Version" + (props.parentModelTag ? ` (from v${props.parentModelTag})` : "")}
            </Breadcrumb.Section>
        </Breadcrumb>
    );
};

const CreateModelButton = (props: {
    onSubmit: () => void;
    invalid: boolean;
    uploading: boolean;
    validating: boolean;
}) => {
    return (
        <Button
            content={useIsHyperoptEnabled() ? "Train with Hyperopt" : "Train"}
            labelPosition="right"
            icon={
                <i className="icon">
                    <svg
                        style={{ width: "24px", height: "100%" }}
                        viewBox="0 0 24 21"
                        fill="none"
                        xmlns="http://www.w3.org/2000/svg"
                    >
                        <path
                            d="M9.65484 8.99984C10.4551 8.83955 10.9765 8.05157 10.8142 7.24203C10.6513 6.43096 9.86819 5.90609 9.06709 6.06658C8.26481 6.22727 7.74616 7.01228 7.90871 7.82377C8.07041 8.63284 8.8524 9.15991 9.65484 8.99984"
                            fill="white"
                        />
                        <path d="M19.6018 0L20.2972 1.76881L18.5284 2.46419L17.833 0.695378L19.6018 0Z" fill="white" />
                        <path
                            d="M20.9924 3.5368L21.6878 5.30602L19.919 6.00125L19.2236 4.23203L20.9924 3.5368Z"
                            fill="white"
                        />
                        <path
                            d="M22.3835 7.07465L23.0787 8.84345L21.3095 9.53883L20.6143 7.77002L22.3835 7.07465Z"
                            fill="white"
                        />
                        <path
                            d="M16.0647 1.3894L16.7597 3.15806L14.9909 3.85308L14.2959 2.08442L16.0647 1.3894Z"
                            fill="white"
                        />
                        <path
                            d="M17.4548 4.92773L18.1501 6.69716L16.3809 7.39254L15.6855 5.62311L17.4548 4.92773Z"
                            fill="white"
                        />
                        <path
                            d="M18.8444 8.46606L19.54 10.2355L17.7708 10.9311L17.0752 9.16165L18.8444 8.46606Z"
                            fill="white"
                        />
                        <path
                            d="M18.5278 2.46307L19.2234 4.23173L17.4544 4.92752L16.7588 3.15886L18.5278 2.46307Z"
                            fill="white"
                        />
                        <path
                            d="M19.9184 6.00092L20.6136 7.76993L18.8446 8.46515L18.1494 6.69614L19.9184 6.00092Z"
                            fill="white"
                        />
                        <path
                            d="M14.9829 3.85718L15.6785 5.62599L13.9095 6.32177L13.2139 4.55297L14.9829 3.85718Z"
                            fill="white"
                        />
                        <path
                            d="M16.3725 7.396L17.0673 9.16521L15.2983 9.86003L14.6035 8.09081L16.3725 7.396Z"
                            fill="white"
                        />
                        <path
                            d="M9.30656 19.8823L10.722 16.4623C10.9062 16.0176 10.7206 15.5056 10.294 15.2814L7.81249 13.9806L9.95714 11.89L13.5864 14.4836L12.2321 15.0099L12.4949 15.6866L24 11.2143L23.7371 10.5375L21.9721 11.2241L21.3094 9.53924L19.5404 10.2348L20.1998 11.9127L18.4199 12.6044L17.7629 10.9336L15.9943 11.6287L16.6487 13.2932L15.1894 13.8605L9.22551 9.42652C8.85933 9.18307 8.4835 9.13145 8.02456 9.33372L2.5867 11.7588C2.20285 11.9297 2.03073 12.3792 2.20164 12.7631C2.32778 13.0463 2.60558 13.214 2.89665 13.214C3.0001 13.214 3.10515 13.1929 3.20619 13.1479L5.8701 11.9598L0.21318 18.7532C0.194499 18.7759 0.179033 18.7968 0.165173 18.8173C0.152317 18.8357 0.142676 18.8554 0.131629 18.8747C0.129419 18.8795 0.126005 18.8844 0.123795 18.8888C-0.0764673 19.2393 -0.0354941 19.6922 0.259379 19.9966C0.441968 20.1854 0.685213 20.2802 0.928685 20.2802C1.16189 20.2802 1.3955 20.193 1.5765 20.018L5.83063 15.8995L8.60458 16.7089L7.58617 19.1696C7.44174 19.5183 7.52671 19.9034 7.76876 20.1617H6.08649V20.6923H16.1842V20.1617L9.12346 20.1621C9.19859 20.0821 9.26246 19.9893 9.30665 19.8821L9.30656 19.8823Z"
                            fill="white"
                        />
                    </svg>
                </i>
            }
            onClick={props.onSubmit}
            loading={props.uploading}
            disabled={props.invalid || props.validating || props.uploading}
            primary
        />
    );
};

const SetupTypeSelector = (props: { setup: string; setSetup: React.Dispatch<React.SetStateAction<string>> }) => {
    return (
        <Form className="select-model-setup">
            <Form.Field>
                <Radio
                    label={
                        <label style={{ display: "flex" }}>
                            <div>
                                <strong style={{ marginRight: "0.5174rem" }}>Explore suggested models</strong>
                                <Chip
                                    color={SEMANTIC_BLUE}
                                    opacity={0.1}
                                    text="Recommended for initial exploration"
                                    size="small"
                                    width={"15.745rem"}
                                />
                                <br />
                                <span style={{ color: SEMANTIC_GREY_DISABLED, fontSize: `${12 / 14}rem` }}>
                                    Train 4-10 models with tried and tested configurations to get a set of baseline
                                    models for your dataset.
                                </span>
                                <br />
                                <span style={{ color: SEMANTIC_GREY_DISABLED, fontSize: `${12 / 14}rem` }}>
                                    Note: This option will take more time and resources. Depending on the dataset, may
                                    include neural networks, gradient boosted trees, and linear models.
                                </span>
                            </div>
                        </label>
                    }
                    name="setupType"
                    value="suggest"
                    checked={props.setup === "suggest"}
                    onChange={() => props.setSetup("suggest")}
                />
            </Form.Field>
            <Form.Field>
                <Radio
                    label={
                        <label style={{ display: "flex" }}>
                            <div>
                                <strong style={{ marginRight: "0.5174rem" }}>Build a custom model</strong>
                                <br />
                                <span
                                    style={{
                                        color: SEMANTIC_GREY_DISABLED,
                                        fontSize: `${12 / 14}rem`,
                                        lineHeight: `${24 / 14}rem`,
                                    }}
                                >
                                    Train a single model. Use out-of-the-box defaults or tweak parameters and
                                    architecture to your liking.
                                </span>
                            </div>
                        </label>
                    }
                    name="setupType"
                    value="custom"
                    checked={props.setup === "custom"}
                    onChange={() => props.setSetup("custom")}
                />
            </Form.Field>
        </Form>
    );
};
const ModelTypeSelector = (props: { modelType: ModelTypes }) => {
    const { modelType } = props;
    const LLMConversionLoss = useModelConversionLoss(ModelTypes.LARGE_LANGUAGE_MODEL);
    const ECDConversionLoss = useModelConversionLoss(ModelTypes.NEURAL_NETWORK);
    const dispatch = useDispatch();
    const [showGBMConfirmationModal, setShowGBMConfirmationModal] = useState(false);
    const [showECDConfirmationModal, setShowECDConfirmationModal] = useState(false);
    const [showLLMConfirmationModal, setShowLLMConfirmationModal] = useState(false);

    return (
        <>
            <GBMConfirmationModal
                open={showGBMConfirmationModal}
                setShowConfirmationModal={setShowGBMConfirmationModal}
            />
            <ECDConfirmationModal
                open={showECDConfirmationModal}
                setShowConfirmationModal={setShowECDConfirmationModal}
            />
            <LLMConfirmationModal
                open={showLLMConfirmationModal}
                setShowConfirmationModal={setShowLLMConfirmationModal}
            />
            <Form className="select-model-type">
                <Form.Field style={{ marginBottom: `${16 / 14}rem` }}>
                    <Radio
                        disabled
                        label={
                            <label style={{ display: "flex", flexWrap: "wrap" }}>
                                <LLMIcon style={{ display: "block", marginRight: "0.89rem" }} />
                                <div>
                                    <strong style={{ marginRight: "0.5174rem" }}>
                                        Fine-Tune a Large Language Model
                                    </strong>
                                    <br />
                                    <span
                                        style={{
                                            color: SEMANTIC_GREY_DISABLED,
                                            fontSize: `${12 / 14}rem`,
                                            lineHeight: `${24 / 14}rem`,
                                            marginRight: "0.5174rem",
                                        }}
                                    >
                                        Fine-tune an LLM for a text generation task.
                                    </span>
                                </div>
                            </label>
                        }
                        name="modelType"
                        value={ModelTypes.LARGE_LANGUAGE_MODEL}
                        checked={modelType === ModelTypes.LARGE_LANGUAGE_MODEL}
                        onChange={() => {
                            if (LLMConversionLoss.isMissingFields) {
                                setShowLLMConfirmationModal(true);
                            } else {
                                dispatch({ type: "APPLY_TEMPLATE", template: LLMTemplates.QLORA_4BIT });
                            }
                        }}
                    />
                </Form.Field>
                <Form.Field>
                    <Radio
                        label={
                            <label style={{ display: "flex", flexWrap: "wrap" }}>
                                <SupervisedMLIcon style={{ display: "block", marginRight: "0.89rem" }} />
                                <div>
                                    <strong style={{ marginRight: "0.5174rem" }}>Train a Custom Model</strong>
                                    <br />
                                    <span
                                        style={{
                                            color: SEMANTIC_GREY_DISABLED,
                                            fontSize: `${12 / 14}rem`,
                                            lineHeight: `${24 / 14}rem`,
                                        }}
                                    >
                                        Train a custom deep learning model for a predictive task.
                                    </span>
                                    <br />
                                </div>
                            </label>
                        }
                        name="modelType"
                        value={ModelTypes.NEURAL_NETWORK}
                        checked={modelType === ModelTypes.NEURAL_NETWORK || modelType === ModelTypes.DECISION_TREE}
                        onChange={() => {
                            // NOTE(MLX-1207): Disable text output features for ECD.
                            // In case a user sets a text output feature in the LLM modal, and then they switch back to
                            // the ECD option, we may need to clear the target form field if there are unsupported
                            // feature types.
                            if (ECDConversionLoss.isMissingFields) {
                                setShowECDConfirmationModal(true);
                            } else {
                                dispatch({ type: "UPDATE_MODEL_TYPE", modelType: ModelTypes.NEURAL_NETWORK });
                            }
                        }}
                    />
                </Form.Field>
            </Form>
        </>
    );
};

// TODO: harden type
const getSupportedTargetTypes = (schema: any) => {
    return schema?.properties?.output_features?.items?.properties?.type?.enum;
};

const Targets = () => {
    const { fields, schema } = useConfigState();
    // Force react-tracked to re-render this component when fields change
    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    const fieldsJSON = JSON.stringify(fields); // DO NOT DELETE!
    const isDecisionTree = useIsDecisionTree();
    const isLLM = useIsLLM();
    const dispatch = useDispatch();
    const [enableFeature, setEnableFeature] = useState<any>();

    const renderLabel = (option: DropdownItemProps) => ({
        color: "blue",
        content: option.text,
    });

    const supportedOutputFeatureTypes = getSupportedTargetTypes(schema);

    const fieldOptions = fields?.map((field, index) => {
        const fieldType = field.config.type.toLowerCase();
        const isSupportedTargetType = supportedOutputFeatureTypes.includes(fieldType);
        // NOTE(MLX-1207): Disable text output features for ECD.
        // This is not strictly necessary, but provides a slightly more customized disable message.
        // LLMs are the only model type that support text output features.
        if (!isLLM && fieldType === "text") {
            return {
                text: field.name + ` (use LLMs for text output features)`,
                value: index,
                key: index,
                disabled: true,
            };
        }

        return {
            text: field.name + (isSupportedTargetType ? "" : ` (data type not supported)`),
            value: index,
            key: index,
            disabled: !isSupportedTargetType,
        };
    });

    const selectedFields: number[] = [];
    fields?.forEach((field, index) => {
        if (field.mode === "output") {
            selectedFields.push(index);
        }
    });

    const handleTargetChange = (event: React.SyntheticEvent<HTMLElement, Event>, data: DropdownProps) => {
        // data.value is an array of numbers representing the index of each
        // selected option in the options array passed to the Form.Select.
        // eg: [0,5]
        const targetValues = data.value as number[];

        // Create a set, so we can easily determine what is currently in the field
        let outputFieldIndexes = new Set();
        targetValues.forEach((fieldIndex: number) => outputFieldIndexes.add(fieldIndex));

        fields?.forEach((field, index) => {
            if (outputFieldIndexes.has(index) && field.mode !== "output") {
                let newField = { ...field };
                newField.mode = "output";

                // Pop a modal asking the user to enable the feature before
                // adding it as a target:
                if (field.excluded) {
                    setEnableFeature({
                        index,
                        value: {
                            ...newField,
                            excluded: false,
                        },
                    });
                } else {
                    dispatch({ type: "UPDATE_FIELD", index, value: newField });
                }
            } else if (!outputFieldIndexes.has(index) && field.mode === "output") {
                let newField = { ...field };
                newField.mode = "input";
                dispatch({ type: "UPDATE_FIELD", index, value: newField });
            }
        });
    };

    return (
        <>
            <Modal open={Boolean(enableFeature)}>
                <Modal.Content>
                    <Modal.Description>
                        <p>
                            This output feature is currently excluded. To include it as a target your model, activate
                            the feature.
                        </p>
                    </Modal.Description>
                </Modal.Content>
                <Modal.Actions>
                    <Button content="Never mind" onClick={() => setEnableFeature(undefined)} />
                    <Button
                        content="Activate"
                        labelPosition="right"
                        icon="checkmark"
                        onClick={() => {
                            setEnableFeature(undefined);
                            dispatch({ type: "UPDATE_FIELD", ...enableFeature });
                        }}
                        positive
                    />
                </Modal.Actions>
            </Modal>
            <Form.Select
                className={metrics.BLOCK_AUTO_CAPTURE}
                label={isDecisionTree ? "Target" : "Target(s)"}
                placeholder="Target"
                fluid
                multiple
                search
                selection
                options={fieldOptions}
                value={selectedFields}
                renderLabel={renderLabel}
                onChange={handleTargetChange}
            />
        </>
    );
};

const getNextModelVersion = (modelsInRepoCount?: number, modelRepo?: ModelRepo) => {
    if (modelsInRepoCount) {
        return Number(modelsInRepoCount) + 1;
    }

    const latest = getLatestModel(modelRepo?.models)?.repoVersion;
    if (latest) {
        return Number(latest) + 1;
    }
    return 1;
};

const getModelHeader = (
    modelName?: string,
    editType?: string,
    modelsInRepoCount?: number,
    modelRepo?: ModelRepo,
    parentModelName?: string,
) => {
    if (!modelName) {
        return null;
    }
    if (editType === "fork") {
        return (
            <span>
                Forking <span style={{ color: SEMANTIC_GREY }}>{parentModelName}</span> → {modelName}
            </span>
        );
    }
    return (
        <>
            {modelName}
            &nbsp;
            <span style={{ color: SEMANTIC_GREY, fontWeight: "normal" }}>
                Version {getNextModelVersion(modelsInRepoCount, modelRepo)}
            </span>
        </>
    );
};

const CreateModelNavigation = (props: {
    config?: CreateModelConfig;
    step: number;
    setup: string;
    onSubmit: () => void;
    uploading: boolean;
    validating: boolean;
    invalid: boolean;
    errorList: JSX.Element;
    setStep: React.Dispatch<React.SetStateAction<number>>;
    onSubmitMultiTrain: () => void;
}) => {
    const { config, invalid, step, setup, onSubmit, uploading, validating, errorList, setStep, onSubmitMultiTrain } =
        props;

    if (step === 1) {
        if (isLLMModel(config)) {
            return (
                <div>
                    {invalid ? (
                        <Popup
                            className="transition-scale"
                            content={errorList}
                            position={"right center"}
                            trigger={
                                <div style={{ display: "inline-block" }}>
                                    <CreateModelButton
                                        onSubmit={onSubmit}
                                        invalid={true}
                                        uploading={uploading}
                                        validating={validating}
                                    />
                                </div>
                            }
                        />
                    ) : (
                        <CreateModelButton
                            onSubmit={onSubmit}
                            invalid={invalid}
                            uploading={uploading}
                            validating={validating}
                        />
                    )}
                </div>
            );
        }

        if (setup === "custom") {
            return (
                <div>
                    {invalid ? (
                        <Popup
                            className="transition-scale"
                            content={errorList}
                            position={"right center"}
                            trigger={
                                <div style={{ display: "inline-block" }}>
                                    <Button primary disabled>
                                        Next
                                    </Button>
                                </div>
                            }
                        />
                    ) : (
                        <Button
                            primary
                            onClick={() => {
                                setStep(2);
                            }}
                        >
                            Next
                        </Button>
                    )}
                </div>
            );
        }
        // for suggested configs
        return (
            <div>
                {invalid ? (
                    <Popup
                        className="transition-scale"
                        content={errorList}
                        position={"right center"}
                        trigger={
                            <div style={{ display: "inline-block" }}>
                                <CreateModelButton
                                    onSubmit={onSubmit}
                                    invalid={true}
                                    uploading={uploading}
                                    validating={validating}
                                />
                            </div>
                        }
                    />
                ) : (
                    <CreateModelButton
                        onSubmit={onSubmitMultiTrain}
                        invalid={invalid}
                        uploading={uploading}
                        validating={validating}
                    />
                )}
            </div>
        );
    }

    // step 2
    return (
        <div>
            <Button onClick={() => setStep(1)}>Back</Button>
            {invalid ? (
                <Popup
                    className="transition-scale"
                    content={errorList}
                    position={"right center"}
                    trigger={
                        <div style={{ display: "inline-block" }}>
                            <CreateModelButton
                                onSubmit={onSubmit}
                                invalid={invalid}
                                uploading={uploading}
                                validating={validating}
                            />
                        </div>
                    }
                />
            ) : (
                <CreateModelButton
                    onSubmit={onSubmit}
                    invalid={invalid}
                    uploading={uploading}
                    validating={validating}
                />
            )}
        </div>
    );
};

const hasEditorErrors = (engine?: Engine, connection?: Connection, dataset?: Dataset) => {
    const editorErrors = [];

    if (!Boolean(engine)) {
        editorErrors.push(
            "No training engines are available. Ask your admin to create a training engine on the engines page.",
        );
    }

    if (!Boolean(connection)) {
        editorErrors.push("Must select a connection");
    }
    if (!Boolean(dataset)) {
        editorErrors.push("Must select a dataset");
    }

    return editorErrors;
};

/**
 * TODO: DO NOT COPY. This is an anti-pattern! It has been temporary employed to introduce react-query to the CMV component.
 * This component needs to be refactored significantly to support the correct pattern - which is to use react-query's
 * callbacks to handle errors and other side effects that happen during the API call's lifecycle.
 */
const postTrainModelConfig = async (props: {
    auth0TokenOptions: Auth0TokenOptions;
    config?: CreateModelConfig;
    description?: string;
    dataset?: Dataset;
    modelRepo?: ModelRepo;
    parentModelID?: number;
    engine?: Engine;
    setErrorMessage: React.Dispatch<React.SetStateAction<string | null>>;
}) => {
    const modelType = "ludwig";
    const apiServer = await createV1APIServer(props.auth0TokenOptions);
    return apiServer
        .post(
            "models/train",
            {
                modelType,
                description: props.description,
                datasetID: props.dataset?.id,
                config: props.config,
                repoID: props.modelRepo?.id,
                parentID: props.parentModelID,
                engineID: props.engine?.id,
            },
            {
                headers: {
                    "Content-Type": "application/x-www-form-urlencoded",
                },
            },
        )
        .then((res) => {
            const errorMessage = res.data.errorMessage || "";
            if (errorMessage) {
                props.setErrorMessage(errorMessage);
                throw new Error(errorMessage);
            }
            return errorMessage;
        })
        .catch((error) => {
            const errorMsg = getErrorMessage(error) ?? "";
            props.setErrorMessage(errorMsg);
            redirectIfSessionInvalid(errorMsg);
            throw new Error(errorMsg);
        });
};

const ModelDescriptionInput = (props: {
    modelDescription: string;
    nextVersion?: number;
    setModelDescription: React.Dispatch<React.SetStateAction<string>>;
    setup: string;
    style?: CSSProperties;
}) => (
    <Form.Input
        name={"description"}
        label={
            <label>
                {props.nextVersion ? `Description for Version ${props.nextVersion}` : "Description"}
                <br />
                <span style={{ color: SEMANTIC_GREY, fontWeight: "normal" }}>
                    What you intend to try in this experiment (ex. “first model” or “learning rate auto”)
                </span>
            </label>
        }
        style={props.style ? props.style : {}}
        placeholder="Description"
        disabled={props.setup !== "custom"}
        onChange={(event) => props.setModelDescription(event.target.value)}
        value={props.modelDescription}
    />
);

const getLLMDropdownOptions = (llmValue: string, config?: CreateModelConfig) => {
    const llmKVPairs = [
        { key: "mixtral-8x7b-instruct", value: "mistralai/Mixtral-8x7B-Instruct-v0.1" },
        { key: "mistral-7b", value: "mistralai/Mistral-7B-v0.1" },
        { key: "mistral-7b-instruct", value: "mistralai/Mistral-7B-Instruct-v0.1" },
        { key: "mistral-7b-instruct-v2", value: "mistralai/Mistral-7B-Instruct-v0.2" },
        { key: "yarn-mistral-7b-128k", value: "NousResearch/Yarn-Mistral-7b-128k" },
        { key: "zephyr-7b-beta", value: "HuggingFaceH4/zephyr-7b-beta" },
        { key: "llama-3-8b-instruct", value: "meta-llama/Meta-Llama-3-8B-Instruct" },
        { key: "llama-3-70b-instruct", value: "meta-llama/Meta-Llama-3-70B-Instruct" },
        { key: "llama-2-7b", value: "llama-2-7b" },
        { key: "llama-2-7b-chat", value: "llama-2-7b-chat" },
        { key: "codellama-13b-instruct", value: "codellama/CodeLlama-13b-Instruct-hf" },
        { key: "llama-2-13b", value: "llama-2-13b" },
        { key: "llama-2-13b-chat", value: "llama-2-13b-chat" },
        { key: "llama-2-70b", value: "llama-2-70b" },
        { key: "llama-2-70b-chat", value: "llama-2-70b-chat" },
        { key: "gemma-2b", value: "google/gemma-2b" },
        { key: "gemma-2b-instruct", value: "google/gemma-2b-it" },
        { key: "gemma-7b", value: "google/gemma-7b" },
        { key: "gemma-7b-instruct", value: "google/gemma-7b-it" },
    ];
    const llmValues = llmKVPairs.map((llkv) => llkv.value);
    if (config?.base_model && config?.base_model !== "" && !llmValues.includes(config?.base_model ?? "")) {
        llmValue = `Other (${config.base_model})`;
        llmKVPairs.push({ key: llmValue, value: config.base_model });
    }

    return llmKVPairs.map((llkv) => {
        return {
            key: llkv.key,
            text: llkv.key,
            value: llkv.value,
            disabled: false,
        };
    });
};

const getAdapterRDropdownOptions = (config?: CreateModelConfig) => {
    const adapterRankOptions = [
        { key: "8", text: "8 (recommended)", value: 8 },
        { key: "16", text: "16", value: 16 },
        { key: "32", text: "32", value: 32 },
        { key: "64", text: "64", value: 64 },
    ];
    if (config?.adapter?.r && !adapterRankOptions.filter((option) => option.value === config.adapter.r).length) {
        adapterRankOptions.push({
            key: config.adapter.r.toString(),
            text: `Other (${config.adapter.r.toString()})`,
            value: config.adapter.r,
        });
    }
    return adapterRankOptions;
};

const LLMParametersForm = (props: {
    nextVersion?: number;
    modelDescription: string;
    setModelDescription: React.Dispatch<React.SetStateAction<string>>;
}) => {
    const [user] = useRecoilState(USER_STATE);
    const dispatch = useDispatch();
    const { config, fields, promptTemplateFields } = useConfigState();

    // See https://docs.staging.predibase.com/user-guide/training/supported_models/
    let llmValue = config?.base_model ?? "";

    // Default to an adapter rank of 8:
    const adapterRank = config?.adapter?.r ?? 8;

    const outputFeatureOptions = fields.map((field) => ({ text: field.name, value: field.name }));

    return (
        <Form>
            <Grid>
                <Grid.Row columns={2} style={{ paddingBottom: 0, marginTop: `${16 / 14}rem` }}>
                    <Grid.Column>
                        <FieldWrapper
                            title="Large Language Model to Fine-tune"
                            description="Pick from a set of popular LLMs of different sizes across a variety of architecture types."
                            path="base_model"
                            schema={{}}
                            style={{ display: "inherit" }}
                        >
                            <Dropdown
                                path="base_model"
                                title="Large Language Model to Fine-tune"
                                error={false}
                                multiple={false}
                                fluid={true}
                                options={getLLMDropdownOptions(llmValue, config)}
                                value={llmValue}
                                defaultValue=""
                                setConfig={() => {}}
                                setLocalState={(path, value) => {
                                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: path, value: value });
                                }}
                            />
                        </FieldWrapper>
                    </Grid.Column>
                    <Grid.Column>
                        {llmValue.startsWith("Other") && (
                            <Message warning style={{ display: "block" }}>
                                <strong>Warning</strong>: Predibase provides best-effort support for any Huggingface
                                text model. You may encounter training failures when specifying a custom model off our
                                well-tested path.
                            </Message>
                        )}
                    </Grid.Column>
                </Grid.Row>
                <Grid.Row columns={2} style={{ paddingTop: 0, paddingBottom: 0 }}>
                    <Grid.Column>
                        <FieldWrapper
                            title="Output Feature"
                            description="The column from your dataset you want your fine-tuned LLM to generate."
                            path="base_model"
                            schema={{}}
                            style={{ display: "inherit" }}
                        >
                            <Dropdown
                                path="output_feature"
                                title="Output Feature"
                                error={false}
                                multiple={false}
                                fluid={true}
                                options={outputFeatureOptions}
                                value={config?.output_features?.[0]?.name ?? ""}
                                defaultValue={config?.output_features?.[0]?.name}
                                setConfig={() => {}}
                                setLocalState={(path, value) => {
                                    // TODO: Move this to a reducer
                                    const newOutputFeature = fields?.find((field) => field.name === value);
                                    if (!newOutputFeature || !newOutputFeature?.config) {
                                        return;
                                    }

                                    // Update existing output feature to input mode so it shows up in available features
                                    // for Prompt Template:
                                    const currentOutputFeatureIndex = fields.findIndex(
                                        (field) => field.name === config?.output_features?.[0]?.name,
                                    );
                                    if (currentOutputFeatureIndex >= 0) {
                                        dispatch({
                                            type: "UPDATE_FIELD",
                                            index: currentOutputFeatureIndex,
                                            value: {
                                                ...fields[currentOutputFeatureIndex],
                                                mode: "input",
                                            },
                                        });
                                    }

                                    const updatedConfig = {
                                        ...config,
                                        output_features: [
                                            {
                                                ...newOutputFeature?.config,
                                                type: "text",
                                                preprocessing: {
                                                    max_sequence_length: null,
                                                },
                                            },
                                        ],
                                    } as CreateModelConfig;
                                    dispatch({ type: "UPDATE_CONFIG", config: updatedConfig, isDirty: false });
                                }}
                            />
                        </FieldWrapper>
                    </Grid.Column>
                    <Grid.Column></Grid.Column>
                </Grid.Row>
                <Grid.Row columns={2} style={{ paddingTop: 0 }}>
                    <Grid.Column>
                        <FieldWrapper title="Prompt Template" path="prompt.template" isTextArea={true} schema={{}}>
                            <TextArea
                                path="prompt.template"
                                title="Prompt Template"
                                error={false}
                                value={config?.prompt?.template ?? ""}
                                defaultValue={""}
                                setConfig={() => {}}
                                setLocalState={(path, value) => {
                                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: path, value: value });
                                }}
                                schemaPath="prompt.template"
                                style={{ minHeight: `${160 / 14}rem`, width: `100%` }}
                            />
                        </FieldWrapper>
                        <FieldWrapper
                            title="Adapter Rank"
                            description="Increasing the adapter rank increases the capacity of your fine-tuned model. Higher model capacity may improve training performance, but may also increase GPU memory requirements and training duration."
                            path="adapter.r"
                            schema={{}}
                            style={{ display: "inherit" }}
                        >
                            <Dropdown
                                path="adapter.r"
                                title="Adapter Rank"
                                error={false}
                                multiple={false}
                                fluid={true}
                                options={getAdapterRDropdownOptions(config)}
                                value={adapterRank}
                                defaultValue=""
                                setConfig={() => {}}
                                setLocalState={(path, value) => {
                                    dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: path, value: value });
                                }}
                            />
                        </FieldWrapper>
                        <ModelDescriptionInput
                            nextVersion={props.nextVersion}
                            modelDescription={props.modelDescription}
                            setModelDescription={props.setModelDescription}
                            setup="custom"
                            style={{ width: `100%` }}
                        />
                    </Grid.Column>
                    <Grid.Column>
                        <Segment placeholder padded>
                            <Header as="h3" size="small">
                                About the Prompt Template
                            </Header>
                            <p>
                                Prompts are sent to the LLM to generate the desired response. Columns from your dataset
                                are incorporated into the prompt via the prompt template.
                            </p>
                            <strong>Required:</strong>
                            <ul>
                                <li>
                                    Include one (or more) columns from your dataset as a variable surrounded in curly
                                    brackets &#123;&#125; to indicate where to insert the input feature. Multiple
                                    columns can be inserted, e.g.: The &#123;color&#125; &#123;animal&#125; jumped over
                                    the &#123;size&#125; &#123;object&#125;, where every term in curly brackets is a
                                    column in the dataset.
                                </li>
                            </ul>
                            <strong>Additional tips:</strong>
                            <ul>
                                <li>
                                    Provide necessary boilerplate needed to make the LLM respond in the correct way (for
                                    example, with a response to a question rather than a continuation of the input
                                    sequence).
                                </li>
                                <li>
                                    Provide additional context to the model that can help it understand the task, or
                                    provide restrictions to prevent hallucinations.
                                </li>
                                <li>
                                    You can combine multiple columns from a dataset into a single text input feature
                                    (see TabLLM).
                                </li>
                            </ul>

                            <p>
                                <strong>Possible input columns from your dataset:</strong>{" "}
                                {fields
                                    .filter(
                                        (field) =>
                                            !promptTemplateFields.includes(field.name) && field.mode !== "output",
                                    )
                                    .map((field) => field.name)
                                    .join(", ")}
                            </p>

                            {
                                // eslint-disable-next-line react/jsx-no-target-blank
                                <a
                                    href={`${getDocsHome()}/user-guide/fine-tuning/prompt_templates/`}
                                    target="_blank"
                                    rel="noopener"
                                    onClick={() => {
                                        user &&
                                            track(user, "docs", {
                                                url: getDocsHome(),
                                                clickSource: "environments-view",
                                            });
                                    }}
                                >
                                    See prompt template examples <FontAwesomeIcon icon="arrow-up-right-from-square" />
                                </a>
                            }
                        </Segment>
                    </Grid.Column>
                </Grid.Row>
            </Grid>
            <Divider hidden />
        </Form>
    );
};

const LLMFirstPagePanel = (props: {
    dataset?: Dataset;
    engine?: Engine;
    setEngine: React.Dispatch<React.SetStateAction<Engine | undefined>>;
    nextVersion?: number;
    modelDescription: string;
    setModelDescription: React.Dispatch<React.SetStateAction<string>>;
    setup: string;
}) => {
    const [activeItem, setActiveItem] = useState("parameters");
    const handleMenuItemClick = (event: React.SyntheticEvent<HTMLElement, Event>, data: TabProps) => {
        setActiveItem(data.name);
    };

    return (
        <>
            <Header as="h2" size="medium">
                Features
            </Header>
            <Menu className={metrics.BLOCK_AUTO_CAPTURE} pointing secondary>
                <MenuItem
                    content="Parameters"
                    name="parameters"
                    active={activeItem === "parameters"}
                    onClick={handleMenuItemClick}
                />
                <MenuItem
                    content="Dataset Preview"
                    name="dataset"
                    active={activeItem === "dataset"}
                    onClick={handleMenuItemClick}
                />
                <MenuItem
                    content="Config (Advanced)"
                    name="advanced"
                    active={activeItem === "advanced"}
                    onClick={handleMenuItemClick}
                />
                <MenuItem
                    content="Compute"
                    name="compute"
                    active={activeItem === "compute"}
                    onClick={handleMenuItemClick}
                />
            </Menu>
            {activeItem === "parameters" && (
                <LLMParametersForm
                    nextVersion={props?.nextVersion}
                    modelDescription={props.modelDescription}
                    setModelDescription={props.setModelDescription}
                />
            )}
            {activeItem === "dataset" && <DatasetPreviewer dataset={props.dataset} />}
            {activeItem === "advanced" && <AdvancedConfigManager />}
            {activeItem === "compute" && (
                <ComputeConfigManager dataset={props.dataset} engine={props.engine} setEngine={props.setEngine} />
            )}
        </>
    );
};

const NeuralNetworkFirstPagePanel = (props: { dataset?: Dataset; loading: boolean }) => {
    const [activeItem, setActiveItem] = useState("fields");
    const handleMenuItemClick = (event: React.SyntheticEvent<HTMLElement, Event>, data: TabProps) => {
        setActiveItem(data.name);
    };

    return (
        <>
            <Header as="h2" size="medium">
                Features
            </Header>
            <Menu className={metrics.BLOCK_AUTO_CAPTURE} pointing secondary>
                <MenuItem
                    content="Feature Selection"
                    name="fields"
                    active={activeItem === "fields"}
                    onClick={handleMenuItemClick}
                />
                <MenuItem
                    content="Dataset Preview"
                    name="dataset"
                    active={activeItem === "dataset"}
                    onClick={handleMenuItemClick}
                />
            </Menu>
            {activeItem === "fields" && (
                <div>
                    <Divider hidden style={{ margin: "0.85715rem 0" }} />
                    <FieldsTable loading={props.loading} />
                </div>
            )}
            {activeItem === "dataset" && <DatasetPreviewer dataset={props.dataset} />}
        </>
    );
};

const displayModelType = (modelType: string) => {
    switch (modelType) {
        case ModelTypes.NEURAL_NETWORK:
            return "Neural Network";
        case ModelTypes.DECISION_TREE:
            return "Gradient Boosted Trees";
        case ModelTypes.LARGE_LANGUAGE_MODEL:
            return "Large Language Model";
    }
};

const CreateModelView = (props: {
    modelID?: number;
    modelRepoID?: number;
    editType?: string; //@TODO make a proper enum of "Fork", "Train" or "Create".
}) => {
    const [userContext] = useRecoilState(USER_STATE);
    let subscriptionTier;
    let isTrial = true;
    if (userContext) {
        const isKratosContext = isKratosUserContext(userContext);
        subscriptionTier = isKratosContext ? userContext.tenant.subscription.tier : userContext?.tenant.tier;
        isTrial = subscriptionTier === tier.FREE;
    }

    // Auth0 state:
    const auth0TokenOptions = useAuth0TokenOptions();

    // TODO: Tried to refactor all apiServer usages, but there is too much context split across multiple requests
    // and not enough time, so I'm resorting to this variable for those remaining bits.
    const [apiServer, setAPIServer] = useState<AxiosInstance | null>(null);
    useEffect(() => {
        const getAPIServer = async () => {
            const v1APIServer = await createV1APIServer(auth0TokenOptions);
            // NOTE: Whoever wrote the axios typings is a moron because the return type of axios.create is not
            // AxiosInstance -- it's a wrap function. And React will see that and treat it as a callback that
            // setState should directly call. FML.
            // See: [1], [2]:
            // [1]: https://github.com/axios/axios/issues/4365
            // [2]: https://stackoverflow.com/questions/64427195/calling-setstate-will-execute-the-function-value-instead-of-passing-it
            setAPIServer(() => v1APIServer);
        };
        getAPIServer();
    }, []);

    const { config, invalidFields, modelErrors } = useConfigState();
    const dispatch = useDispatch();
    const [setup, setSetup] = useState<string>("custom");

    const navigate = useNavigate();
    const { featureFlags } = useContext(FeatureFlagsContext);
    const [loading, setLoading] = useState(true);
    const [validating, setValidating] = useState(false);
    const [uploading, setUploading] = useState(false);
    const [step, setStep] = useState(1);
    const [errorMessage, setErrorMessage] = useState<string | null>("");

    // @TODO Should all of this be in another reducer?
    const [modelRepo, setModelRepo] = useState<ModelRepo>();
    const [modelsInRepoCount, setModelsInRepoCount] = useState<number>();
    const [modelName, setModelName] = useState<string>("");
    const [modelDescription, setModelDescription] = useState<string>("");
    const [parentModelName, setParentModelName] = useState<string>("");
    const [parentModelID, setParentModelID] = useState<number>();
    const [parentModelVersion, setParentModelVersion] = useState<number>(0);
    const modelRepoConfig = useRef<CreateModelConfig>();

    const [engine, setEngine] = useState<Engine>();
    const [connection, setConnection] = useState<Connection>();
    const [dataset, setDataset] = useState<Dataset>();

    const nextVersion = getNextModelVersion(modelsInRepoCount, modelRepo);

    const queryClient = useQueryClient();
    const { mutateAsync: mutatePostTrainModelConfig } = useMutation({
        mutationFn: postTrainModelConfig,
    });

    // Prefetch the dataset preview so that it's available when the user clicks the "Preview" button.
    if (dataset?.id && dataset?.status === "connected") {
        prefetchDatasetPreviewQuery(
            queryClient,
            dataset.id,
            defaultDatasetPreviewParams,
            commonGetDatasetPreviewQueryOptions(dataset),
            auth0TokenOptions,
        );
    }

    const isInvalid = () => {
        // Can't train models when plan expires
        if (userContext?.isExpired) {
            return true;
        }
        // Generic block for config errors
        if (modelErrors.length > 0) {
            return true;
        }
        // Can't train a model with unsupported fields
        if (Object.keys(invalidFields).length > 0) {
            return true;
        }
        // Can't train a model without an engine
        if (!Boolean(engine)) {
            return true;
        }
        // Can't train a model without a connection
        if (!Boolean(connection)) {
            return true;
        }

        // Can't train a model without a dataset
        return !Boolean(dataset);
    };

    const getErrorList = () => {
        if (userContext?.isExpired) {
            return renderValidationErrorList(
                [isTrial ? DisablingMessages.disabledTrialMessage : DisablingMessages.disabledPlanMessageModel],
                invalidFields,
            );
        }

        return renderValidationErrorList([...editorErrors], invalidFields);
    };

    // TODO: Refactor this component to use react-query's useQuery hook results
    const { data: enginesResponse } = useEnginesQuery({
        staleTime: 0,
        refetchOnWindowFocus: false,
    });
    // Use current engine if it's a training engine, otherwise use the first available training engine:
    useEffect(() => {
        if (enginesResponse === undefined) {
            return;
        }
        const trainingEngine = enginesResponse.currentEngines.train;
        if (engineCanTrainModels(trainingEngine)) {
            setEngine(trainingEngine);
        } else {
            setEngine(enginesResponse.engines.find((x) => engineCanTrainModels(x, false)));
        }
    }, [enginesResponse]);

    // In order to communicate to React Tracked that this component always
    // needs to be re-rendered when the config object changes, we create this
    // unused variable that stringifies the entire object.
    // The bug we were running into was that when submitting the config to the
    // server, we were submitting a stale version of the config that did not get
    // updated until the component was re-rendered.
    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    const configJSON = JSON.stringify(config); // DO NOT DELETE!
    const editorErrors = [...hasEditorErrors(engine, connection, dataset), ...modelErrors];
    const invalid = isInvalid();
    const errorList = getErrorList();

    const onSelectConnection = (connection: Connection) => {
        setConnection(connection);
        setDataset(undefined);
    };

    const selectDataset = (dataset: Dataset) => {
        setDataset(dataset);
    };

    const { mutate: mutateDetectDatasetConfig } = useMutation({
        mutationFn: () => detectDatasetConfig(dataset!.id, auth0TokenOptions),
        onMutate: () => {
            setLoading(true);
        },
        onSuccess: (data) => {
            const { config: defaultConfig, fields } = data;
            // Use the default LLM config for LLM models:
            let updatedConfig = isLLMModel(config) ? qlora_4bit : defaultConfig;
            // When loading config from the model repo, update the
            // config to represent the user's latest config:
            if (modelRepoConfig.current) {
                updatedConfig = modelRepoConfig.current;
                modelRepoConfig.current = undefined;
            }

            setErrorMessage(data?.errorMessage || "");
            dispatch({
                type: "INIT",
                config: updatedConfig,
                defaultConfig,
                fields,
                featureFlags,
            });
        },
        onError: (error) => {
            const errorMessage = getErrorMessage(error);
            setErrorMessage(errorMessage);
            dispatch({
                type: "INIT",
                config: undefined,
                defaultConfig: undefined,
                fields: [],
                featureFlags,
            });
        },
        onSettled: () => {
            setLoading(false);
        },
    });
    const detectConfig = () => {
        if (!dataset?.id) {
            // @TODO: Set an error message?
            return;
        }
        mutateDetectDatasetConfig();
    };

    // Keep store in sync with Feature Flags Context
    useEffect(() => {
        dispatch({
            type: "UPDATE_FEATURE_FLAGS",
            featureFlags,
        });
    }, [featureFlags]); // eslint-disable-line react-hooks/exhaustive-deps

    useEffect(() => {
        if (!connection || !dataset) {
            return;
        }

        detectConfig();
    }, [dataset]); // eslint-disable-line react-hooks/exhaustive-deps

    useEffect(() => {
        const isVersion = props.modelID;
        setLoading(true);

        if (isVersion) {
            apiServer
                ?.get(`models/version/${props.modelID}?withRuns=true`)
                .then((res: AxiosResponse<GetModelWithRunsResponse>) => {
                    if (!res.data.modelVersion) {
                        return;
                    }
                    const model = res.data.modelVersion;
                    setModelRepo(model.repo);
                    setModelName(model?.repo?.modelName ?? "");
                    setParentModelName(model?.repo?.modelName ?? "");
                    setParentModelID(props.modelID);
                    setParentModelVersion(model.repoVersion);
                    setModelDescription(`Trained from #${model.repoVersion}`);
                    setModelsInRepoCount(res.data.modelsInRepoCount);

                    if (model.dataset) {
                        setConnection(model.dataset.connection);
                        setDataset(model.dataset);
                        // Save this so that Detect Config uses previous
                        // config instead of default:
                        modelRepoConfig.current = model.config;
                    }
                    setErrorMessage(res.data.errorMessage || "");
                })
                .catch((error) => {
                    const errorMsg = getErrorMessage(error) ?? "";
                    redirectIfSessionInvalid(errorMsg);
                    setErrorMessage(errorMsg);
                })
                .finally(() => {
                    setLoading(false);
                });

            return;
        }

        apiServer
            ?.get(`models/repo/${props.modelRepoID}`, { params: { withVersions: true } })
            .then((res: AxiosResponse<GetModelRepoWithVersionsResponse>) => {
                if (!res.data.modelRepo) {
                    return;
                }
                const repo = res.data.modelRepo as ModelRepo;
                setModelRepo(repo);
                setModelName(repo.modelName);
                setParentModelID(repo.parentID);
                setSetup("suggest");
                setErrorMessage(res.data.errorMessage || "");
                if (props.editType !== "fork") {
                    setParentModelName(repo.modelName);
                    if (repo.dataset) {
                        setConnection(repo.dataset.connection);
                        setDataset(repo.dataset);
                        // Save this so that Detect Config uses previous
                        // config instead of default:
                        modelRepoConfig.current = repo.latestConfig;
                    }
                }

                if (repo.parentID) {
                    apiServer
                        .get(`models/repo/${repo.parentID}`)
                        .then((res2: AxiosResponse<GetModelRepoWithVersionsResponse>) => {
                            const parentRepo = res2.data.modelRepo as ModelRepo;
                            if (props.editType === "fork") {
                                setParentModelName(parentRepo.modelName);
                                setParentModelVersion(parentRepo.repoVersion);
                            }
                            if (props.editType === "fork" || repo.dataset == null) {
                                setConnection(parentRepo.dataset?.connection);
                                // TODO: Sloppy typing (why can this be null OR undefined?):
                                setDataset(parentRepo.dataset as Dataset | undefined);
                            }
                        })
                        .catch((error2) => {
                            const errorMsg = getErrorMessage(error2);
                            setErrorMessage(errorMsg);
                        });
                }
            })
            .catch((error) => {
                const errorMsg = getErrorMessage(error) ?? "";
                redirectIfSessionInvalid(errorMsg);
                setErrorMessage(errorMsg);
            })
            .finally(() => {
                setLoading(false);
            });
        // @TODO: empty this dependency array so that useEffect only runs once on page load?
    }, [props.modelID, props.modelRepoID, apiServer]); // eslint-disable-line react-hooks/exhaustive-deps

    const trainModel = (config?: CreateModelConfig, description?: string) => {
        const modelType = "ludwig";
        setUploading(true);

        return apiServer
            ?.post(
                "models/train",
                {
                    modelType,
                    description: description,
                    datasetID: dataset?.id,
                    config: config,
                    repoID: modelRepo?.id,
                    parentID: parentModelID,
                    engineID: engine?.id,
                },
                {
                    headers: {
                        "Content-Type": "application/x-www-form-urlencoded",
                    },
                },
            )
            .then((res) => {
                const errorMessage = res.data.errorMessage || "";
                setErrorMessage(errorMessage);
                // Training uses credits; invalidate the credit cache
                // noinspection JSIgnoredPromiseFromCall
                queryClient.invalidateQueries({ queryKey: GET_USER_CREDITS_QUERY_KEY });
            })
            .catch((error) => {
                const errorMsg = getErrorMessage(error) ?? "";
                setErrorMessage(errorMsg);
                redirectIfSessionInvalid(errorMsg);
            });
    };

    /**
     * TODO: DO NOT COPY. This is an anti-pattern! It has been temporary employed to introduce react-query to the CMV component.
     * This component needs to be refactored significantly to support the correct pattern - which is to use react-query's
     * callbacks to handle errors and other side effects that happen during the API call's lifecycle.
     */
    const onSubmit = () => {
        setValidating(true); // TODO: this separate state is kinda weird now
        setUploading(true);
        return mutatePostTrainModelConfig({
            auth0TokenOptions: auth0TokenOptions,
            config,
            description: modelDescription,
            dataset,
            modelRepo,
            parentModelID,
            engine,
            setErrorMessage,
        })
            .then(() => {
                navigate(`/models/repo/${modelRepo?.id}`);
            })
            .catch((error) => {
                const errorMsg = getErrorMessage(error) ?? "";
                redirectIfSessionInvalid(errorMsg);
                setErrorMessage(errorMsg);
            })
            .finally(() => {
                setUploading(false);
                setValidating(false);
            });
    };

    /**
     * TODO: Refactor to use react-query instead of axios directly.
     */
    const onSubmitMultiTrain = () => {
        if (!dataset) {
            setErrorMessage("Please select a dataset to train with.");
            return;
        }
        if (!config) {
            setErrorMessage("No config provided.");
            return;
        }

        setUploading(true);
        apiServer
            ?.post(
                "config/suggest",
                {
                    config: config,
                },
                {
                    headers: {
                        "Content-Type": "application/x-www-form-urlencoded",
                    },
                },
            )
            .then((res) => {
                const { configs } = res.data;
                setErrorMessage(res?.data?.errorMessage || "");
                if (!Array.isArray(configs)) {
                    throw new Error("Could not derive suggested configs! Please build a custom model instead.");
                }

                const trainModelPromises: any[] = [];
                configs.forEach((suggestedConfig: ConfigSuggestion) => {
                    trainModelPromises.push(trainModel(suggestedConfig.config, suggestedConfig.description));
                });

                Promise.all(trainModelPromises)
                    .then(() => {
                        navigate(`/models/repo/${modelRepo?.id}`);
                    })
                    .catch((error) => {
                        const errorMsg = getErrorMessage(error) ?? "";
                        setErrorMessage(errorMsg);
                        redirectIfSessionInvalid(errorMsg);
                    })
                    .finally(() => {
                        setUploading(false);
                    });
            })
            .catch((error) => {
                setUploading(false);
                const errorMessage = getErrorMessage(error);
                redirectIfSessionInvalid(errorMessage ?? "");
                setErrorMessage(errorMessage);
            });
    };

    return (
        <div style={{ padding: "20px", height: "100vh", maxHeight: "100vh", overflowY: "scroll" }}>
            <div className={"builder-header"}>
                <Breadcrumbs
                    editType={props.editType}
                    modelRepoID={props.modelRepoID}
                    modelRepo={modelRepo}
                    parentModelName={parentModelName}
                    parentModelTag={parentModelVersion}
                />
                <SubscriptionButton isExpired={userContext?.isExpired} currentTier={subscriptionTier} />
            </div>
            {userContext?.isExpired && (
                <Message warning color="yellow" className="disabled-warning">
                    <p style={{ color: SEMANTIC_DARK_YELLOW }}>
                        <Icon name={"warning sign"} />
                        <b>You may no longer train models.</b>{" "}
                        <a href="https://predibase.com/contact-us" style={{ fontWeight: "bold" }}>
                            Contact us
                        </a>{" "}
                        to upgrade.
                    </p>
                </Message>
            )}
            <Divider hidden />
            <div style={{ display: "flex", justifyContent: "space-between" }}>
                <Header className="header" as="h2">
                    {getModelHeader(modelName, props.editType, modelsInRepoCount, modelRepo, parentModelName)}
                </Header>
                {loading ? null : (
                    <CreateModelNavigation
                        config={config}
                        invalid={invalid}
                        step={step}
                        setup={setup}
                        onSubmit={onSubmit}
                        uploading={uploading}
                        validating={validating}
                        errorList={errorList}
                        setStep={setStep}
                        onSubmitMultiTrain={onSubmitMultiTrain}
                    />
                )}
            </div>
            {errorMessage && (
                <Message negative>
                    <Message.Header>Error in model configuration</Message.Header>
                    <p>{errorMessage}</p>
                </Message>
            )}
            {loading ? (
                <Loader active size="large" />
            ) : step === 1 ? (
                <>
                    <Message warning color="yellow" className="fine-tuning-warning">
                        <p style={{ color: SEMANTIC_DARK_YELLOW }}>
                            <b>Fine-tuning has moved!</b> Check out our new and improved LLM fine-tuning experience in{" "}
                            <Link to="/adapters">Adapters</Link>.
                        </p>
                    </Message>
                    <Header as="h2" size="medium">
                        ML Task
                    </Header>
                    <ModelTypeSelector modelType={getModelType(config)} />
                    {!isLLMModel(config) && (
                        <div
                            style={{
                                display: "block",
                                marginLeft: `${72 / 14}rem`,
                                marginTop: `${16 / 14}rem`,
                                marginBottom: `${24 / 14}rem`,
                            }}
                        >
                            <SetupTypeSelector setup={setup} setSetup={setSetup} />
                        </div>
                    )}
                    <Divider />
                    <Form>
                        {!isLLMModel(config) && (
                            <Form.Group widths="equal">
                                <ModelDescriptionInput
                                    nextVersion={nextVersion}
                                    modelDescription={modelDescription}
                                    setModelDescription={setModelDescription}
                                    setup={setup}
                                />
                                <Form.Input type="hidden" />
                            </Form.Group>
                        )}
                        <Divider hidden />
                        <Form.Group widths="equal">
                            <ConnectionSelector selectedConnection={connection} onSelectConnection={onSelectConnection} />
                            <DatasetSelector connection={connection} dataset={dataset} selectDataset={selectDataset} />
                        </Form.Group>
                        <Divider hidden />
                        {!isLLMModel(config) && (
                            <Form.Group widths="equal">
                                <Targets />
                                <Form.Input type="hidden" />
                            </Form.Group>
                        )}
                    </Form>
                    <Divider style={{ marginTop: `${24 / 14}rem` }} />
                    {!dataset ? null : dataset?.status === "connecting" ||
                      dataset?.status === "refreshing" ||
                      dataset?.status === "errored" ? (
                        <Message error>
                            Dataset has invalid state: {dataset?.status}. Please check the status in the datasets table
                            and try again later.
                        </Message>
                    ) : isLLMModel(config) ? (
                        <LLMFirstPagePanel
                            dataset={dataset}
                            engine={engine}
                            setEngine={setEngine}
                            nextVersion={nextVersion}
                            modelDescription={modelDescription}
                            setModelDescription={setModelDescription}
                            setup={setup}
                        />
                    ) : (
                        <NeuralNetworkFirstPagePanel dataset={dataset} loading={loading} />
                    )}
                </>
            ) : (
                <>
                    <div style={{ display: "flex", justifyContent: "space-between" }}>
                        <dl
                            style={{
                                display: "grid",
                                gridTemplateColumns: "7.4286rem 1fr",
                                gridAutoRows: "max-content",
                            }}
                        >
                            <dt style={{ fontWeight: 400, color: SEMANTIC_GREY_DISABLED }}>Description</dt>
                            <dd style={{ minHeight: "1.575rem", minWidth: `${150 / 14}rem` }}>
                                {
                                    <EditDiv
                                        ogText={modelDescription}
                                        placeholder={"--"}
                                        setStateFunc={setModelDescription}
                                        fitted
                                        asTableCell
                                    />
                                }
                            </dd>
                            <dt style={{ fontWeight: 400, color: SEMANTIC_GREY_DISABLED }}>Engine</dt>
                            <dd>{engine?.name}</dd>
                            <dt style={{ fontWeight: 400, color: SEMANTIC_GREY_DISABLED }}>Dataset</dt>
                            <dd>{dataset?.name}</dd>
                            <dt style={{ fontWeight: 400, color: SEMANTIC_GREY_DISABLED }}>Target</dt>
                            <dd>
                                {Array.isArray(config?.output_features)
                                    ? config?.output_features.map((feature) => feature?.name).join(", ")
                                    : ""}
                            </dd>
                            <dt style={{ fontWeight: 400, color: SEMANTIC_GREY_DISABLED }}>Model Type</dt>
                            <dd>{displayModelType(getModelType(config))}</dd>
                        </dl>
                        {getModelType(config) === ModelTypes.NEURAL_NETWORK && <SelectDefaults dataset={dataset} />}
                    </div>
                    <ModelTabPanel dataset={dataset} engine={engine} setEngine={setEngine} />
                </>
            )}
        </div>
    );
};

export default CreateModelView;
