import { DropdownItemProps } from "semantic-ui-react";
import {
    accelerator,
    acceleratorId,
    baseModel,
    deploymentAcceleratorOption,
    deploymentQuantization,
    tier,
} from "../../api_generated";
import TruncatedText from "../../components/TruncatedText";
import { SEMANTIC_GREY_DISABLED } from "../../utils/colors";
import { centsToDollars } from "../../utils/strings";
import { formatAcceleratorID } from "./utils";

// TODO: Refine this numParams check later:
export const isEnterpriseTierModel = (baseModel: baseModel) => {
    let accelerators: deploymentAcceleratorOption[] = [];
    if (baseModel?.accelerators?.serving) {
        Object.entries(baseModel.accelerators.serving).forEach(([quantizationType, availableAccelerators]) => {
            accelerators = [...accelerators, ...availableAccelerators];
        });
    }

    return (
        accelerators.filter(
            (accelerator) =>
                accelerator.availability.tiers.includes(tier.PREMIUM) ||
                accelerator.availability.tiers.includes(tier.FREE),
        ).length === 0
    );
};

export const isEnterpriseTierAccelerator = (accelerator: accelerator) => {
    return (
        !accelerator.availability.tiers.includes(tier.FREE) && !accelerator.availability.tiers.includes(tier.PREMIUM)
    );
};

export const isVPCOnlyAccelerator = (accelerator: accelerator) => {
    return accelerator.availability.tiers.length === 1 && accelerator.availability.tiers.includes(tier.ENTERPRISE_VPC);
};

const modelFamilyOrder = [
    "solar",
    "llama",
    "mistral",
    "mixtral",
    "zephyr",
    "gemma",
    "phi",
    "qwen2",
    "qwen",
    "codellama",
];

// TODO: This needs to be explicitly defined in llms yaml file:
const getModelVersion = (baseModelName: string) => {
    // Llama family is - separated:
    if (baseModelName.includes("-3-1-")) {
        return 3.1;
    }
    if (baseModelName.includes("-3-")) {
        return 3;
    }
    if (baseModelName.includes("-2-")) {
        return 2;
    }
    // Mixtral family is prepended with "-v":
    if (baseModelName.includes("-v0-3")) {
        return 3;
    }
    if (baseModelName.includes("-v0-2")) {
        return 2;
    }
    if (baseModelName.includes("-v0-1")) {
        return 1;
    }
    // Qwen family has version after family name:
    if (baseModelName.includes("qwen2")) {
        return 2;
    }
    if (baseModelName.includes("qwen")) {
        return 1;
    }
    // Phi-2 has one model size, so end of string is version:
    if (baseModelName.endsWith("-2")) {
        return 2;
    }
    // Phi-3 has two versions and mini model sizes:
    if (baseModelName.endsWith("-3-5-")) {
        return 3.5;
    }
    return 1;
};

export const sortBaseModels = (baseModels?: baseModel[]) => {
    // Sort by model family (popularity), version, and number of parameters:
    return baseModels?.sort((a, b) => {
        const modelFamilyA = a.name.split("-");
        const modelFamilyB = b.name.split("-");

        // Sort by model family first:
        if (modelFamilyA[0] !== modelFamilyB[0]) {
            let modelFamilyAIndex = modelFamilyOrder.indexOf(modelFamilyA[0]);
            let modelFamilyBIndex = modelFamilyOrder.indexOf(modelFamilyB[0]);

            if (modelFamilyAIndex === -1) {
                modelFamilyAIndex = modelFamilyOrder.length;
            }
            if (modelFamilyBIndex === -1) {
                modelFamilyBIndex = modelFamilyOrder.length;
            }

            if (modelFamilyAIndex === modelFamilyBIndex) {
                return 0;
            }

            return modelFamilyAIndex < modelFamilyBIndex ? -1 : 1;
        }

        // Then sort by model version / recency:
        const modelVersionA = getModelVersion(a.name);
        const modelVersionB = getModelVersion(b.name);
        if (modelVersionA !== modelVersionB) {
            return modelVersionA > modelVersionB ? -1 : 1;
        }

        // Sort by number of parameters:
        if (a.numParams !== b.numParams) {
            return a.numParams < b.numParams ? -1 : 1;
        }

        // Finally, sort alphabetically:
        return a.name.localeCompare(b.name);
    });
};

