import { useEffect, useRef, useState } from "react";

import _ from "lodash";

import { grpoJobMetrics } from "@/autogen/openapi";

// Common types
export type ChartDataPoint = { [key: string]: number; step: number };
export type ChartData = ChartDataPoint[];
export type RewardChartGroup = {
    rewardName: string;
    metrics: string[];
};

// Constants
export const DEFAULT_EMA_FACTOR = 0.1;

export const calculateEMA = (currentValue?: number, previousEMA?: number, emaFactor: number = DEFAULT_EMA_FACTOR) => {
    if (currentValue === undefined) return undefined;
    if (previousEMA === undefined) return currentValue;
    return emaFactor * currentValue + (1 - emaFactor) * previousEMA;
};

export const VALIDATION_PREFIX = "@validation_";
export const EMA_PREFIX = "@ema_";
export const transformMetricNameToChartSeriesName = (
    fullMetricName: string,
    isValidation?: boolean,
    isEMA?: boolean,
) => {
    let name = fullMetricName;
    if (isValidation) {
        name = `${VALIDATION_PREFIX}${name}`;
    }
    if (isEMA) {
        name = `${EMA_PREFIX}${name}`;
    }
    return name;
};

export const getRootNameOfVersionedMetric = (metric: string) => {
    const baseNameMatch = metric.match(/(.+)_v\d+$/);
    return baseNameMatch ? baseNameMatch[1] : null;
};

export const convertJobMetricsToChartDataPoint = (payload: grpoJobMetrics): ChartDataPoint => {
    const dataPoint: ChartDataPoint = { step: payload.data.steps! };

    // Add train rewards
    if (payload.data.grpo_train_rewards) {
        Object.entries(payload.data.grpo_train_rewards).forEach(([key, value]) => {
            dataPoint[key] = value;
        });
    }

    // Add validation rewards with prefix
    if (payload.data.grpo_validation_rewards) {
        Object.entries(payload.data.grpo_validation_rewards).forEach(([key, value]) => {
            dataPoint[transformMetricNameToChartSeriesName(key, true)] = value;
        });
    }

    return dataPoint;
};

// Get a map of reward names to all metrics that exist for that reward.
// Example:
// For a preprocessed dataset with a reward named "fn" that has two versions and emits validation metrics,
// the input will have data points with the following metrics:
// {
//     "fn_v1": 0.1,
//     "fn_v2": 0.2,
//     "@validation_fn_v1": 0.1,
//     "@validation_fn_v2": 0.2,
//     ...
//     "step": 3,
// }
// The output will be a map that includes the following keys:
// {
//     "fn": ["fn_v1", "fn_v2"],
//     "@validation_fn": ["@validation_fn_v1", "@validation_fn_v2"],
// }
const getRewardToVersionsMap = (preprocessedData: ChartData | undefined): Map<string, string[]> => {
    const rewardToMetricsMap: Map<string, string[]> = new Map();
    preprocessedData?.forEach((dataPoint) => {
        Object.entries(dataPoint).forEach(([metric]) => {
            if (metric === "step") return;

            const fullMetricName = metric; // for clarity

            // The base name will include prefixes like "@validation_". If this function were self-contained, it
            // would be better to extract the base name without these prefixes and return a generalized struct,
            // but because it is only used in this one place it is more efficient to let the map break keys
            // based on the prefixes.
            const baseName = getRootNameOfVersionedMetric(metric);

            // If the metric name doesn't contain a version number, just add it and it's okay if its overwritten in
            // subsequent iterations (because it'll always be the same)
            // Examples: "total_reward" and "total_reward_std"
            if (baseName === null) {
                rewardToMetricsMap.set(fullMetricName, [fullMetricName]);
                return;
            }

            // If the base name doesn't exist in the map, create an empty array for it:
            if (!rewardToMetricsMap.has(baseName)) {
                rewardToMetricsMap.set(baseName, []);
            }
            rewardToMetricsMap.get(baseName)?.push(fullMetricName);
        });
    });

    return rewardToMetricsMap;
};

