import React, { ReactNode, useEffect, useMemo, useState } from "react";

import { useQueryClient } from "@tanstack/react-query";
import { Slider } from "antd";
import ShowMoreText from "react-show-more-text";
import {
    Accordion,
    Button,
    Checkbox,
    Divider,
    Grid,
    Header,
    Icon,
    Input,
    Pagination,
    Popup,
    Segment,
    Select,
    SemanticICONS,
    Table,
} from "semantic-ui-react";

import {
    epochPromptCompletions,
    finetuningJob,
    grpoJobMetrics,
    grpoPrompt,
    grpoRewards,
    promptCompletion,
    promptRewards,
    sftMetricsPayload,
} from "@/autogen/openapi";

import { getDocsHome } from "../../../../utils/api";
import { PREDIBASE_LIGHT_GRAY, SEMANTIC_GREY_ACTIVE } from "../../../../utils/colors";
import {
    GET_GRPO_ALL_PROMPT_COMPLETIONS_QUERY_KEY,
    GET_GRPO_PROMPT_REWARDS_QUERY_KEY,
    useGRPOPromptCompletions,
    useGRPOPromptRewards,
    useGRPOPrompts,
} from "../../../query";

import "./Completions.css";

type CombinedPromptData = grpoPrompt & Partial<promptRewards>;

type SelectedComparison = {
    epoch: number;
    index: number;
};

const promptsPerPage = 10;

const rewardFunctionNamesSorted = (rewards?: grpoRewards) => {
    return Object.keys(rewards ?? {})
        .filter((funcName) => funcName.toLocaleLowerCase() !== "total")
        .sort();
};

const TableHeaderStyling = {
    height: "4.42857rem",
};

const RewardsSummaryCellStyling = {
    fontSize: "0.85714rem",
    lineHeight: "1.28571rem",
    color: SEMANTIC_GREY_ACTIVE,
    textAlign: "center",
    borderBottom: "none",
    paddingBottom: 0,
};

const SectionHeader = (props: { children: string }) => (
    <Header
        size="large"
        as="h2"
        style={{ display: "inline-block", fontSize: "1.286rem", margin: "0 1.14rem 1.14rem 0" }}
    >
        {props.children}
    </Header>
);

const SegmentHeader = (props: { children: string }) => (
    <Header size="small" as="h3" style={{ display: "inline-block", marginTop: 0 }}>
        {props.children}
    </Header>
);

const StyledSegment = (props: { children: any }) => (
    <Segment
        style={{
            marginTop: 0,
            background: PREDIBASE_LIGHT_GRAY,
            border: "1px solid #dededf",
        }}
    >
        {props.children}
    </Segment>
);

const FunctionNameTableCell = (props: { functionName: string }) => {
    const { functionName } = props;
    if (functionName.length < 11) {
        return functionName;
    }

    let firstDelimiter = functionName.search(/[-_]/);
    if (firstDelimiter === -1) {
        firstDelimiter = 10;
    }
    const shortName = functionName.substring(0, firstDelimiter);

    return <Popup content={functionName} trigger={<span>{shortName}</span>} />;
};

const PromptTableCell = (props: { prompt: string }) => {
    const { prompt } = props;
    if (prompt.length < 101) {
        return prompt;
    }

    return <Popup content={prompt} trigger={<span>{prompt.substring(0, 100)}</span>} />;
};

