import { ComponentProps, useMemo, useState } from "react";

import type { ManipulateType } from "dayjs";
import { CartesianGrid, Legend, Line, LineChart, ResponsiveContainer, Tooltip, XAxis, YAxis } from "recharts";
import { CurveType } from "recharts/types/shape/Curve";
import { Divider, Dropdown, Icon, Loader, Message, Popup, Segment } from "semantic-ui-react";

import { acceleratorId, deployment } from "@/autogen/openapi";

import { SEMANTIC_GREY_DISABLED } from "../../../../utils/colors";
import dayjsExtended from "../../../../utils/dayjs";
import { getErrorMessage } from "../../../../utils/errors";
import { formatValueToNumericString } from "../../../../utils/numbers";
import {
    stringifiedValue,
    typesafeDeploymentMetrics,
    typesafeDeploymentMetricsTuple,
    unixEpoch,
} from "../../../data/data-utils";
import { useDeploymentMetricsQuery } from "../../../data/query";

const timeRangeDropdownOptions: {
    text: string;
    value: string;
    semantic: [number, ManipulateType];
    format: string;
}[] = [
    { text: "Last 5 minutes", value: "5m", semantic: [5, "minute"], format: "HH:mm" },
    { text: "Last 15 minutes", value: "15m", semantic: [15, "minute"], format: "HH:mm" },
    { text: "Last 30 minutes", value: "30m", semantic: [30, "minute"], format: "HH:mm" },
    { text: "Last 1 hour", value: "1h", semantic: [1, "hour"], format: "HH:mm" },
    { text: "Last 3 hours", value: "3h", semantic: [3, "hour"], format: "HH:mm" },
    { text: "Last 6 hours", value: "6h", semantic: [6, "hour"], format: "HH:mm" },
    { text: "Last 12 hours", value: "12h", semantic: [12, "hour"], format: "HH:mm" },
    { text: "Last 24 hours", value: "24h", semantic: [24, "hour"], format: "MM/DD, HH:mm" },
    { text: "Last 2 days", value: "2d", semantic: [2, "day"], format: "MMM DD, HH:mm" },
    { text: "Last 7 days", value: "7d", semantic: [7, "day"], format: "MMM DD, HH:mm" },
    { text: "Last 30 days", value: "30d", semantic: [30, "day"], format: "YY MMM DD, HH:mm" },
];

const PROMETHEUS_QUERY_DATA_LIMIT = 1000 * 11; // 11000 data points
const STRETCH_FACTOR = 1.2; // Used to decrease report resolution to avoid hitting upper data limit/overwhelming app

// Helpers:
const getStepSecsForTimeframe = (fromTime: unixEpoch | undefined) => {
    const fromTimeDJS = fromTime !== undefined ? dayjsExtended.unix(fromTime) : undefined;
    const timeframeSeconds = fromTimeDJS ? dayjsExtended().diff(fromTimeDJS, "second") : undefined;

    return Math.max(
        timeframeSeconds !== undefined
            ? Math.floor((timeframeSeconds / PROMETHEUS_QUERY_DATA_LIMIT) * STRETCH_FACTOR)
            : -1,
        60,
    );
};

const commonNumberFormatter = (val?: number, fractionDigits = 2) => {
    return val !== undefined
        ? formatValueToNumericString(val, fractionDigits, undefined, undefined, "lessPrecision")
        : undefined;
};

const getDataWithAverage = (data?: typesafeDeploymentMetricsTuple[], dimension = 0) => {
    // Cast the data to a more precise type:
    // const rawData = data?.result.at(dimension)?.values as [unixTimestamp, stringifiedValue][] | undefined;
    const rawData = data;

    // Filter out invalid data points, and if there is actually no data then return an empty array:
    const chartData =
        rawData === undefined
            ? []
            : rawData
                  ?.map((value) => {
                      return { time: value[0], value: parseFloat(value[1]) };
                  })
                  .filter((value) => !isNaN(value.value));

    const average =
        chartData !== undefined && chartData.length > 0
            ? commonNumberFormatter(chartData.reduce((acc, curr) => acc + curr.value, 0) / chartData.length)
            : undefined;

    return [chartData, average] as [typeof chartData, typeof average];
};

