import { useEffect, useState } from "react";
import { FormField, Radio, StrictDropdownProps } from "semantic-ui-react";
import { baseModel } from "../../../api_generated";
import Markdown from "../../../components/Markdown";
import { handleLocalState } from "../../../models/create/forms/utils";
import { SEMANTIC_GREY_DISABLED } from "../../../utils/colors";
import { adapterTypes } from "./utils";

const continueTrainingFromAdapterTypeMap: Record<adapterTypes, adapterTypes[]> = {
    lora: ["lora", "turbo"] as const,
    turbo_lora: ["turbo_lora"] as const,
    turbo: ["lora", "turbo"] as const,
};

const parentAndChildAdapterTypesAreCompatibleForContinuedTraining = (
    parentAdapterType?: adapterTypes,
    childAdapterType?: adapterTypes,
) => {
    if (parentAdapterType === undefined || childAdapterType === undefined) {
        return true;
    }
    return continueTrainingFromAdapterTypeMap[parentAdapterType].includes(childAdapterType);
};

const canContinueTraining = (
    parentAdapterType?: adapterTypes,
    childAdapterType?: adapterTypes,
    baseModel?: baseModel,
) => {
    const baseModelSupportedAdapterTypes = baseModel?.trainingParams.adapterType.allValues;

    if (childAdapterType === undefined) {
        return false;
    }
    if (!parentAndChildAdapterTypesAreCompatibleForContinuedTraining(parentAdapterType, childAdapterType)) {
        return false;
    }
    if (baseModelSupportedAdapterTypes !== undefined && !baseModelSupportedAdapterTypes.includes(childAdapterType)) {
        return false;
    }
    return true;
};

const AdapterTypeLabel = (props: { title: string; description: string | JSX.Element; disabled?: boolean }) => {
    // Parent props:
    const { title, description, disabled } = props;

    return (
        <div
            style={{
                marginLeft: `${8 / 14}rem`,
                color: disabled ? SEMANTIC_GREY_DISABLED : undefined,
                opacity: disabled ? 0.7 : 1.0,
            }}
        >
            {title}
            {description && <Markdown children={description} secondary={true} />}
        </div>
    );
};

const getAdapterDropdownOptions = (
    parentAdapterType?: adapterTypes, // Only provided in continued training mode
    baseModel?: baseModel,
    readonly?: boolean,
) => {
    const adapterOptions = [
        {
            key: "lora",
            text: "LoRA",
            value: "lora",
            trigger: (
                <AdapterTypeLabel
                    title="LoRA"
                    description="Efficient method for training LLMs that introduces a small subset of task-specific model parameters alongside the original model parameters."
                    disabled={readonly || !canContinueTraining(parentAdapterType, "lora", baseModel)}
                />
            ),
            disabled: readonly || !canContinueTraining(parentAdapterType, "lora", baseModel),
        },
    ];

    if (parentAdapterType === undefined) {
        adapterOptions.push({
            key: "turbo_lora",
            text: "Turbo LoRA",
            value: "turbo_lora",
            trigger: (
                <AdapterTypeLabel
                    title="Turbo LoRA"
                    description="LoRA adapter with a custom speculator. Improve inference speed via our proprietary fine-tuning method. Training jobs will take longer to train and are
                    priced at 2 times the standard fine-tuning pricing. Note that some base models will require [special deployment configurations](https://docs.predibase.com/user-guide/inference/private_deployments/#turbo-lora-and-turbo-adapters)
                    for inference. [Learn more.](https://docs.predibase.com/user-guide/fine-tuning/turbo_lora)"
                    disabled={readonly || !canContinueTraining(parentAdapterType, "turbo_lora", baseModel)}
                />
            ),
            disabled: readonly || !canContinueTraining(parentAdapterType, "turbo_lora", baseModel),
        });
    } else {
        adapterOptions.push({
            key: "turbo",
            text: "[New] Turbo",
            value: "turbo",
            trigger: (
                <AdapterTypeLabel
                    title="Turbo"
                    description="Custom speculator for an existing adapter. Improve inference speed up without affecting the output. Note that some base models will require
                    [special deployment configurations](https://docs.predibase.com/user-guide/inference/private_deployments/#turbo-lora-and-turbo-adapters)
                    for inference. [Learn more.](https://docs.predibase.com/user-guide/fine-tuning/turbo_lora#turbo-new)"
                    disabled={readonly || !canContinueTraining(parentAdapterType, "turbo", baseModel)}
                />
            ),
            disabled: readonly || !canContinueTraining(parentAdapterType, "turbo", baseModel),
        });
    }

    return adapterOptions;
};

export const AdapterTypeRadioGroup = (props: {
    path: string;
    value: StrictDropdownProps["value"] | undefined;
    setConfig: (path: string, typedValue: any) => void;
    setLocalState: (localState: any, path: string) => void;
    selectedBaseModel: baseModel | undefined;
    continueFromAdapterType?: adapterTypes;
    readonly?: boolean;
}) => {
    // Parent props:
    const { path, value, setConfig, setLocalState, selectedBaseModel, continueFromAdapterType, readonly } = props;

    // Local state:
    const [selectedValue, setSelectedValue] = useState(value);
    // ! NOTE: On first render, the value prop is sometimes undefined when it logically shouldn't be (e.g. when viewing
    // trained adapters there is a delay in the context receving the up-to-date value), so we cannot reliably use the
    // setter inside of the onChange handler. Instead, we must use a useEffect that listens for changes to the prop:
    useEffect(() => {
        setSelectedValue(value);
    }, [value]);

    // Derived state:
    const options = getAdapterDropdownOptions(continueFromAdapterType, selectedBaseModel, readonly);

    return (
        <>
            {options.map((item) => (
                <FormField key={`radioGroupAlt_FormField_${item.key as string}`}>
                    <div style={{ display: "flex" }}>
                        <Radio
                            checked={selectedValue === item.value}
                            onChange={() => {
                                handleLocalState(path, item.value, setLocalState, setConfig);
                            }}
                            disabled={readonly || item.disabled}
                        />
                        <>{item.trigger}</>
                    </div>
                </FormField>
            ))}
        </>
    );
};

export default AdapterTypeRadioGroup;