const PromptSelector = (props: {
    combinedPromptData: CombinedPromptData[];
    commonPromptPrefix: string;
    selectedPromptId: string;
    setSelectedPromptId: React.Dispatch<React.SetStateAction<string>>;
    pageNumber: number;
    setOffset: React.Dispatch<React.SetStateAction<number>>;
    totalNumberOfPages: number;
}) => {
    const {
        combinedPromptData,
        commonPromptPrefix,
        selectedPromptId,
        setSelectedPromptId,
        pageNumber,
        setOffset,
        totalNumberOfPages,
    } = props;
    const functionNames = rewardFunctionNamesSorted(combinedPromptData[0]?.last_epoch_rewards_avg);

    const [isOpen, setIsOpen] = useState(true);
    const [showCommonPrefix, setShowCommonPrefix] = useState(false);

    return (
        <StyledSegment>
            <Accordion>
                <Accordion.Title
                    active={isOpen}
                    onClick={() => {
                        setIsOpen((open) => !open);
                    }}
                >
                    <Icon name="dropdown" />
                    <SegmentHeader>Select a Prompt</SegmentHeader>
                </Accordion.Title>
                <Accordion.Content active={isOpen}>
                    <Table className="completions-table">
                        <Table.Header style={TableHeaderStyling}>
                            <Table.Row>
                                <Table.HeaderCell rowSpan="2">Index</Table.HeaderCell>
                                <Table.HeaderCell rowSpan="2" className="border-right">
                                    Prompt
                                </Table.HeaderCell>
                                <Table.HeaderCell
                                    colSpan={functionNames.length + 2}
                                    style={RewardsSummaryCellStyling}
                                    className="border-right"
                                >
                                    Average Latest Reward Score
                                </Table.HeaderCell>
                                <Table.HeaderCell colSpan={functionNames.length + 2} style={RewardsSummaryCellStyling}>
                                    Additional Metadata
                                </Table.HeaderCell>
                            </Table.Row>
                            <Table.Row>
                                <Table.HeaderCell className="secondary-header-cell">Epoch</Table.HeaderCell>
                                {functionNames.map((funcName) => (
                                    <Table.HeaderCell key={`${funcName}-header`} className="secondary-header-cell">
                                        <FunctionNameTableCell functionName={funcName} />
                                    </Table.HeaderCell>
                                ))}
                                <Table.HeaderCell className="border-right secondary-header-cell">
                                    Total
                                </Table.HeaderCell>
                                {Object.keys(combinedPromptData[0]?.metadata_columns ?? {}).map((columnName) => (
                                    <Table.HeaderCell
                                        key={`${columnName}-meta-header`}
                                        className="secondary-header-cell"
                                    >
                                        {columnName}
                                    </Table.HeaderCell>
                                ))}
                            </Table.Row>
                        </Table.Header>
                        <Table.Body>
                            {combinedPromptData?.map((promptData) => {
                                const isSelectedPrompt = promptData.id === selectedPromptId;
                                return (
                                    <Table.Row
                                        key={promptData.id}
                                        style={{ backgroundColor: isSelectedPrompt ? "#DFF0FF" : undefined }}
                                    >
                                        <Table.Cell
                                            onClick={() => {
                                                setSelectedPromptId(promptData.id);
                                            }}
                                            style={{ cursor: "pointer" }}
                                        >
                                            <Icon
                                                name="check circle"
                                                color="blue"
                                                style={{ marginRight: "1rem", opacity: isSelectedPrompt ? 100 : 0 }}
                                                aria-hidden={isSelectedPrompt ? "false" : "true"}
                                            />
                                            {promptData.id}
                                        </Table.Cell>
                                        <Table.Cell>
                                            <PromptTableCell
                                                prompt={
                                                    showCommonPrefix
                                                        ? promptData.prompt_text
                                                        : promptData.prompt_text.substring(commonPromptPrefix.length)
                                                }
                                            />
                                        </Table.Cell>
                                        <Table.Cell className="reward-cell">
                                            {promptData?.most_recently_trained_epoch}
                                        </Table.Cell>
                                        {functionNames.map((funcName) => (
                                            <Table.Cell
                                                className="reward-cell"
                                                key={`${funcName}-data-${promptData.id}`}
                                            >
                                                {promptData?.last_epoch_rewards_avg?.[funcName]}
                                            </Table.Cell>
                                        ))}
                                        <Table.Cell className="reward-cell border-right">
                                            {promptData?.last_epoch_rewards_avg?.total}
                                        </Table.Cell>
                                        {Object.keys(promptData?.metadata_columns).map((columnName) => (
                                            <Table.Cell key={`${columnName}-meta-row-${promptData.id}`}>
                                                {promptData?.metadata_columns[columnName]}
                                            </Table.Cell>
                                        ))}
                                    </Table.Row>
                                );
                            })}
                        </Table.Body>
                    </Table>
                    <div style={{ display: "flex", justifyContent: "center" }}>
                        <div style={{ flex: 1 }}>
                            <Checkbox
                                toggle={true}
                                checked={showCommonPrefix}
                                label="Show detected common prefix"
                                onClick={() => {
                                    setShowCommonPrefix((showPrefix) => !showPrefix);
                                }}
                                disabled={commonPromptPrefix.length === 0}
                            />
                        </div>
                        {totalNumberOfPages > 1 && (
                            <Pagination
                                activePage={pageNumber}
                                onPageChange={(e, data) => {
                                    e.preventDefault();

                                    const page = parseInt(String(data.activePage)) ?? 1;
                                    setOffset((page - 1) * promptsPerPage);
                                }}
                                totalPages={totalNumberOfPages}
                            />
                        )}
                        <div style={{ flex: 1 }}></div>
                    </div>
                </Accordion.Content>
            </Accordion>
        </StyledSegment>
    );
};

