import { DropdownItemProps } from "semantic-ui-react";
import {
    accelerator,
    acceleratorId,
    baseModel,
    computeReservation,
    deploymentAcceleratorOption,
    deploymentQuantization,
    tier,
} from "../../api_generated";
import TruncatedText from "../../components/TruncatedText";
import { SEMANTIC_GREY_DISABLED } from "../../utils/colors";
import { centsToDollars } from "../../utils/strings";
import { formatAcceleratorID } from "./utils";

const modelFamilyOrder = [
    "solar",
    "solar-pro",
    "llama",
    "mistral",
    "mistral-nemo",
    "mixtral",
    "zephyr",
    "gemma",
    "phi",
    "qwen",
    "codellama",
];

const acceleratorGPUOrder = ["a100", "h100", "l40s", "l4", "a10", "t4"];

const quantizationOrder = [
    deploymentQuantization.NONE,
    deploymentQuantization.FP8,
    deploymentQuantization.BITSANDBYTES_NF4,
];

// TODO: Refine this numParams check later:
export const isEnterpriseTierModel = (baseModel: baseModel) => {
    let accelerators: deploymentAcceleratorOption[] = [];
    if (baseModel?.accelerators?.serving) {
        Object.entries(baseModel.accelerators.serving).forEach(([quantizationType, availableAccelerators]) => {
            accelerators = [...accelerators, ...availableAccelerators];
        });
    }

    return (
        accelerators.filter(
            (accelerator) =>
                accelerator.availability.tiers.includes(tier.PREMIUM) ||
                accelerator.availability.tiers.includes(tier.FREE),
        ).length === 0
    );
};

export const isEnterpriseTierAccelerator = (accelerator: accelerator) => {
    return (
        !accelerator.availability.tiers.includes(tier.FREE) && !accelerator.availability.tiers.includes(tier.PREMIUM)
    );
};

export const isVPCOnlyAccelerator = (accelerator: accelerator) => {
    return accelerator.availability.tiers.length === 1 && accelerator.availability.tiers.includes(tier.ENTERPRISE_VPC);
};

export const sortBaseModels = (baseModels?: baseModel[]) => {
    // Sort by model family (popularity), version, and number of parameters:
    return baseModels?.sort((a, b) => {
        // Sort by model family first:
        if (a.family !== b.family) {
            let modelFamilyAIndex = modelFamilyOrder.indexOf(a.family);
            let modelFamilyBIndex = modelFamilyOrder.indexOf(b.family);

            if (modelFamilyAIndex === -1) {
                modelFamilyAIndex = modelFamilyOrder.length;
            }

            if (modelFamilyBIndex === -1) {
                modelFamilyBIndex = modelFamilyOrder.length;
            }

            if (modelFamilyAIndex === modelFamilyBIndex) {
                return 0;
            }

            return modelFamilyAIndex < modelFamilyBIndex ? -1 : 1;
        }

        // Show the latest version first:
        if (a.version !== b.version) {
            return a.version > b.version ? -1 : 1;
        }

        // Show the smallest model first:
        if (a.numParams !== b.numParams) {
            return a.numParams < b.numParams ? -1 : 1;
        }

        // Finally, sort alphabetically:
        return a.name.localeCompare(b.name);
    });
};

export type BaseModelLookup = { [key: baseModel["name"]]: baseModel };
export const generateBaseModelSelectorOptions = (
    baseModels: baseModel[] | undefined,
    userTier: tier | undefined,
): [DropdownItemProps[], BaseModelLookup] => {
    // Initialize return variables:
    const baseModelOptions: DropdownItemProps[] = [];
    const baseModelLookup: BaseModelLookup = {};

    // Sort base models:
    const sortedBaseModels = sortBaseModels(baseModels);
    const userIsNotInEnterpriseTier = userTier && ![tier.ENTERPRISE_SAAS, tier.ENTERPRISE_VPC].includes(userTier);

    // Generate dropdown payloads and register each base model into lookup dictionary:
    for (const baseModel of sortedBaseModels ?? []) {
        const baseModelName = baseModel?.name ?? "";
        if (baseModelName === "") {
            continue;
        }

        // Add to the Base Model Options
        baseModelOptions.push({
            rawtext: baseModelName,
            text: (
                <div style={{ display: "flex", justifyContent: "space-between" }}>
                    <TruncatedText text={baseModelName} />
                    {/* If the user is not enterprise-tier, append "enterprise only" for any models that are larger
                    than 7b: */}
                    {userIsNotInEnterpriseTier && isEnterpriseTierModel(baseModel) && (
                        <span style={{ color: SEMANTIC_GREY_DISABLED }}>(Enterprise only)</span>
                    )}
                </div>
            ),
            value: baseModelName,
            key: baseModelName,
        });

        // Add to the lookup map
        baseModelLookup[baseModelName] = baseModel;
    }

    return [baseModelOptions, baseModelLookup];
};