export const generateEMASeries = (preprocessedData: ChartData | undefined, emaFactor?: number): ChartData => {
    if (preprocessedData === undefined || preprocessedData.length === 0) return [];

    // Clone the data so that we have a copy that can be modified in-place:
    const clonedData = _.cloneDeep(preprocessedData);
    const rewardToVersionsMap = getRewardToVersionsMap(clonedData);
    const rewardToEMAValuesMap: Map<string, Map<number, number>> = new Map();
    // For both training and validation data, the EMA should be computed as a single series across all versions. In
    // other words, EMA should not be versioned - there should not be an EMA line for each version of a metric on
    // the chart. To achieve this, we need to collate all the data for a given metric across all versions and then
    // compute the EMA for that combined array, taking care to handle gaps for validation data (which is only
    // produced every N steps).

    // Collate a an EMA series for each reward (pre-split into training and validation):
    rewardToVersionsMap.forEach((metrics, rewardName) => {
        const allDataForReward: Map<number, number> = new Map(); // We can't use an array because there may be gaps (e.g. validation data)
        metrics.sort().forEach((metric) => {
            const dataForMetric = preprocessedData.filter((item) => item[metric] !== undefined);
            dataForMetric.forEach((item) => {
                allDataForReward.set(item.step, item[metric]);
            });
        });
        // The reward name here is the base name of the metric (e.g. "fn" for "fn_v1" and "fn_v2")
        rewardToEMAValuesMap.set(rewardName, allDataForReward);
    });

    // Insert EMA values as new columns in the clonedData:
    rewardToEMAValuesMap.forEach((fullSeries, reward) => {
        // Initialize the EMA series:
        const first = fullSeries.entries().next()?.value;
        if (first === undefined) return;
        const emaForReward: Map<number, number> = new Map();
        const [initialStep, initialEMA] = first;

        // Two data structures keep track of the same thing:
        // 1. emaForReward: step to EMA value
        // 2. emaArr: array to hold all EMA values (used for fast access to last EMA value during iteration)
        emaForReward.set(initialStep, initialEMA);
        const emaArr: number[] = [initialEMA];

        let emaCounter = 0;
        fullSeries.forEach((metricValue, step) => {
            const currentEMA = calculateEMA(metricValue, emaArr[emaCounter - 1], emaFactor);
            if (currentEMA === undefined) return;
            emaForReward.set(step, currentEMA);
            emaArr.push(currentEMA);
            emaCounter++;
        });

        // NOTE: Again, the assumption here is that clonedData.length === fullSeries.length
        clonedData.forEach((point) => {
            const step = point.step;
            const ema = emaForReward.get(step);
            if (ema === undefined) return;
            point[transformMetricNameToChartSeriesName(reward, false, true)] = ema;
        });
    });

    return clonedData;
};

export const createChartDataGroups = (rawData: grpoJobMetrics[], processedData: ChartData): RewardChartGroup[] => {
    if (rawData.at(-1)?.data.grpo_train_rewards === undefined || processedData.length === 0) return [];

    // Create a map of reward names to all metrics that exist for that reward:
    const rewardToVersionsMap = getRewardToVersionsMap(processedData); // this map will have separate keys for validation and ema series
    const rootRewardNames = Array.from(rewardToVersionsMap.keys()).filter((name) => {
        return (
            !name.startsWith(VALIDATION_PREFIX) &&
            !name.startsWith(EMA_PREFIX) &&
            !name.startsWith(`${EMA_PREFIX}${VALIDATION_PREFIX}`)
        );
    });
    const rewardToMetricsMap = new Map<string, Set<string>>();
    rootRewardNames.forEach((rewardName) => {
        rewardToMetricsMap.set(rewardName, new Set(rewardToVersionsMap.get(rewardName)));
        const possibleValidationName = `${VALIDATION_PREFIX}${rewardName}`;
        const possibleTrainingEMAName = `${EMA_PREFIX}${rewardName}`;
        const possibleValidationEMAName = `${EMA_PREFIX}${VALIDATION_PREFIX}${rewardName}`;
        if (rewardToVersionsMap.has(possibleValidationName)) {
            const validationMetrics = rewardToVersionsMap.get(possibleValidationName)!;
            validationMetrics.forEach((metric) => {
                rewardToMetricsMap.get(rewardName)?.add(metric);
            });
        }
        if (rewardToVersionsMap.has(possibleTrainingEMAName)) {
            const trainingMetrics = rewardToVersionsMap.get(possibleTrainingEMAName)!;
            trainingMetrics.forEach((metric) => {
                rewardToMetricsMap.get(rewardName)?.add(metric);
            });
        }
        if (rewardToVersionsMap.has(possibleValidationEMAName)) {
            const validationMetrics = rewardToVersionsMap.get(possibleValidationEMAName)!;
            validationMetrics.forEach((metric) => {
                rewardToMetricsMap.get(rewardName)?.add(metric);
            });
        }
    });

    // Create chart groups
    return Array.from(rewardToMetricsMap.entries()).map(([rewardName, metrics]) => {
        return {
            rewardName,
            metrics: Array.from(metrics),
        };
    });
};

export const getRewardChartData = (rawData?: grpoJobMetrics[], emaFactor?: number) => {
    if (rawData === undefined || rawData.length === 0) return { chartData: undefined, chartDataGroups: undefined };

    const latestEvent = rawData.at(-1);

    // Convert all metrics to data points for the chart. After this, we still don't have EMA values.
    const preprocessedData = rawData.map((payload) => convertJobMetricsToChartDataPoint(payload));

    // Generate EMA series for all metrics. Afterwards, a single chart data point is a slice of data across all metrics
    // that exist for a given step.
    const processedData = generateEMASeries(preprocessedData, emaFactor);

    // Create chart data groups - one group per reward. Each group contains all versions of a reward, all validation
    // versions if available, as well as consolidated EMA series.
    const chartDataGroups: RewardChartGroup[] = latestEvent ? createChartDataGroups(rawData, processedData) : [];

    return {
        chartData: processedData,
        chartDataGroups,
    };
};