const CompletionDetails = (props: { completion?: promptCompletion }) => {
    const { completion } = props;
    const functionNames = rewardFunctionNamesSorted(completion?.rewards ?? {});

    return (
        <>
            <div
                style={{
                    display: "flex",
                    padding: "0.64rem 1.14rem",
                    backgroundColor: "#DFF0FF",
                }}
            >
                <span style={{ marginRight: "10px" }}>total: {completion?.rewards?.total}</span>
                {functionNames.map((funcName, index) => (
                    <span key={funcName} style={{ marginRight: index !== functionNames.length ? "10px" : "auto" }}>
                        {funcName}: {completion?.rewards[funcName]}
                    </span>
                ))}
            </div>
            <div
                style={{
                    padding: "0.64rem 1.14rem",
                    minHeight: "10rem",
                }}
            >
                {completion?.text}
            </div>
        </>
    );
};

const CompletionNavButton = (props: {
    ariaLabel: string;
    disabled: boolean;
    iconName: SemanticICONS;
    onClick: () => void;
}) => {
    const { ariaLabel, disabled, iconName, onClick } = props;
    return (
        <button
            className="button-reset"
            onClick={() => {
                onClick();
            }}
            disabled={disabled}
            aria-label={ariaLabel}
            style={{
                fontWeight: "bold",
                fontSize: "1.25rem",
            }}
        >
            <Icon name={iconName} color="blue" disabled={disabled} style={{ margin: "0" }} />
        </button>
    );
};

// TODO: Come back to this
// const generateMarks = (totalEpochs: number) => {
//     let marks = {};

//     for (let i = 1; i <= totalEpochs; i++) {
//         marks[i] = <Icon name='square outline' />
//     }

//     return marks;
// }

const EditableCompletionHeader = (props: {
    comparison: SelectedComparison;
    setComparison: React.Dispatch<React.SetStateAction<SelectedComparison>>;
    totalCompletionEpochs: number;
    totalSelectedEpochCompletions: number;
}) => {
    const { comparison, setComparison, totalCompletionEpochs, totalSelectedEpochCompletions } = props;

    return (
        <div
            style={{
                display: "flex",
                justifyContent: "space-between",
                alignItems: "center",
                padding: "0.57rem 1.14rem",
                height: "3.29rem",
            }}
        >
            <div style={{ display: "flex", alignItems: "center", minWidth: "50%" }}>
                <span style={{ marginRight: "0.25rem", fontWeight: "bold" }}>Epoch</span>
                <Input
                    value={comparison.epoch}
                    onChange={(val) => {
                        let newValue = parseInt(val.target.value);
                        if (isNaN(newValue)) {
                            newValue = 1;
                        }

                        if (newValue > totalCompletionEpochs) {
                            return;
                        }

                        setComparison({ epoch: newValue, index: 0 });
                    }}
                    style={{
                        maxHeight: "2.14286rem",
                        maxWidth: "4.57143rem",
                        marginRight: "0.5rem",
                    }}
                />
                <Slider
                    defaultValue={1}
                    value={comparison.epoch}
                    min={1}
                    max={totalCompletionEpochs}
                    // marks={generateMarks(totalCompletionEpochs)}
                    onChange={(val) => {
                        setComparison({ epoch: val, index: 0 });
                    }}
                    style={{
                        width: "100%",
                    }}
                />
            </div>
            <div style={{ display: "flex", alignItems: "center", fontSize: "0.85714rem" }}>
                <span>Completion:</span>
                <CompletionNavButton
                    ariaLabel="View previous comparison"
                    disabled={comparison.index < 1}
                    iconName="angle left"
                    onClick={() => {
                        setComparison((currentSelection) => ({
                            ...currentSelection,
                            index: currentSelection.index - 1,
                        }));
                    }}
                />
                <span style={{ fontWeight: "bold", textAlign: "center", width: "10px" }}>{comparison.index + 1}</span>
                <CompletionNavButton
                    ariaLabel="View next comparison"
                    disabled={comparison.index >= totalSelectedEpochCompletions - 1}
                    iconName="angle right"
                    onClick={() => {
                        setComparison((currentSelection) => ({
                            ...currentSelection,
                            index: currentSelection.index + 1,
                        }));
                    }}
                />
            </div>
        </div>
    );
};

