import { QueryClient, QueryFunctionContext } from "@tanstack/react-query";
import { AxiosError } from "axios";
import _ from "lodash";

import { adapter } from "@/autogen/openapi";

import { parseStreamResponse } from "../api/stream";
import { getTraceId } from "../api/trace";
import { APIResponse, Auth0TokenOptions } from "../data";
import metrics from "../metrics/metrics";
import {
    createV1APIServer,
    createV2APIServer,
    getServingAPIEndpoint,
    redirectIfSessionInvalid,
    servingAPIServer,
} from "../utils/api";
import { getErrorMessage } from "../utils/errors";
import { injectVariablesIntoPromptTemplate, wrapPromptWithSystemTemplate } from "./utils/data-utils";
import { PredibaseStreamDetails, PredibaseStreamResponse, predibaseStreamHasFinished } from "./utils/lorax";
import { State } from "./utils/reducer";

// TODO: This is equivalent to `GetModelWithRunsResponse`, but I hate those types and I want this to be explicit. Plus,
// those types are out of date (`runs` is wrong) and we only care about the `modelVersion` property here:

interface Model {
    id: number;
    uuid: string;
    repoID: number;
    repo?: ModelRepo;
    datasetID: number;
    dataset?: Dataset;
    repoVersion: number;
    llmBaseModelName?: string;

    // Deprecated
    tag: string;

    description: string;
    created: Date;
    completed?: Date;
    trainingCreated?: Date;
    trainingCompleted?: Date;
    url: string;
    workspaceURL: string;
    status: string;
    features: ModelFeature[];
    modelMetrics?: ModelMetric[];
    modelLabels: any[];
    modelEngine: any;
    config: CreateModelConfig;
    activeFields: ActiveModel[];
    progress?: any;
    experimentID: string;
    bestRunID: string;
    activeRunID?: string;
    engineID: number;
    engine?: Engine;
    parentID?: number;
    parent?: number;
    errorText?: string;
    deployments?: ModelDeployment[];
    user?: CurrentUser; // TODO: Not exactly the right type, but close enough
    createdByUserID: number;
    starred: boolean;
    archived: boolean;
    engineTemplate?: EngineTemplate;
}

interface LegacyModelByIDResponse {
    modelVersion: Model;
    runs: ModelRun[] | null;
    errorMessage: string;
    modelsInRepoCount: number;
}
export const getLegacyModelByID = async (modelID: string, auth0TokenOptions?: Auth0TokenOptions) => {
    const endpoint = `models/version/${modelID}`;
    const apiServer = await createV1APIServer(auth0TokenOptions);

    return apiServer
        .get<LegacyModelByIDResponse>(endpoint)
        .then((res) => {
            return res.data.modelVersion;
        })
        .catch((error) => {
            const errorMsg = getErrorMessage(error) ?? "";
            metrics.captureError("api_error", errorMsg, {
                method: "GET",
                endpoint,
                trace_id: getTraceId(error),
            });
            redirectIfSessionInvalid(errorMsg);
            throw errorMsg;
        });
};

export const getAdaptersList = async (auth0TokenOptions?: Auth0TokenOptions) => {
    const endpoint = `adapters`;
    const v2APIServer = await createV2APIServer(auth0TokenOptions);

    return v2APIServer
        .get<APIResponse<adapter[]>>(endpoint)
        .then((res) => {
            return res.data.data;
        })
        .catch((error) => {
            const errorMsg = getErrorMessage(error) ?? "";
            metrics.captureError("api_error", errorMsg, {
                method: "GET",
                endpoint,
                trace_id: getTraceId(error),
            });
            redirectIfSessionInvalid(errorMsg);
            throw errorMsg;
        });
};

export interface DeploymentIsReadyResponse {
    name: string;
    ready: boolean;
}
export const getDeploymentIsReady = async (deploymentName: string, auth0TokenOptions?: Auth0TokenOptions) => {
    const endpoint = `llms/${deploymentName}/ready`;
    const apiServer = await createV1APIServer(auth0TokenOptions);

    return apiServer
        .get<DeploymentIsReadyResponse>(endpoint)
        .then((res) => {
            return res.data;
        })
        .catch((error) => {
            const errorMsg = getErrorMessage(error) ?? "";
            metrics.captureError("api_error", errorMsg, {
                method: "GET",
                endpoint,
                trace_id: getTraceId(error),
            });
            redirectIfSessionInvalid(errorMsg);
            throw errorMsg;
        });
};

export type GenerateParameters = State;

export interface GenerateResponse {
    generated_text: string;
    details: PredibaseStreamDetails;
    tokensPerSecond: number;
}