const ChartSegmentLayout = (props: {
    loading: boolean;
    chartData: { time: unixEpoch; value: number }[] | undefined;
    title: string;
    tooltip?: string;
    legendLabel: string;
    summaryStatistics: [string, number | string | undefined][];
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    error?: Error | null;
    lineType?: CurveType;
    lineColor?: string;
    yAxisTickFormatter?: ComponentProps<typeof YAxis>["tickFormatter"];
    xAxisTickFormatter?: ComponentProps<typeof XAxis>["tickFormatter"];
    tooltipXFormatter?: ComponentProps<typeof Tooltip>["labelFormatter"];
    tooltipYFormatter?: ComponentProps<typeof Tooltip>["formatter"];
}) => {
    // Parent state:
    const {
        loading,
        chartData,
        title,
        legendLabel,
        summaryStatistics,
        selectedTimeRange,
        unixFromTime,
        error,
        lineType,
        lineColor,
        xAxisTickFormatter,
        yAxisTickFormatter,
        tooltip,
        tooltipXFormatter,
        tooltipYFormatter,
    } = props;

    // Force segment content to load all at once (no partial rendering or pop-in from late-loading data such as the
    // summary stats):
    const thereIsValidChartData = chartData !== undefined && chartData.length > 0;
    const summaryStatsAreLoading = summaryStatistics.some((stat) => stat[1] === undefined);
    const stillLoading = loading || (thereIsValidChartData && summaryStatsAreLoading);
    const canShowSummaryStats = !stillLoading && error === null && thereIsValidChartData && !summaryStatsAreLoading;

    return (
        <>
            <Segment>
                <b style={{ marginRight: "0.5rem" }}>{title}</b>
                {tooltip && (
                    <Popup
                        className="transition-scale"
                        hoverable
                        wide={"very"}
                        position={"right center"}
                        trigger={<Icon name={"question circle"} color={"grey"} />}
                        content={<span>{tooltip}</span>}
                    />
                )}
                <Divider />
                {error ? (
                    <Message negative header="Error" content={getErrorMessage(error)} hidden={!props.error} />
                ) : stillLoading ? (
                    <Loader inline active style={{ width: "100%" }} />
                ) : (
                    <>
                        {canShowSummaryStats && (
                            <div style={{ display: "flex", marginBottom: "2rem" }}>
                                {summaryStatistics.map((stat) => {
                                    const [label, value] = stat;
                                    if (value === undefined) {
                                        return null;
                                    }
                                    return (
                                        <div style={{ marginRight: "5rem" }} key={`${title}_${label}`}>
                                            <h2>{value}</h2>
                                            {label}
                                        </div>
                                    );
                                })}
                            </div>
                        )}

                        <ResponsiveContainer aspect={6}>
                            <LineChart data={chartData} style={{ overflow: "none" }}>
                                <XAxis
                                    dataKey={"time"}
                                    domain={[
                                        unixFromTime !== undefined ? unixFromTime : "auto",
                                        unixFromTime !== undefined ? dayjsExtended().unix() : "auto",
                                    ]}
                                    name="Time"
                                    tickFormatter={
                                        xAxisTickFormatter ??
                                        ((unixTime) => dayjsExtended.unix(unixTime).format(selectedTimeRange?.format))
                                    }
                                    type="number"
                                />
                                <YAxis
                                    tickFormatter={
                                        yAxisTickFormatter ?? ((val) => commonNumberFormatter(val as number) ?? "")
                                    }
                                />
                                {chartData?.length === 0 && (
                                    <text x="50%" y="30%" dy={+12} style={{ fontSize: "1rem" }} textAnchor="middle">
                                        No activity during this time period
                                    </text>
                                )}
                                <Line
                                    type={lineType ?? "monotone"}
                                    dataKey="value"
                                    name={legendLabel}
                                    stroke={lineColor ?? "#8884d8"}
                                    strokeWidth={2}
                                    dot={false}
                                    isAnimationActive={false}
                                />
                                <Tooltip
                                    labelFormatter={
                                        tooltipXFormatter ??
                                        ((unixTime) => dayjsExtended.unix(unixTime).format("YYYY-MM-DDTHH:mm:ssZ[Z]"))
                                    }
                                    formatter={
                                        tooltipYFormatter ?? ((value) => commonNumberFormatter(value as number, 4))
                                    }
                                />
                                <Legend wrapperStyle={{ bottom: -5, paddingBottom: "8px" }} />
                                <CartesianGrid stroke="#eee" strokeDasharray="5 5" />
                            </LineChart>
                        </ResponsiveContainer>
                    </>
                )}
            </Segment>
        </>
    );
};