const StaticCompletionHeader = (props: {
    comparison: SelectedComparison;
    headerText?: string;
    updateCTA?: ReactNode;
}) => {
    const { comparison, headerText, updateCTA } = props;

    return (
        <div
            style={{ display: "flex", justifyContent: "space-between", padding: "0.86rem 1.14rem", height: "3.29rem" }}
        >
            <div style={{ display: "flex" }}>
                <strong>{headerText}</strong>
                {updateCTA}
            </div>
            <div style={{ fontSize: "0.85714rem" }}>
                Epoch: <span style={{ fontWeight: "bold", marginRight: "1.143rem" }}>{comparison.epoch}</span>
                &nbsp; Completion: <span style={{ fontWeight: "bold" }}>{comparison.index + 1}</span>
            </div>
        </div>
    );
};

const CompletionsComparison = (props: {
    isDefaultComparisonMode: boolean;
    isNewDataAvailable: boolean;
    resetDataStateTracking: () => void;
    leftComparison: SelectedComparison;
    setLeftComparison: React.Dispatch<React.SetStateAction<SelectedComparison>>;
    rightComparison: SelectedComparison;
    setRightComparison: React.Dispatch<React.SetStateAction<SelectedComparison>>;
    promptCompletions: epochPromptCompletions[];
}) => {
    const {
        isDefaultComparisonMode,
        isNewDataAvailable,
        resetDataStateTracking,
        leftComparison,
        setLeftComparison,
        rightComparison,
        setRightComparison,
        promptCompletions,
    } = props;

    const latestCompletion = useMemo(() => {
        const latestEpoch = promptCompletions?.at(-1);
        const latestCompletion = latestEpoch?.completions?.at(-1);
        return { epoch: latestEpoch?.epoch ?? 1, index: latestCompletion?.index ?? 0 };
    }, [promptCompletions]);
    const totalCompletionEpochs = useMemo(() => {
        return promptCompletions?.at(-1)?.epoch ?? 1;
    }, [promptCompletions]);
    const totalLeftSelectedEpochCompletions = useMemo(() => {
        const epoch = promptCompletions.find((epochData) => epochData.epoch === leftComparison.epoch);
        return epoch?.completions.length ?? 0;
    }, [leftComparison, promptCompletions]);

    const leftComparisonCompletion = useMemo(() => {
        const epoch = promptCompletions.find((epochData) => epochData.epoch === leftComparison.epoch);
        return epoch?.completions?.find((completion) => completion.index === leftComparison.index);
    }, [promptCompletions, leftComparison]);
    const rightComparisonCompletion = useMemo(() => {
        const epoch = promptCompletions.find((epochData) => epochData.epoch === rightComparison.epoch);
        return epoch?.completions?.find((completion) => completion.index === rightComparison.index);
    }, [promptCompletions, rightComparison]);

    return (
        <StyledSegment>
            <Grid>
                <Grid.Row style={{ padding: 0 }}>
                    <Grid.Column
                        computer={8}
                        tablet={8}
                        mobile={16}
                        style={{ padding: 0, borderRight: "1px solid #DEDEDF" }}
                    >
                        {isDefaultComparisonMode ? (
                            <EditableCompletionHeader
                                comparison={leftComparison}
                                setComparison={setLeftComparison}
                                totalCompletionEpochs={totalCompletionEpochs}
                                totalSelectedEpochCompletions={totalLeftSelectedEpochCompletions}
                            />
                        ) : (
                            <StaticCompletionHeader comparison={rightComparison} headerText="Completion" />
                        )}
                        <CompletionDetails completion={leftComparisonCompletion} />
                    </Grid.Column>
                    <Grid.Column computer={8} tablet={8} mobile={16} style={{ padding: 0 }}>
                        <StaticCompletionHeader
                            comparison={rightComparison}
                            updateCTA={
                                isNewDataAvailable ? (
                                    <button
                                        aria-label="Load latest completion"
                                        className="button-reset"
                                        onClick={() => {
                                            setRightComparison(latestCompletion);
                                            resetDataStateTracking();
                                        }}
                                    >
                                        Load latest
                                    </button>
                                ) : undefined
                            }
                            headerText={
                                isDefaultComparisonMode
                                    ? "Highest-scoring completion from most recent epoch"
                                    : "Completion"
                            }
                        />
                        <CompletionDetails completion={rightComparisonCompletion} />
                    </Grid.Column>
                </Grid.Row>
            </Grid>
        </StyledSegment>
    );
};