export const updateEMASeries = (
    existingChartData: ChartData,
    newPreprocessedPoint: ChartDataPoint,
    lastEMAValues: Record<string, number | undefined>,
    emaFactor?: number,
) => {
    const newEMAPoint: ChartDataPoint = { step: newPreprocessedPoint.step };
    const updatedLastEMAValues = { ...lastEMAValues };

    Object.entries(newPreprocessedPoint).forEach(([rawMetric, value]) => {
        if (rawMetric !== "step") {
            // First add the raw metric to the new point:
            newEMAPoint[rawMetric] = value;

            // Calculate EMA for this metric
            const rootName = getRootNameOfVersionedMetric(rawMetric) ?? rawMetric;
            const emaKey = transformMetricNameToChartSeriesName(rootName, false, true);
            const currentEMA = calculateEMA(value, updatedLastEMAValues[rootName], emaFactor);

            // Store new EMA value
            if (currentEMA !== undefined) {
                newEMAPoint[emaKey] = currentEMA;
            }

            // Update lastEMAValues for next calculation
            updatedLastEMAValues[rootName] = currentEMA;
        }
    });

    return {
        updatedChartData: [...existingChartData, newEMAPoint],
        updatedLastEMAValues,
    };
};

export const useStreamingMetrics = (websocketData: grpoJobMetrics[] | undefined, emaFactor?: number) => {
    // Local state to maintain all accumulated data
    const [chartData, setChartData] = useState<ChartData>([]);
    const [chartDataGroups, setChartDataGroups] = useState<RewardChartGroup[]>([]);

    // Track the last processed data point to detect new ones
    const lastProcessedIndexRef = useRef<number>(-1);

    // Ref for EMA calculations
    const lastEMAValuesRef = useRef<Record<string, number | undefined>>({});

    // Store raw data points (without EMA) to allow recalculation when emaFactor changes
    const rawDataPointsRef = useRef<ChartData>([]);

    // Process all data, including both initial and new streaming points
    useEffect(() => {
        if (!websocketData || websocketData.length === 0) return;

        // Get the latest data point for chart groups
        const latestEvent = websocketData.at(-1);
        if (!latestEvent) return;

        // Find new data points that haven't been processed yet
        const currentDataLength = websocketData.length;
        const lastProcessedIndex = lastProcessedIndexRef.current;

        // If we have new data points
        if (currentDataLength > lastProcessedIndex + 1) {
            // Get only the new data points
            const newDataPoints = websocketData.slice(lastProcessedIndex + 1);

            // Store the raw data points (without EMAs) for potential recalculation
            for (const dataPoint of newDataPoints) {
                const convertedPoint = convertJobMetricsToChartDataPoint(dataPoint);
                rawDataPointsRef.current.push(convertedPoint);
            }

            // Process each new data point
            let updatedChartData = [...chartData];
            let updatedLastEMAValues = { ...lastEMAValuesRef.current };

            for (const dataPoint of newDataPoints) {
                const convertedPoint = convertJobMetricsToChartDataPoint(dataPoint);

                const result = updateEMASeries(updatedChartData, convertedPoint, updatedLastEMAValues, emaFactor);

                updatedChartData = result.updatedChartData;
                updatedLastEMAValues = result.updatedLastEMAValues;
            }

            // Update the state with new data
            setChartData(updatedChartData);
            lastEMAValuesRef.current = updatedLastEMAValues;

            // Create or update chart data groups
            if (updatedChartData.length > 0) {
                const groups = createChartDataGroups(websocketData, updatedChartData);
                setChartDataGroups(groups);
            }

            // Update the last processed index
            lastProcessedIndexRef.current = currentDataLength - 1;
        }
    }, [websocketData, chartData, emaFactor]);

    // Recalculate all EMAs when emaFactor changes
    useEffect(() => {
        // Skip if no data or first render
        if (rawDataPointsRef.current.length === 0) return;

        // Get the latest point for chart groups
        const latestEvent = websocketData?.at(-1);
        if (!latestEvent) return;

        // Recalculate all EMAs from scratch with the new factor
        let recalculatedChartData: ChartData = [];
        let newEMAValues: Record<string, number | undefined> = {};

        for (const rawPoint of rawDataPointsRef.current) {
            const result = updateEMASeries(recalculatedChartData, rawPoint, newEMAValues, emaFactor);

            recalculatedChartData = result.updatedChartData;
            newEMAValues = result.updatedLastEMAValues;
        }

        // Update state with recalculated data
        setChartData(recalculatedChartData);
        lastEMAValuesRef.current = newEMAValues;

        // Update chart groups
        if (recalculatedChartData.length > 0) {
            const groups = createChartDataGroups(websocketData!, recalculatedChartData);
            setChartDataGroups(groups);
        }
    }, [emaFactor, websocketData]);

    return {
        chartData,
        chartDataGroups,
    };
};