// Chart segments:
const RequestsChartSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    rawData?: typesafeDeploymentMetrics;
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const { selectedTimeRange, unixFromTime, rawData, rawDataLoading, rawDataError } = props;

    // Derived state:
    const queriesPerSecond = rawData?.queriesPerSecond;
    const qot = rawData?.queriesOverTime.at(1) as stringifiedValue | undefined;
    const otpr = rawData?.generatedTokensPerRequest.at(1) as stringifiedValue | undefined;
    const itpr = rawData?.inputTokensPerRequest.at(1) as stringifiedValue | undefined;
    // Note: Locale string usage here just inserts commas for thousands separators. We do not use the common formatter
    // because we want to display the full number instead of scientific notation:
    const totalRequests = qot !== undefined && qot !== "NaN" ? Math.floor(parseFloat(qot)).toLocaleString() : undefined;
    const averageOutputTokensPerRequest =
        otpr !== undefined && otpr !== "NaN" ? commonNumberFormatter(parseFloat(otpr)) : undefined;
    const averageInputTokensPerRequest =
        itpr !== undefined && itpr !== "NaN" ? commonNumberFormatter(parseFloat(itpr)) : undefined;
    const [chartData, averageRequests] = useMemo(() => getDataWithAverage(queriesPerSecond), [queriesPerSecond]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"Requests"}
            tooltip={
                "Even when no requests are sent, the deployment is configured to do periodic health checks. These health checks are infrequent and have little to no impact on performance, but you might see their requests in this graph."
            }
            legendLabel={"Requests per second"}
            summaryStatistics={[
                ["Total", totalRequests],
                ["Requests per second", averageRequests],
                ["Avg Output Tokens", averageOutputTokensPerRequest],
                ["Avg Input Tokens", averageInputTokensPerRequest],
            ]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
            lineColor={"#2185D0"}
        />
    );
};

const ThroughputChartSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    rawData?: typesafeDeploymentMetricsTuple[];
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const { selectedTimeRange, unixFromTime, rawData, rawDataLoading, rawDataError } = props;

    // Derived state:
    const [chartData, averageThroughput] = useMemo(() => getDataWithAverage(rawData), [rawData]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"Throughput"}
            legendLabel={"Generated tokens per second"}
            summaryStatistics={[["Throughput (output tok/s)", averageThroughput]]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
        />
    );
};

const LatencyChartSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    rawData?: typesafeDeploymentMetricsTuple[];
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const { selectedTimeRange, unixFromTime, rawData, rawDataLoading, rawDataError } = props;

    // Derived state:
    const [chartData, averageLatency] = useMemo(() => getDataWithAverage(rawData), [rawData]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"LoRAX Inference Time"}
            tooltip={"This is not the full response time since it does not include network latency."}
            legendLabel={"Seconds"}
            summaryStatistics={[["Mean (secs)", averageLatency]]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
            lineColor={"#2185D0"}
        />
    );
};

const QueueDurationSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    rawData?: typesafeDeploymentMetricsTuple[];
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const { selectedTimeRange, unixFromTime, rawData, rawDataLoading, rawDataError } = props;

    // Derived state:
    const [chartData, averageDuration] = useMemo(() => getDataWithAverage(rawData), [rawData]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"Queue Duration"}
            legendLabel={"Queue duration"}
            summaryStatistics={[["Mean (secs)", averageDuration]]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
        />
    );
};

const ReplicasSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    currentReplicas?: number;
    minReplicas?: number;
    maxReplicas?: number;
    rawData?: typesafeDeploymentMetricsTuple[];
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const {
        selectedTimeRange,
        unixFromTime,
        currentReplicas,
        minReplicas,
        maxReplicas,
        rawData,
        rawDataLoading,
        rawDataError,
    } = props;

    // Derived state:
    const chartData = useMemo(() => {
        // Fitler out invalid data points:
        return rawData === undefined
            ? []
            : rawData
                  ?.map((value) => {
                      return { time: value[0], value: parseFloat(value[1]) };
                  })
                  .filter((value) => !isNaN(value.value));
    }, [rawData]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"Replicas"}
            legendLabel={"Replicas"}
            summaryStatistics={[
                ["Current", currentReplicas],
                ["min", minReplicas],
                ["max", maxReplicas],
            ]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
            lineType={"step"}
            lineColor={"#198F35"}
        />
    );
};

