import urlJoin from "url-join";
import { baseModel, trainingParams } from "../../../api_generated";
import { sortBaseModels } from "../../../deployments/misc/dropdown-utils";
import { getDocsHome } from "../../../utils/api";
import { SEMANTIC_GREY_DISABLED } from "../../../utils/colors";
import { AdapterConfig } from "./schema";

export const getModelValidationErrors = (invalidFields: any, connection?: Connection, dataset?: Dataset) => {
    let modelValidationErrors: string[] = [];

    if (!connection) {
        modelValidationErrors.push(`Must select a connection and dataset.`);
    }
    if (connection && !dataset) {
        modelValidationErrors.push(`Must select a dataset.`);
    }
    if (invalidFields) {
        Object.keys(invalidFields).forEach((field: string) => {
            let fieldError = invalidFields[field];
            if (field === "base_model") {
                modelValidationErrors.push(`Must specify a Base Model.`);
                return;
            }
            modelValidationErrors.push(`${field} ${fieldError.errorMessages.join(", ")}.`);
        });
    }

    return modelValidationErrors;
};

export const getSelectedLLM = (llmValue?: string, baseModels?: baseModel[]) => {
    baseModels = Array.isArray(baseModels) ? baseModels : [];

    let matchedLlms = baseModels.filter((baseModel) => baseModel.name === llmValue);
    if (matchedLlms.length) {
        return matchedLlms[0];
    }
    return undefined;
};

export const getLLMDropdownOptions = (llmValue?: string, baseModels?: baseModel[], config?: AdapterConfig) => {
    baseModels = Array.isArray(baseModels) ? baseModels : [];
    baseModels = sortBaseModels(baseModels) ?? [];

    let llmOptions = baseModels.map((baseModel) => {
        return {
            key: baseModel.name,
            text: baseModel.name,
            value: baseModel.name,
            disabled: false,
        };
    });

    if (
        config?.base_model &&
        config?.base_model !== "" &&
        !llmOptions.filter((option) => option.value === config.base_model).length
    ) {
        llmValue = `Other (${config.base_model})`;
        llmOptions.push({
            key: llmValue,
            text: llmValue,
            value: config.base_model,
            disabled: false,
        });
    }
    return llmOptions;
};

const AdapterTypeLabel = (props: { title: string; description: string | JSX.Element; disabled?: boolean }) => {
    // Parent props:
    const { title, description, disabled } = props;

    return (
        <div
            style={{
                marginLeft: "1rem",
                color: disabled ? SEMANTIC_GREY_DISABLED : undefined,
                opacity: disabled ? 0.7 : 1.0,
            }}
        >
            <b>{title}</b>
            <br />
            {description}
        </div>
    );
};

const TurboLoraDocsLink = () => {
    return (
        <a
            href={urlJoin(getDocsHome(), "user-guide/fine-tuning/adapters#turbo-lora-")}
            target="_blank"
            rel="noreferrer"
        >
            Learn more.
        </a>
    );
};

const TurboDocsLink = () => {
    return (
        <a href={urlJoin(getDocsHome(), "user-guide/fine-tuning/adapters#turbo-new")} target="_blank" rel="noreferrer">
            Learn more.
        </a>
    );
};

// TODO: Ideally the adapter types should be promoted to an enum in the OpenAPI spec:
export type adapterTypes = trainingParams["adapterType"]["allValues"][0];
type taskTypes = trainingParams["taskType"]["allValues"][0];
const continueTrainingFromAdapterTypeMap: Record<string, adapterTypes[]> = {
    lora: ["lora", "turbo"] as const, // TODO: Is this correct?
    turbo_lora: ["turbo_lora"] as const,
    turbo: ["lora", "turbo"] as const,
};

const adapterTypeSupportsContinuedTraining = (continueFromAdapterType?: adapterTypes, adapterType?: adapterTypes) => {
    if (continueFromAdapterType === undefined || adapterType === undefined) {
        return true;
    }
    return continueTrainingFromAdapterTypeMap[continueFromAdapterType].includes(adapterType);
};