export const generate = (
    tenantShortcode: string | undefined,
    apiToken: string | undefined,
    {
        selectedDeployment,
        selectedAdapter,
        prompt,
        promptTemplate,
        promptTemplateVariables,
        maxNewTokens,
        temperature,
        doSample,
        topP,
    }: GenerateParameters,
) => {
    const parameters: { [key: string]: any } = {
        max_new_tokens: maxNewTokens,
    };

    if (doSample) {
        parameters.do_sample = true;
        parameters.temperature = temperature;

        if (topP !== 1) {
            parameters.top_p = topP;
        }
    }

    if (selectedAdapter) {
        const adapterId = `${selectedAdapter?.repo}/${selectedAdapter?.repo}`;
        parameters.adapter_id = adapterId;
        parameters.adapter_source = "pbase";
        parameters.api_token = apiToken;
    }

    return servingAPIServer
        .post<GenerateResponse>(
            `${tenantShortcode}/deployments/v2/llms/${selectedDeployment?.name}/generate`,
            {
                inputs: wrapPromptWithSystemTemplate(
                    _.isEmpty(promptTemplateVariables)
                        ? prompt
                        : injectVariablesIntoPromptTemplate(promptTemplateVariables, promptTemplate),
                    selectedDeployment,
                    selectedAdapter,
                ),
                parameters,
            },
            {
                headers: {
                    Authorization: `Bearer ${apiToken}`,
                },
            },
        )
        .then((res) => {
            // TODO: explore inversion of control or just return a more extensive payload with its own type and push
            // that to consumers explicitly (right now the type is an implicit union - not great).
            const details = { ...res.data.details, generated_tokens: parseInt(res.headers["x-generated-tokens"]) };
            return {
                ...res.data,
                tokensPerSecond: Math.trunc(1000 / parseInt(res.headers["x-time-per-token"])),
                details: details,
            };
        })
        .catch((error) => {
            // TODO: WTF is going on with these axios types.
            const axiosError = error as AxiosError;
            // TODO: Elegantly handler 5XX errors:
            if (axiosError.response?.status === 502 || axiosError.response?.status === 503) {
                throw Object.assign(new Error("Your deployment is still spinning up! Try again in a few moments. "), {
                    code: axiosError.response?.status,
                });
            }
            // TODO: WTF is going on with these error types
            const errorMsg = getErrorMessage(error) ?? "";
            redirectIfSessionInvalid(errorMsg);
            throw axiosError;
        });
};

export const generate_stream = async (
    tenantShortcode: string | undefined,
    apiToken: string | undefined,
    {
        selectedDeployment,
        selectedAdapter,
        prompt,
        promptTemplate,
        promptTemplateVariables,
        maxNewTokens,
        temperature,
    }: GenerateParameters,
    context: QueryFunctionContext,
    queryClient: QueryClient,
) => {
    const parameters: { [key: string]: any } = {
        max_new_tokens: maxNewTokens,
        temperature,
        details: true,
    };

    if (selectedAdapter) {
        const adapterId = `${selectedAdapter?.repo}/${selectedAdapter?.versionTag}`;
        parameters.adapter_id = adapterId;
        parameters.adapter_source = "pbase";
        parameters.api_token = apiToken;
    }

    performance.clearMarks(`start_prompt`);
    performance.clearMarks(`end_prompt`);
    performance.clearMeasures(`prompt`);
    performance.mark(`start_prompt`);

    // TODO: Updated Axios to latest version and use the new streaming API
    const url =
        getServingAPIEndpoint() + `/${tenantShortcode}/deployments/v2/llms/${selectedDeployment?.name}/generate_stream`;
    return fetch(url, {
        method: "POST",
        headers: {
            "Content-Type": "application/json",
            Authorization: `Bearer ${apiToken}`,
        },
        body: JSON.stringify({
            inputs: wrapPromptWithSystemTemplate(
                _.isEmpty(promptTemplateVariables)
                    ? prompt
                    : injectVariablesIntoPromptTemplate(promptTemplateVariables, promptTemplate),
                selectedDeployment,
                selectedAdapter,
            ),
            parameters,
        }),
        signal: context.signal,
    })
        .then(async (response) => {
            if (response.ok && response.body) {
                let answer = "";
                const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
                while (true) {
                    const { value, done } = await reader.read();
                    if (done) {
                        break;
                    }

                    let payload: PredibaseStreamResponse | null = null;
                    for (const event of parseStreamResponse(value)) {
                        try {
                            payload = JSON.parse(event?.data ?? "");
                        } catch (e) {
                            throw new Error(`Unable to parse response into JSON: ${event?.data}`);
                        }

                        // TODO: Update LoRAX to return a 422 when validation fails
                        if (payload?.error) {
                            throw new Error(payload.error);
                        }

                        const text = payload?.token?.text;
                        // Ignore start token
                        if (text !== "<s>") {
                            answer = `${answer}${text}`;
                        }

                        if (payload?.details && predibaseStreamHasFinished(payload?.details)) {
                            performance.mark(`end_prompt`);
                            const promptMeasure = performance.measure(`prompt`, `start_prompt`, `end_prompt`);
                            return {
                                details: payload.details,
                                generated_text: payload.generated_text,
                                tokensPerSecond: Math.trunc(
                                    (payload.details.generated_tokens / promptMeasure.duration) * 1000,
                                ),
                            } as GenerateResponse;
                        }
                    }

                    // Update the cache with the latest generated text
                    // https://medium.com/@arpitmalik04/implementing-streaming-data-with-eventsource-in-react-query-4b91b794abfd
                    queryClient.setQueryData(context.queryKey, { generated_text: answer, details: payload?.details });
                }
            }

            /**
             * Immediately stop if the server returns a rate limited response code.
             */
            if (response.status === 429) {
                throw Object.assign(
                    new Error("You have exceeded the number of prompts allowed. Please try again later."),
                    {
                        code: response.status,
                    },
                );
            }

            /**
             * Retry the request when server is spinning up.
             */
            if (response.status === 502 || response.status === 503) {
                throw Object.assign(new Error("Your deployment is still spinning up! Try again in a few moments."), {
                    code: response.status,
                });
            }
        })
        .catch((error) => {
            const errorMsg = getErrorMessage(error) ?? "";
            redirectIfSessionInvalid(errorMsg);
            throw error;
        });
};