const UtilizationSegment = (props: {
    selectedTimeRange?: (typeof timeRangeDropdownOptions)[number];
    unixFromTime?: unixEpoch;
    rawData?: typesafeDeploymentMetricsTuple[];
    rawDataLoading?: boolean;
    rawDataError?: Error | null;
}) => {
    // Parent state:
    const { selectedTimeRange, unixFromTime, rawData, rawDataLoading, rawDataError } = props;

    // Derived state:
    const [chartData, averageUtilization] = useMemo(() => {
        // Filter out invalid data points, and if there is actually no data then return an empty array:
        // ! NOTE: We want percentage values, so we multiply the values by 100 as well:
        const chartData =
            rawData === undefined
                ? []
                : rawData
                      ?.map((value) => {
                          return { time: value[0], value: parseFloat(value[1]) * 100 };
                      })
                      .filter((value) => !isNaN(value.value));
        const average =
            chartData !== undefined && chartData.length > 0
                ? commonNumberFormatter(chartData.reduce((acc, curr) => acc + curr.value, 0) / chartData.length)
                : undefined;
        return [chartData, average] as [typeof chartData, typeof average];
    }, [rawData]);

    return (
        <ChartSegmentLayout
            loading={!!rawDataLoading}
            chartData={chartData}
            title={"GPU Utilization"}
            legendLabel={"Utilization"}
            summaryStatistics={[["Average", `${averageUtilization}%`]]}
            selectedTimeRange={selectedTimeRange}
            unixFromTime={unixFromTime}
            error={rawDataError}
            yAxisTickFormatter={(val) => `${commonNumberFormatter(val)}%`}
            tooltipYFormatter={(value) => `${commonNumberFormatter(value as number, 4)}%`}
            lineColor={"#8D9E13"}
        />
    );
};

const DeploymentHealth = (props: { deployment?: deployment }) => {
    // Parent state:
    const { deployment } = props;

    const deploymentUUID = deployment?.uuid;

    const currentReplicas = deployment?.currentReplicas;
    const minReplicas = deployment?.config?.minReplicas;
    const maxReplicas = deployment?.config?.maxReplicas;

    // Local state:
    const [timeRange, setTimeRange] = useState<string>("3h"); // 3 hour range by default
    const selectedTimeRange = timeRangeDropdownOptions.find((option) => option.value === timeRange);

    // Derived state:
    const usingFractionalGPU =
        deployment?.accelerator.id === acceleratorId.A100_80GB_025 ||
        deployment?.accelerator.id === acceleratorId.A100_80GB_050;
    const [unixFromTime, stepSecs] = useMemo(() => {
        const unixFromTime = selectedTimeRange
            ? dayjsExtended()
                  .subtract(...selectedTimeRange.semantic)
                  .unix()
            : undefined;
        const stepSecs = getStepSecsForTimeframe(unixFromTime);
        return [unixFromTime, stepSecs];
    }, [selectedTimeRange]);

    // Query state:
    const { data, isLoading, error } = useDeploymentMetricsQuery(
        deploymentUUID ?? "",
        unixFromTime,
        undefined,
        stepSecs,
        {
            enabled: deploymentUUID !== undefined,
            refetchInterval: 1000 * 60 * 1, // 1 minute
            refetchOnWindowFocus: false,
            retry: false,
        },
    );

    return (
        <>
            <div style={{ marginBottom: "1rem" }}>
                <div style={{ marginBottom: "0.5rem" }}>
                    <b>Timeframe</b>
                </div>
                <Dropdown
                    labeled
                    options={timeRangeDropdownOptions}
                    selection
                    value={timeRange}
                    onChange={(_, { value }) => {
                        setTimeRange(value as string);
                    }}
                />
                <span style={{ marginLeft: "1rem", color: SEMANTIC_GREY_DISABLED }}>Refreshes every minute</span>
            </div>

            <RequestsChartSegment
                selectedTimeRange={selectedTimeRange}
                unixFromTime={unixFromTime}
                rawData={data}
                rawDataLoading={isLoading}
                rawDataError={error}
            />
            <ThroughputChartSegment
                selectedTimeRange={selectedTimeRange}
                unixFromTime={unixFromTime}
                rawData={data?.generatedTokensPerSecond}
                rawDataLoading={isLoading}
                rawDataError={error}
            />
            <LatencyChartSegment
                selectedTimeRange={selectedTimeRange}
                unixFromTime={unixFromTime}
                rawData={data?.inferenceDuration}
                rawDataLoading={isLoading}
                rawDataError={error}
            />
            <QueueDurationSegment
                selectedTimeRange={selectedTimeRange}
                unixFromTime={unixFromTime}
                rawData={data?.queueDuration}
                rawDataLoading={isLoading}
                rawDataError={error}
            />
            <ReplicasSegment
                selectedTimeRange={selectedTimeRange}
                currentReplicas={currentReplicas}
                minReplicas={minReplicas}
                maxReplicas={maxReplicas}
                unixFromTime={unixFromTime}
                rawData={data?.numReplicas}
                rawDataLoading={isLoading}
                rawDataError={error}
            />
            {!usingFractionalGPU && (
                <UtilizationSegment
                    selectedTimeRange={selectedTimeRange}
                    unixFromTime={unixFromTime}
                    rawData={data?.gpuUtilization}
                    rawDataLoading={isLoading}
                    rawDataError={error}
                />
            )}
        </>
    );
};

export default DeploymentHealth;