const AllCompletions = (props: {
    leftComparison: SelectedComparison;
    rightComparison: SelectedComparison;
    setLeftComparison: React.Dispatch<React.SetStateAction<SelectedComparison>>;
    setRightComparison: React.Dispatch<React.SetStateAction<SelectedComparison>>;
    setIsDefaultComparisonMode: React.Dispatch<React.SetStateAction<boolean>>;
    promptCompletions: epochPromptCompletions[];
}) => {
    const {
        leftComparison,
        rightComparison,
        setLeftComparison,
        setRightComparison,
        setIsDefaultComparisonMode,
        promptCompletions,
    } = props;
    const functionNames = rewardFunctionNamesSorted(promptCompletions[0]?.completions[0]?.rewards ?? []);

    const [isOpen, setIsOpen] = useState(true);
    const [showAdvancedFilters, setShowAdvancedFilters] = useState(false);
    const [epochFilter, setEpochFilter] = useState<number | string>("all");
    const [functionFilter, setFunctionFilter] = useState<string>();
    const [functionFilterValue, setFunctionFilterValue] = useState<string>();

    const epochOptions = useMemo(() => {
        let uniqueValues: { [key: number]: string } = {};
        promptCompletions?.forEach((prompt) => {
            uniqueValues[prompt.epoch] = "";
        });
        const options = [
            {
                key: "all",
                value: "all",
                text: "All epochs",
            },
        ];

        Object.keys(uniqueValues).forEach((epoch) =>
            options.push({
                key: epoch,
                value: epoch,
                text: epoch,
            }),
        );

        return options;
    }, [promptCompletions]);

    const functionOptions = useMemo(() => {
        const options = [
            {
                key: "all",
                value: "all",
                text: "",
            },
        ];

        functionNames.forEach((funcName) =>
            options.push({
                key: funcName,
                value: funcName,
                text: funcName,
            }),
        );

        return options;
    }, [functionNames]);

    return (
        <StyledSegment>
            <Accordion>
                <div style={{ display: "flex", alignItems: "baseline", justifyContent: "space-between" }}>
                    <Accordion.Title
                        active={isOpen}
                        onClick={() => {
                            setIsOpen((open) => !open);
                        }}
                    >
                        <span>
                            <Icon name="dropdown" />
                            <SegmentHeader>View all completions</SegmentHeader>
                        </span>
                    </Accordion.Title>
                    <button
                        className="button-reset"
                        onClick={() => {
                            console.log("hello?");
                            setShowAdvancedFilters((currentState) => !currentState);
                        }}
                    >
                        {showAdvancedFilters ? "Hide advanced filters" : "Show advanced filters"}
                    </button>
                </div>
                <Accordion.Content active={isOpen}>
                    {showAdvancedFilters && (
                        <div
                            style={{
                                display: "flex",
                                alignItems: "center",
                                justifyContent: "center",
                                padding: "1.143rem",
                            }}
                        >
                            <span style={{ color: SEMANTIC_GREY_ACTIVE, marginRight: "0.7142rem" }}>
                                Filter by epoch
                            </span>
                            <Select
                                style={{ marginRight: "2.286rem" }}
                                options={epochOptions}
                                value={epochFilter}
                                onChange={(event, data) => {
                                    setEpochFilter(String(data.value));
                                }}
                            />
                            <span style={{ color: SEMANTIC_GREY_ACTIVE, marginRight: "0.7142rem" }}>
                                Filter by reward function
                            </span>
                            <Select
                                style={{ marginRight: "0.7142rem" }}
                                options={functionOptions}
                                value={functionFilter}
                                onChange={(event, data) => {
                                    setFunctionFilter(String(data.value));
                                }}
                            />
                            <span
                                style={{
                                    color: SEMANTIC_GREY_ACTIVE,
                                    marginRight: "0.7142rem",
                                }}
                            >
                                has value greater than
                            </span>
                            <Input
                                value={functionFilterValue}
                                onChange={(event, data) => {
                                    setFunctionFilterValue(String(data.value));
                                }}
                            />
                        </div>
                    )}
                    <Table style={{ marginTop: "0.5rem" }}>
                        <Table.Header style={TableHeaderStyling}>
                            <Table.Row>
                                <Table.HeaderCell rowSpan="2">Epoch</Table.HeaderCell>
                                <Table.HeaderCell rowSpan="2" style={{ minWidth: "120px" }} className="border-right">
                                    Completion&nbsp;#
                                </Table.HeaderCell>
                                <Table.HeaderCell
                                    colSpan={functionNames.length + 1}
                                    style={RewardsSummaryCellStyling}
                                    className="border-right"
                                >
                                    Reward Score
                                </Table.HeaderCell>
                                <Table.HeaderCell rowSpan="2">Length</Table.HeaderCell>
                                <Table.HeaderCell rowSpan="2">Completion</Table.HeaderCell>
                                <Table.HeaderCell rowSpan="2" style={{ minWidth: "109px" }}>
                                    Comparison Column
                                </Table.HeaderCell>
                            </Table.Row>
                            <Table.Row>
                                {functionNames.map((funcName) => (
                                    <Table.HeaderCell key={funcName} className="secondary-header-cell">
                                        <FunctionNameTableCell functionName={funcName} />
                                    </Table.HeaderCell>
                                ))}
                                <Table.HeaderCell className="border-right secondary-header-cell">
                                    Total
                                </Table.HeaderCell>
                            </Table.Row>
                        </Table.Header>
                        <Table.Body>
                            {promptCompletions
                                .filter((promptCompletion) => {
                                    if (epochFilter !== "all") {
                                        return promptCompletion.epoch === Number(epochFilter);
                                    }

                                    return true;
                                })
                                .map((promptCompletionsEpoch) => {
                                    return promptCompletionsEpoch.completions
                                        .filter((completion) => {
                                            if (functionFilter !== "all" && Number(functionFilterValue) > 0) {
                                                return (
                                                    completion.rewards?.[functionFilter ?? ""] >
                                                    Number(functionFilterValue)
                                                );
                                            }

                                            return true;
                                        })
                                        .map((promptCompletion) => {
                                            const currentCompletion = {
                                                epoch: promptCompletionsEpoch.epoch,
                                                index: promptCompletion.index,
                                            };
                                            return (
                                                <Table.Row
                                                    key={`${promptCompletionsEpoch.epoch}-${promptCompletion.index}`}
                                                >
                                                    <Table.Cell>{promptCompletionsEpoch.epoch}</Table.Cell>
                                                    <Table.Cell className="border-right">
                                                        {promptCompletion.index + 1}
                                                    </Table.Cell>
                                                    {functionNames.map((funcName) => (
                                                        <Table.Cell className="reward-cell" key={funcName}>
                                                            {promptCompletion.rewards[funcName]}
                                                        </Table.Cell>
                                                    ))}
                                                    <Table.Cell className="reward-cell border-right">
                                                        {promptCompletion.rewards.total}
                                                    </Table.Cell>
                                                    <Table.Cell>{promptCompletion.length_tokens}</Table.Cell>
                                                    <Table.Cell>{promptCompletion.text}</Table.Cell>
                                                    <Table.Cell>
                                                        <div style={{ display: "flex", minWidth: "78px" }}>
                                                            <Button
                                                                aria-label="Left"
                                                                color={
                                                                    JSON.stringify(leftComparison) ===
                                                                    JSON.stringify(currentCompletion)
                                                                        ? "blue"
                                                                        : "grey"
                                                                }
                                                                size="mini"
                                                                onClick={() => {
                                                                    setLeftComparison(currentCompletion);
                                                                    setIsDefaultComparisonMode(false);
                                                                }}
                                                            >
                                                                L
                                                            </Button>
                                                            <Button
                                                                aria-label="Right"
                                                                color={
                                                                    JSON.stringify(rightComparison) ===
                                                                    JSON.stringify(currentCompletion)
                                                                        ? "blue"
                                                                        : "grey"
                                                                }
                                                                size="mini"
                                                                onClick={() => {
                                                                    setRightComparison(currentCompletion);
                                                                    setIsDefaultComparisonMode(false);
                                                                }}
                                                            >
                                                                R
                                                            </Button>
                                                        </div>
                                                    </Table.Cell>
                                                </Table.Row>
                                            );
                                        });
                                })}
                        </Table.Body>
                    </Table>
                </Accordion.Content>
            </Accordion>
        </StyledSegment>
    );
};

