import { useMutation } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { Header, Message, Segment } from "semantic-ui-react";
import NumberInput from "semantic-ui-react-numberinput";
import { useAuth0TokenOptions } from "../../data";
import { detectEngines } from "../../engines/data";
import EngineNodeTable from "../../engines/EngineNodeTable";
import EngineSelector from "../../engines/EngineSelector";
import { EngineServiceType } from "../../types/engineServiceType";
import { SEMANTIC_BLACK, SEMANTIC_BLUE } from "../../utils/colors";
import { getErrorMessage } from "../../utils/errors";
import { useConfigState, useDispatch } from "./store";

const defaultTextStyling = {
    marginTop: `${10 / 28}rem`,
    color: SEMANTIC_BLUE,
    fontSize: `${12 / 14}rem`,
};

function pow2ceil(x: number) {
    return Math.pow(2, Math.ceil(Math.log(x) / Math.log(2)));
}

function ComputeConfigManager(props: { dataset?: Dataset; engine?: Engine; setEngine: (engine: Engine) => void }) {
    const dispatch = useDispatch();
    const [template, setTemplate] = useState<EngineTemplate | null>(null);
    const [errorMessage, setErrorMessage] = useState<string | null>(null);
    const { config } = useConfigState();

    const auth0TokenOptions = useAuth0TokenOptions();

    const { mutate: detect } = useMutation({
        mutationFn: () => detectEngines(props.dataset!.id, config, props.engine!.id, auth0TokenOptions),
        onSuccess: (data) => {
            setTemplate(data);
            setErrorMessage(null);
        },
        onError: (error) => {
            setErrorMessage(getErrorMessage(error));
        },
    });

    useEffect(() => {
        if (!props.dataset || !props.engine) {
            return;
        }
        detect();
    }, [config, props.dataset, props.engine]);

    const setNumWorkers = (newValue: string) => {
        let numWorkers = pow2ceil(Number(newValue));
        dispatch({ type: "UPDATE_CONFIG_PROPERTY", field: "backend.trainer.num_workers", value: numWorkers });
    };

    if (!config) {
        return null;
    }

    // Use the explicitly set number of workers in the config if provided, otherwise fallback to the
    // adaptve engine recommendation.
    let numWorkers = config?.backend?.trainer?.num_workers || template?.suggestedWorkers;

    // Because the NumberInput component can only increase / decrease by a fixed amount, in order to instead
    // go up and down by powers of 2, we need to set the step amount to the current value / 2 (equal to the true
    // decrement amount), then round up to the nearest power of 2 in the onChange callback (to handle incrementing).
    //
    // Example:
    //     numWorkers = 4
    //     stepAmount = 4 / 2 = 2
    //     numWorkersDown = pow2ceil(4 - 2) = 2
    //     numWorkersUp = pow2ceil(4 + 2) = 8
    let stepAmount = numWorkers > 1 ? numWorkers / 2 : 1;

    return (
        <div>
            <Message floating info style={{ color: SEMANTIC_BLACK }}>
                This page shows the compute resources that will be used to train this model version.
            </Message>
            <Segment raised style={{ padding: "1.75rem" }}>
                <Header className="header" as="h2" size="small">
                    Engine
                </Header>
                <EngineSelector
                    label={null}
                    engineID={props.engine?.id}
                    onEngineSelect={props.setEngine}
                    validate={true}
                    serviceTypes={[EngineServiceType.BATCH, EngineServiceType.RAY]}
                    onlyAllowTrainableEngines={true}
                />
                {template?.adaptive && (
                    <>
                        <Header className="header" as="h2" size="small">
                            Training Workers
                            <Header.Subheader>
                                Increase to speed up training, decrease to reduce costs and effective batch size.
                            </Header.Subheader>
                        </Header>
                        <NumberInput
                            className="numberInput"
                            value={numWorkers.toString()}
                            onChange={setNumWorkers}
                            stepAmount={stepAmount}
                            minValue={1}
                            maxValue={template.maxWorkers}
                        />
                        <p style={{ ...defaultTextStyling }}>Recommended: {template.suggestedWorkers}</p>
                    </>
                )}
            </Segment>
            <Segment raised style={{ padding: "1.75rem" }}>
                {errorMessage && <Message negative>{errorMessage}</Message>}
                {template && <EngineNodeTable template={template} />}
            </Segment>
        </div>
    );
}

export default ComputeConfigManager;