const sortAccelerators = (accelerators?: deploymentAcceleratorOption[]) => {
    return accelerators?.sort((a, b) => {
        const gpuTypeA = a.id.split("_");
        const gpuTypeB = b.id.split("_");

        // Sort by GPU type first (largest to smallest GPU):
        if (gpuTypeA !== gpuTypeB) {
            let gpuTypeAIndex = acceleratorGPUOrder.indexOf(gpuTypeA[0]);
            let gpuTypeBIndex = acceleratorGPUOrder.indexOf(gpuTypeB[0]);

            if (gpuTypeAIndex === -1) {
                gpuTypeAIndex = modelFamilyOrder.length;
            }

            if (gpuTypeBIndex === -1) {
                gpuTypeBIndex = modelFamilyOrder.length;
            }

            if (gpuTypeAIndex === gpuTypeBIndex) {
                return 0;
            }

            return gpuTypeAIndex < gpuTypeBIndex ? -1 : 1;
        }

        // Then sort by fraction:
        const fractionA = Number(gpuTypeA[2]);
        const fractionB = Number(gpuTypeB[2]);

        if (fractionA > fractionB) {
            return -1;
        }

        if (fractionA < fractionB) {
            return 1;
        }

        return 0;
    });
};

export type AcceleratorLookup = { [key in acceleratorId]?: deploymentAcceleratorOption };
export const generateAcceleratorSelectorOptions = (
    accelerators: deploymentAcceleratorOption[] | undefined,
    userTier: tier | undefined,
    reservations: computeReservation[] | undefined,
): [DropdownItemProps[], AcceleratorLookup] => {
    // Initialize return variables:
    const acceleratorOptions: DropdownItemProps[] = [];
    const acceleratorLookup: AcceleratorLookup = {};

    // Sort and filter accelerators:
    const sortedAccelerators = sortAccelerators(accelerators);
    const userIsNotInEnterpriseTier = userTier && ![tier.ENTERPRISE_SAAS, tier.ENTERPRISE_VPC].includes(userTier);
    const filteredAccelerators = sortedAccelerators?.filter((acceleratorOption) => {
        if (userTier === undefined) {
            return false;
        }
        return acceleratorOption.availability.tiers.includes(userTier);
    });

    // Generate dropdown options and register each accelerator into lookup dictionary:
    filteredAccelerators?.forEach((acceleratorOption) => {
        // TODO: Switch to `compute` when possible:
        // Generate formatted, user-facing name and cost of accelerator (e.g. "1 x A10G ($0.00/hr)"):
        const costPerHour = centsToDollars(acceleratorOption.compute.cost.centsPerHour);
        let acceleratorName = formatAcceleratorID(acceleratorOption.id);
        if (userIsNotInEnterpriseTier) {
            acceleratorName += ` (${costPerHour}/hr)`;
        }
        // If the accelerator requires a reservation, check with our reservations
        let disabled = false;
        if (acceleratorOption.reservationRequired !== undefined && acceleratorOption.reservationRequired) {
            // Check the reservations to see if _any_ containt he accelerator
            disabled = true;
            reservations?.forEach((reservation) => {
                if (reservation.accelerator === acceleratorOption.id) {
                    // It becomes enabled if they have a matching reservation
                    disabled = false;
                    return;
                }
            });
        }

        acceleratorOptions.push({
            rawtext: acceleratorName,
            text: (
                <div style={{ display: "flex", justifyContent: "space-between" }}>
                    <TruncatedText
                        text={acceleratorName}
                        popoverText={
                            disabled
                                ? "This accelerator requires a reservation, please contact sales to purchase."
                                : acceleratorName
                        }
                    />
                </div>
            ),
            value: acceleratorOption.id,
            key: acceleratorOption.id,
            disabled: disabled,
        });
        acceleratorLookup[acceleratorOption.id] = acceleratorOption;
    });

    return [acceleratorOptions, acceleratorLookup];
};

export const sortQuantizationTypes = (a: deploymentQuantization, b: deploymentQuantization) => {
    // Sort by model family first:
    if (a !== b) {
        let aQuantIndex = quantizationOrder.indexOf(a);
        let bQuantIndex = quantizationOrder.indexOf(b);

        if (aQuantIndex === -1) {
            aQuantIndex = quantizationOrder.length;
        }

        if (bQuantIndex === -1) {
            bQuantIndex = quantizationOrder.length;
        }

        if (aQuantIndex === bQuantIndex) {
            return 0;
        }

        return aQuantIndex < bQuantIndex ? -1 : 1;
    }

    return 0;
};