export type BaseModelLookup = { [key: baseModel["name"]]: baseModel };
export const generateBaseModelSelectorOptions = (
    baseModels: baseModel[] | undefined,
    userTier: tier | undefined,
): [DropdownItemProps[], BaseModelLookup] => {
    // Initialize return variables:
    const baseModelOptions: DropdownItemProps[] = [];
    const baseModelLookup: BaseModelLookup = {};

    // Sort base models:
    const sortedBaseModels = sortBaseModels(baseModels);
    const userIsNotInEnterpriseTier = userTier && ![tier.ENTERPRISE_SAAS, tier.ENTERPRISE_VPC].includes(userTier);

    // Generate dropdown payloads and register each base model into lookup dictionary:
    for (const baseModel of sortedBaseModels ?? []) {
        const baseModelName = baseModel?.name ?? "";
        if (baseModelName === "") {
            continue;
        }

        // Add to the Base Model Options
        baseModelOptions.push({
            rawtext: baseModelName,
            text: (
                <div style={{ display: "flex", justifyContent: "space-between" }}>
                    <TruncatedText text={baseModelName} />
                    {/* If the user is not enterprise-tier, append "enterprise only" for any models that are larger
                    than 7b: */}
                    {userIsNotInEnterpriseTier && isEnterpriseTierModel(baseModel) && (
                        <span style={{ color: SEMANTIC_GREY_DISABLED }}>(Enterprise only)</span>
                    )}
                </div>
            ),
            value: baseModelName,
            key: baseModelName,
        });

        // Add to the lookup map
        baseModelLookup[baseModelName] = baseModel;
    }

    return [baseModelOptions, baseModelLookup];
};

const sortAccelerators = (accelerators?: deploymentAcceleratorOption[]) => {
    // TODO: improve...
    return accelerators?.sort((a, b) => {
        const gpuTypeA = a.id.split("_");
        const gpuTypeB = b.id.split("_");

        // Sort by GPU type first (H100 > A100 > L40S > L4 > A10G > T4)
        if (gpuTypeA[0] !== gpuTypeB[0]) {
            switch (gpuTypeA[0]) {
                case "h100":
                    return -1;
                case "a100":
                    return gpuTypeB[0] === "h100" ? 1 : -1;
                case "l40s":
                    return gpuTypeB[0] === "h100" || gpuTypeB[0] === "a100" ? 1 : -1;
                case "l4":
                    return gpuTypeB[0] === "h100" || gpuTypeB[0] === "a100" || gpuTypeB[0] === "l40s" ? 1 : -1;
                case "a10":
                    return gpuTypeB[0] === "t4" ? -1 : 1;
                case "t4":
                    return 1;
                default:
                    return 1;
            }
        }

        // Then sort by fraction:
        const fractionA = Number(gpuTypeA[2]);
        const fractionB = Number(gpuTypeB[2]);
        if (fractionA > fractionB) {
            return -1;
        }
        if (fractionA < fractionB) {
            return 1;
        }
        return 0;
    });
};

export type AcceleratorLookup = { [key in acceleratorId]?: deploymentAcceleratorOption };
export const generateAcceleratorSelectorOptions = (
    accelerators: deploymentAcceleratorOption[] | undefined,
    userTier: tier | undefined,
): [DropdownItemProps[], AcceleratorLookup] => {
    // Initialize return variables:
    const acceleratorOptions: DropdownItemProps[] = [];
    const acceleratorLookup: AcceleratorLookup = {};

    // Sort and filter accelerators:
    const sortedAccelerators = sortAccelerators(accelerators);
    const userIsNotInEnterpriseTier = userTier && ![tier.ENTERPRISE_SAAS, tier.ENTERPRISE_VPC].includes(userTier);
    const filteredAccelerators = sortedAccelerators?.filter((acceleratorOption) => {
        if (userTier === undefined) {
            return false;
        }
        return acceleratorOption.availability.tiers.includes(userTier);
    });

    // Generate dropdown options and register each accelerator into lookup dictionary:
    filteredAccelerators?.forEach((acceleratorOption) => {
        // TODO: Switch to `compute` when possible:
        // Generate formatted, user-facing name and cost of accelerator (e.g. "1 x A10G ($0.00/hr)"):
        const costPerHour = centsToDollars(acceleratorOption.compute.cost.centsPerHour);
        let acceleratorName = formatAcceleratorID(acceleratorOption.id);
        if (userIsNotInEnterpriseTier) {
            acceleratorName += ` (${costPerHour}/hr)`;
        }

        acceleratorOptions.push({
            rawtext: acceleratorName,
            text: (
                <div style={{ display: "flex", justifyContent: "space-between" }}>
                    <TruncatedText text={acceleratorName} />
                </div>
            ),
            value: acceleratorOption.id,
            key: acceleratorOption.id,
        });
        acceleratorLookup[acceleratorOption.id] = acceleratorOption;
    });

    return [acceleratorOptions, acceleratorLookup];
};

export const sortQuantizationTypes = (a: deploymentQuantization, b: deploymentQuantization) => {
    switch (a) {
        case deploymentQuantization.NONE:
            return -1;
        case deploymentQuantization.FP8:
            return b === deploymentQuantization.NONE ? 1 : -1;
        case deploymentQuantization.BITSANDBYTES_NF4:
            return 1;
        default:
            return 1;
    }
};