export const getAdapterDropdownOptions = (
    continueFromAdapterType?: adapterTypes,
    baseModel?: baseModel,
    readonly?: boolean,
) => {
    const baseModelSupportedAdapterTypes = baseModel?.trainingParams.adapterType.allValues;
    const canContinueTraining = (adapterType: adapterTypes) => {
        if (!adapterTypeSupportsContinuedTraining(continueFromAdapterType, adapterType)) {
            return false;
        }
        if (baseModelSupportedAdapterTypes && !baseModelSupportedAdapterTypes.includes(adapterType)) {
            return false;
        }
        return true;
    };

    const adapterOptions = [
        {
            key: "lora",
            text: "LoRA",
            value: "lora",
            trigger: (
                <AdapterTypeLabel
                    title="LoRA"
                    description="Efficient method for training LLMs that introduces a small subset of task-specific model parameters alongside the original model parameters."
                    disabled={readonly || !canContinueTraining("lora")}
                />
            ),
            disabled: readonly || !canContinueTraining("lora"),
        },
    ];

    if (continueFromAdapterType === undefined) {
        adapterOptions.push({
            key: "turbo_lora",
            text: "Turbo LoRA",
            value: "turbo_lora",
            trigger: (
                <AdapterTypeLabel
                    title="Turbo LoRA"
                    description={
                        <span>
                            LoRA adapter with a custom speculator. Improve inference speed via our proprietary
                            fine-tuning method that builds on LoRA Training jobs will take longer to train and are
                            priced at 2 times the standard fine-tuning pricing. <TurboLoraDocsLink />
                        </span>
                    }
                    disabled={readonly || !canContinueTraining("turbo_lora")}
                />
            ),
            disabled: readonly || !canContinueTraining("turbo_lora"),
        });
    } else {
        adapterOptions.push({
            key: "turbo",
            text: "[New] Turbo",
            value: "turbo",
            trigger: (
                <AdapterTypeLabel
                    title="Turbo"
                    description={
                        <>
                            Custom speculator for an existing adapter or base model. Improve inference speed up without
                            affecting the output. <TurboDocsLink />
                        </>
                    }
                    disabled={readonly || !canContinueTraining("turbo")}
                />
            ),
            disabled: readonly || !canContinueTraining("turbo"),
        });
    }

    return adapterOptions;
};

export const getTaskDropdownOptions = (baseModel?: baseModel) => {
    const allTaskOptions = [
        { key: "instruction_tuning", text: "Instruction Tuning (default)", value: "instruction_tuning" },
        { key: "completion", text: "Completion (beta)", value: "completion" },
    ];
    const filteredTaskOptions = allTaskOptions.filter((option) => {
        return baseModel !== undefined
            ? baseModel.trainingParams.taskType.allValues.includes(option.value as taskTypes)
            : option;
    });

    return filteredTaskOptions;
};

export const getAdapterRankDropdownOptions = (config?: AdapterConfig) => {
    const adapterRankOptions = [
        { key: "16", text: "16 (default)", value: 16 },
        { key: "8", text: "8", value: 8 },
        { key: "32", text: "32", value: 32 },
        { key: "64", text: "64", value: 64 },
    ];
    if (config?.rank && !adapterRankOptions.filter((option) => option.value === config.rank).length) {
        adapterRankOptions.push({
            key: config.rank.toString(),
            text: `Other (${config.rank.toString()})`,
            value: config.rank,
        });
    }
    return adapterRankOptions;
};

export const getTargetModulesOptions = (baseModel?: baseModel, target_modules?: string[]) => {
    if (!baseModel) {
        if (target_modules) {
            return target_modules.map((option: string) => {
                return {
                    key: option,
                    text: option,
                    value: option,
                    disabled: false,
                };
            });
        }
        return [];
    }

    let targetModulesOptions = baseModel.trainingParams.targetModules.allValues.map((option) => {
        return {
            key: option,
            text: option,
            value: option,
            disabled: false,
        };
    });
    return targetModulesOptions;
};

export const getTargetModulesDefaultValue = (baseModel?: baseModel) => {
    if (!baseModel) {
        return [];
    }

    return baseModel.trainingParams.targetModules.defaultValue;
};
