import { baseModel } from "../../../api_generated";
import { sortBaseModels } from "../../../deployments/misc/dropdown-utils";

export const getModelValidationErrors = (invalidFields: any, connection?: Connection, dataset?: Dataset) => {
    let modelValidationErrors: string[] = [];

    if (!connection) {
        modelValidationErrors.push(`Must select a connection and dataset.`);
    }
    if (connection && !dataset) {
        modelValidationErrors.push(`Must select a dataset.`);
    }
    if (invalidFields) {
        Object.keys(invalidFields).forEach((field: string) => {
            let fieldError = invalidFields[field];
            if (field === "base_model") {
                modelValidationErrors.push(`Must specify a Base Model.`);
                return;
            }
            modelValidationErrors.push(`${field} ${fieldError.errorMessages.join(", ")}.`);
        });
    }

    return modelValidationErrors;
};

export const getSelectedLLM = (llmValue: string, baseModels?: baseModel[]) => {
    baseModels = Array.isArray(baseModels) ? baseModels : [];

    let matchedLlms = baseModels.filter((baseModel) => baseModel.name === llmValue);
    if (matchedLlms.length) {
        return matchedLlms[0];
    }
    return undefined;
};

export const getLLMDropdownOptions = (llmValue: string, baseModels?: baseModel[], config?: any) => {
    baseModels = Array.isArray(baseModels) ? baseModels : [];
    baseModels = sortBaseModels(baseModels) ?? [];

    let llmOptions = baseModels.map((baseModel) => {
        return {
            key: baseModel.name,
            text: baseModel.name,
            value: baseModel.name,
            disabled: false,
        };
    });

    if (
        config?.base_model &&
        config?.base_model !== "" &&
        !llmOptions.filter((option) => option.value === config.base_model).length
    ) {
        llmValue = `Other (${config.base_model})`;
        llmOptions.push({
            key: llmValue,
            text: llmValue,
            value: config.base_model,
            disabled: false,
        });
    }
    return llmOptions;
};

export const getAdapterDropdownOptions = (config?: any) => {
    const adapterOptions = [
        { key: "lora", text: "LoRA", value: "lora" },
        { key: "turbo_lora", text: "Turbo LoRA", value: "turbo_lora" },
        { key: "turbo", text: "[New] Turbo", value: "turbo" },
    ];

    return adapterOptions;
}

export const getTaskDropdownOptions = (config?: any) => {
    const taskOptions = [
        { key: "instruction_tuning", text: "Instruction Tuning (default)", value: "instruction_tuning" },
        { key: "completion", text: "Completion (beta)", value: "completion" },
    ];

    return taskOptions;
};

export const getAdapterRankDropdownOptions = (config?: any) => {
    const adapterRankOptions = [
        { key: "16", text: "16 (default)", value: 16 },
        { key: "8", text: "8", value: 8 },
        { key: "32", text: "32", value: 32 },
        { key: "64", text: "64", value: 64 },
    ];
    if (config?.rank && !adapterRankOptions.filter((option) => option.value === config.rank).length) {
        adapterRankOptions.push({
            key: config.rank.toString(),
            text: `Other (${config.rank.toString()})`,
            value: config.rank,
        });
    }
    return adapterRankOptions;
};

export const getTargetModulesOptions = (baseModel?: baseModel, target_modules?: string[]) => {
    if (!baseModel) {
        if (target_modules) {
            return target_modules.map((option: string) => {
                return {
                    key: option,
                    text: option,
                    value: option,
                    disabled: false,
                };
            });
        }
        return [];
    }

    let targetModulesOptions = baseModel.trainingParams.targetModules.allValues.map((option) => {
        return {
            key: option,
            text: option,
            value: option,
            disabled: false,
        };
    });
    return targetModulesOptions;
};

export const getTargetModulesDefaultValue = (baseModel?: baseModel) => {
    if (!baseModel) {
        return [];
    }

    return baseModel.trainingParams.targetModules.defaultValue;
};