const Completions = (props: {
    job?: finetuningJob;
    websocketOpen: boolean;
    streamingData?: sftMetricsPayload[] | grpoJobMetrics[];
}) => {
    const { job, websocketOpen, streamingData } = props;
    const jobUUID = job?.uuid ?? "";

    const [offset, setOffset] = useState(0);
    const [selectedPromptId, setSelectedPromptId] = useState("");
    const [leftComparison, setLeftComparison] = useState<SelectedComparison>({ epoch: 1, index: 0 });
    const [rightComparison, setRightComparison] = useState<SelectedComparison>({ epoch: 1, index: 0 });
    const [isDefaultComparisonMode, setIsDefaultComparisonMode] = useState(true);
    const [isNewDataAvailable, setIsNewDataAvailable] = useState(false);
    const [lastEpochSeen, setLastEpochSeen] = useState(0);
    const websocketDelay = 500; // milliseconds

    const queryClient = useQueryClient();

    const { data: prompts } = useGRPOPrompts(jobUUID, offset, promptsPerPage);
    const { data: promptRewards } = useGRPOPromptRewards(jobUUID);
    const { data: promptCompletions } = useGRPOPromptCompletions(jobUUID, selectedPromptId, {
        enabled: selectedPromptId.length > 0,
    });

    const totalNumberOfPages = promptRewards?.prompts ? Math.ceil(promptRewards.prompts.length / promptsPerPage) : 1;

    const resetDataStateTracking = () => {
        setIsNewDataAvailable(false);
        setLastEpochSeen(streamingData?.at(-1)?.data?.epoch ?? 0);
    };

    const commonPromptPrefix = useMemo(() => {
        let commonPrefix = "";

        if (prompts?.num_prompts === 1 || !prompts) {
            return commonPrefix;
        }

        commonPrefix = prompts?.prompts[0].prompt_text;
        prompts?.prompts.forEach((prompt) => {
            let count = 0;
            for (let i = 0; i < commonPrefix.length; i++) {
                if (prompt.prompt_text[i] === commonPrefix[i]) {
                    count++;
                } else {
                    break;
                }
            }
            commonPrefix = prompt.prompt_text.substr(0, count);
        });

        return commonPrefix;
    }, [prompts]);

    const combinedPromptData = useMemo(() => {
        const combinedPromptData: CombinedPromptData[] = [];
        // TODO: If the paginated array grows, need to sort and then walk through the arrays
        // in one loop instead of a bunch of .find()s
        prompts?.prompts
            ?.sort((a, b) => parseInt(a.id) - parseInt(b.id))
            .forEach((promptText) => {
                const promptReward = promptRewards?.prompts?.find((promptReward) => {
                    return promptText.id === promptReward.id;
                });
                combinedPromptData.push({ ...promptText, ...promptReward });
            });

        return combinedPromptData;
    }, [prompts, promptRewards]);

    const selectedPrompt = useMemo(() => {
        return combinedPromptData?.find((promptData) => promptData.id === selectedPromptId)?.prompt_text ?? "";
    }, [selectedPromptId, combinedPromptData]);

    // Load the first prompt's completions data upon page load
    useEffect(() => {
        if (selectedPromptId.length > 0) {
            return;
        }

        if (!Array.isArray(combinedPromptData) || combinedPromptData.length === 0) {
            return;
        }

        setSelectedPromptId(combinedPromptData[0].id);
    }, [combinedPromptData]);

    // Reset to default comparisons mode when user selects a prompt index
    useEffect(() => {
        setIsDefaultComparisonMode(true);
    }, [selectedPromptId]);

    // Recalculate best completion when reset to default mode (button or diff prompt)
    useEffect(() => {
        // Only update the user has taken an action to get them
        // back to the default mode
        if (!isDefaultComparisonMode) {
            return;
        }

        const latestEpoch = promptCompletions?.epochs?.at(-1);
        let bestCompletion = { epoch: latestEpoch?.epoch ?? 0, index: 0 };
        // Loop through Prompt Completions to find the best one
        let maxTotal = 0;
        latestEpoch?.completions?.forEach((completion) => {
            if (completion.rewards.total > maxTotal) {
                maxTotal = completion.rewards.total;
                bestCompletion = { epoch: latestEpoch?.epoch, index: completion.index };
            }
        });
        // Reset the environment
        setLeftComparison({ epoch: 1, index: 0 });
        setRightComparison(bestCompletion);
        setIsDefaultComparisonMode(true);
    }, [isDefaultComparisonMode, promptCompletions]);

    // Update caches and UI when new data comes in
    useEffect(() => {
        if (!websocketOpen || !Array.isArray(streamingData)) {
            return;
        }

        const timeoutId = setTimeout(() => {
            queryClient.invalidateQueries({ queryKey: GET_GRPO_PROMPT_REWARDS_QUERY_KEY(jobUUID) });
            queryClient.invalidateQueries({ queryKey: GET_GRPO_ALL_PROMPT_COMPLETIONS_QUERY_KEY(jobUUID) });
            const currentEpoch = streamingData?.at(-1)?.data?.epoch ?? 0;
            if (currentEpoch > lastEpochSeen) {
                setLastEpochSeen(currentEpoch);
                setIsNewDataAvailable(true);
            }
        }, websocketDelay);

        return () => clearTimeout(timeoutId);
    }, [streamingData, websocketOpen, websocketDelay]);

    if (promptCompletions === undefined || promptCompletions.epochs.length === 0) {
        return (
            <div style={{ display: "flex", flexDirection: "column", alignItems: "center" }}>
                <Divider hidden />
                <img src={"/model/emptyRepos.svg"} alt="" />
                <Header as="h2" size={"medium"} style={{ marginBottom: "0.5rem" }}>
                    No completions yet!
                </Header>
                <Divider hidden />
            </div>
        );
    }

    return (
        <>
            <div style={{ display: "flex", justifyContent: "space-between" }}>
                <SectionHeader>Prompt</SectionHeader>
                <a
                    href={`${getDocsHome()}/user-guide/fine-tuning/grpo#how-do-i-use-the-completions-tab`}
                    target="_blank"
                    rel="noreferrer"
                >
                    <Icon name="book" />
                    Learn how to analyze your model's generations during training
                </a>
            </div>
            <PromptSelector
                selectedPromptId={selectedPromptId}
                setSelectedPromptId={setSelectedPromptId}
                combinedPromptData={combinedPromptData}
                commonPromptPrefix={commonPromptPrefix}
                pageNumber={offset / promptsPerPage + 1}
                setOffset={setOffset}
                totalNumberOfPages={totalNumberOfPages}
            />
            <StyledSegment>
                <SegmentHeader>{`Full text of prompt ${selectedPromptId}`}</SegmentHeader>
                <ShowMoreText lines={6} keepNewLines={selectedPrompt.includes("\n")} anchorClass="show-more-text-link">
                    {selectedPrompt}
                </ShowMoreText>
            </StyledSegment>
            <div style={{ display: "flex", alignItems: "baseline", marginTop: "2.286rem" }}>
                <SectionHeader>Compare Completions</SectionHeader>
                {!isDefaultComparisonMode && (
                    <Button
                        color="blue"
                        size="tiny"
                        onClick={() => {
                            setIsDefaultComparisonMode(true);
                        }}
                        style={{ padding: "0.21429rem 1.5rem", height: "1.57143rem", fontSize: "" }}
                    >
                        Reset to default
                    </Button>
                )}
            </div>
            <CompletionsComparison
                isDefaultComparisonMode={isDefaultComparisonMode}
                isNewDataAvailable={isNewDataAvailable}
                resetDataStateTracking={resetDataStateTracking}
                leftComparison={leftComparison}
                setLeftComparison={setLeftComparison}
                rightComparison={rightComparison}
                setRightComparison={setRightComparison}
                promptCompletions={promptCompletions?.epochs ?? []}
            />
            <AllCompletions
                leftComparison={leftComparison}
                rightComparison={rightComparison}
                setLeftComparison={setLeftComparison}
                setRightComparison={setRightComparison}
                setIsDefaultComparisonMode={setIsDefaultComparisonMode}
                promptCompletions={promptCompletions?.epochs ?? []}
            />
        </>
    );
};

export default Completions;
