import { useQuery, useQueryClient, UseQueryOptions } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import Sockette from "sockette";
import { finetuningJob, predifineMetricsPayload, repo } from "../../../../api_generated";
import { useAuth0TokenOptions } from "../../../../data";
import metrics from "../../../../metrics/metrics";
import { getWebSocketEndpointV2 } from "../../../../utils/api";
import { JOBS_CONSTANT } from "../../../query";
import { getHistoricalJobMetrics } from "./data";
import { deduplicateArray } from "./util";

// Queries:
export const GET_HISTORICAL_JOB_METRICS_QUERY_KEY = (jobUUID: finetuningJob["uuid"]) => [
    JOBS_CONSTANT,
    jobUUID,
    "metrics",
    "historical",
];
export const useHistoricalJobMetricsQuery = (
    jobUUID: finetuningJob["uuid"],
    options?: Partial<UseQueryOptions<predifineMetricsPayload[]>>,
) => {
    const auth0TokenOptions = useAuth0TokenOptions();

    return useQuery<predifineMetricsPayload[]>({
        queryKey: GET_HISTORICAL_JOB_METRICS_QUERY_KEY(jobUUID),
        queryFn: () => getHistoricalJobMetrics(jobUUID, auth0TokenOptions),
        ...options,
    });
};

// Websockets:
// TODO: Maybe worth just keeping it all in the same cache, not a separate "live" cache:
export const GET_LIVE_JOB_METRICS_QUERY_KEY = (jobUUID: finetuningJob["uuid"]) => [
    JOBS_CONSTANT,
    jobUUID,
    "metrics",
    "live",
];

export const useLiveJobMetricsQuery = (
    jobUUID: finetuningJob["uuid"],
    options?: Partial<UseQueryOptions<predifineMetricsPayload[]>>,
) => {
    return useQuery<predifineMetricsPayload[]>({
        queryKey: GET_LIVE_JOB_METRICS_QUERY_KEY(jobUUID),
        ...options,
    });
};

// The hook that subscribes to the websocket and updates the metrics cache:
export const useLiveJobMetricsWebsocket = (
    jobUUID: finetuningJob["uuid"],
    repoUUID: repo["uuid"],
    versionTag: number,
    enabled: boolean = true,
) => {
    const queryClient = useQueryClient();
    const [isReconnected, setIsReconnected] = useState(false);

    useEffect(() => {
        const websocketServerAddress = getWebSocketEndpointV2();
        const endpoint = websocketServerAddress + `/finetuning/jobs/${jobUUID}/metrics/stream`;

        const captureError = (e: Event, type: string) => {
            const code = e instanceof CloseEvent ? e.code : undefined;
            metrics.captureError("ws_error", String(code), {
                type,
                jobUUID,
                endpoint,
            });
        };

        if (enabled) {
            const websocket = new Sockette(endpoint, {
                timeout: 5e3,
                maxAttempts: 3,
                onmaximum: (e) => captureError(e, "onmaximum"),
                onerror: (e) => captureError(e, "onerror"),
                onreconnect: (e) => {
                    setIsReconnected(true);
                    captureError(e, "onreconnect");
                },
                onmessage: (e: MessageEvent<string>) => {
                    const liveData: predifineMetricsPayload = JSON.parse(e.data);
                    queryClient.setQueryData<predifineMetricsPayload[]>(
                        GET_LIVE_JOB_METRICS_QUERY_KEY(jobUUID),
                        (prev: predifineMetricsPayload[] | undefined) => {
                            if (prev === undefined) {
                                // The first message will not have a previous value
                                return [liveData];
                            }

                            // If we've reconnected, we want to reset the cache:
                            if (isReconnected) {
                                prev = [];
                                setIsReconnected(false);
                            }

                            // TODO: Sometimes the websocket sends duplicate final events?
                            // Probably because of: https://predibase.slack.com/archives/C03HRSHQBA6/p1714690478934529?thread_ts=1714689528.910759&cid=C03HRSHQBA6
                            const last = prev.at(-1);
                            if (last !== undefined && !last.meta.is_completed) {
                                return deduplicateArray([...prev, liveData]);
                            }
                        },
                    );
                },
            });
            return () => {
                websocket.close();
            };
        }
    }, [jobUUID, enabled, repoUUID, versionTag]);
};

// Custom hooks:

export const useLiveJobMetrics = (
    jobUUID: finetuningJob["uuid"],
    repoUUID: repo["uuid"],
    versionTag: number,
    options?: Partial<UseQueryOptions<any>>,
) => {
    useLiveJobMetricsWebsocket(jobUUID, repoUUID, versionTag, options?.enabled);
    return useLiveJobMetricsQuery(jobUUID, options);
};
