import { ValidateFunction } from "ajv";
import type { JSONSchema7 } from "json-schema";
import _ from "lodash";
import React, { memo, useEffect, useState } from "react";
import ReactFlow, {
    Controls,
    Edge,
    Node,
    Position,
    ReactFlowProvider,
    useEdgesState,
    useNodesState,
} from "react-flow-renderer";
import { ModelTypes } from "../../types/model/modelTypes";
import { SEMANTIC_GREY } from "../../utils/colors";
import { findAllOfSchema } from "../../utils/jsonSchema";
import { truncateStringWithPopup } from "../../utils/overflow";
import { noFunctionCompare } from "../../utils/reactUtils";
import { DecoderForm, EncoderForm } from "../create/forms/CoderForm";
import CombinerForm from "../create/forms/CombinerForm";
import ConfigSetModal from "../create/forms/ConfigSetModal";
import LLMForm from "../create/forms/LLMForm";
import { InputPreprocessingForm, OutputPostprocessingForm } from "../create/forms/PreprocessingForm";

const MetricRoot = "Model.Config.Overview";

const XPAD = 180;
const YPAD = 80;

const TRUNCATE_LIMIT = XPAD / 10 + 2;

const nodeStyle = { border: "1px solid #222138" };

const columnHeader = (text: string) => {
    return (
        <span
            style={{
                position: "absolute",
                color: SEMANTIC_GREY,
                top: "-30px",
                left: 0,
                right: 0,
                textAlign: "center",
                pointerEvents: "none",
            }}
        >
            {text}
        </span>
    );
};

const columnFooter = (text: string) => {
    return (
        <span
            style={{
                whiteSpace: "nowrap",
                position: "absolute",
                color: SEMANTIC_GREY,
                bottom: "-30px",
                left: 0,
                right: 0,
                textAlign: "center",
                pointerEvents: "none",
            }}
        >
            {text}
        </span>
    );
};

const getEncoderType = (config: CreateModelConfig, featureIndex: number, schema: JSONSchema7) => {
    let feature = config.input_features[featureIndex];
    if (!feature) {
        return null;
    }
    const featureTypeSchema = findAllOfSchema(schema?.properties?.input_features as JSONSchema7, feature.type);
    // @ts-expect-error
    const featureDefaultEncoder = featureTypeSchema?.properties?.encoder?.properties?.type?.default;
    const encoder = config.input_features[featureIndex]?.encoder;

    if (!encoder) return featureDefaultEncoder;
    return encoder.type;
};

const getDecoderType = (config: CreateModelConfig, featureIndex: number, schema: JSONSchema7) => {
    let feature = config.output_features?.[featureIndex];
    if (!feature) {
        return null;
    }
    const featureTypeSchema = findAllOfSchema(schema?.properties?.output_features as JSONSchema7, feature.type);
    // @ts-expect-error
    const featureDefaultDecoder = featureTypeSchema?.properties?.decoder?.properties?.type?.default;
    const decoder = config.output_features[featureIndex]?.decoder;

    if (!decoder) return featureDefaultDecoder;
    return decoder.type;
};

const getCombinerType = (config: CreateModelConfig, schema: JSONSchema7) => {
    return config.combiner?.type || _.get(schema, "properties.combiner.properties.type.default");
};

const getLLMType = (config: CreateModelConfig, schema: JSONSchema7) =>
    config?.model_name || _.get(schema, "properties.model_name.default");

const components: Record<string, React.ComponentType<any>> = {
    combiner: CombinerForm,
    encoder: EncoderForm,
    decoder: DecoderForm,
    input_preproc: InputPreprocessingForm,
    output_postproc: OutputPostprocessingForm,
    llm: LLMForm,
};

export const buildGraphElements = (
    config: CreateModelConfig,
    modelType: ModelTypes,
    openModal: (header: string, component: string, idx: number) => void,
    schema?: JSONSchema7,
) => {
    if (!schema) {
        return { nodes: [], edges: [] };
    }

    const inputCount = config.input_features.length;
    const outputCount = config.output_features.length;

    const inputHeight = (inputCount - 1) * YPAD;
    const outputHeight = (outputCount - 1) * YPAD;
    const height = Math.max(inputHeight, outputHeight);

    const yCenter = height / 2;
    const yInputOrigin = yCenter - inputHeight / 2;
    const yOutputOrigin = yCenter - outputHeight / 2;

    let xOffset = 0;
    let inputNodes: Node[] = config.input_features.map((feature, idx) => {
        return {
            id: `input-${idx}`,
            type: "input",
            style: {
                ...nodeStyle,
                background: "#efefef",
            },
            sourcePosition: Position.Right,
            targetPosition: Position.Left,
            data: {
                label:
                    idx === 0 ? (
                        <>
                            {columnHeader("Input Feature")}
                            {truncateStringWithPopup(feature.name, TRUNCATE_LIMIT)}
                        </>
                    ) : (
                        truncateStringWithPopup(feature.name, TRUNCATE_LIMIT)
                    ),
                onClick: () => {},
            },
            position: { x: xOffset * XPAD, y: yInputOrigin + YPAD * idx },
        };
    });

    xOffset += 1;
    let preprocNodes: Node[] = config.input_features.map((feature, idx) => {
        return {
            id: `preproc-${idx}`,
            style: {
                ...nodeStyle,
                background: "#fff1cc",
            },
            sourcePosition: Position.Right,
            targetPosition: Position.Left,
            data: {
                label:
                    idx === 0 ? (
                        <>
                            {columnHeader("Preprocessing")}
                            {feature.type}
                        </>
                    ) : (
                        feature.type
                    ),
                onClick: () => openModal(feature.name + " Preprocessing", "input_preproc", idx),
            },
            position: { x: xOffset * XPAD, y: yInputOrigin + YPAD * idx },
        };
    });

    let encoderNodes: Node[] = [];
    if (modelType === ModelTypes.NEURAL_NETWORK) {
        xOffset += 1;
        encoderNodes = config.input_features.map((feature, idx) => {
            return {
                id: `encoder-${idx}`,
                style: {
                    ...nodeStyle,
                    background: "#ffd866",
                },
                sourcePosition: Position.Right,
                targetPosition: Position.Left,
                data: {
                    label:
                        idx === 0 ? (
                            <>
                                {columnHeader("Encoder")}
                                {getEncoderType(config, idx, schema)}
                            </>
                        ) : (
                            getEncoderType(config, idx, schema)
                        ),
                    onClick: () => openModal(feature.name + " Encoder", "encoder", idx),
                },
                position: { x: xOffset * XPAD, y: yInputOrigin + YPAD * idx },
            };
        });
    }

    let combinerNodes: Node[] = [];
    if (modelType === ModelTypes.NEURAL_NETWORK) {
        xOffset += 1;
        combinerNodes = [
            {
                id: "combiner",
                style: {
                    ...nodeStyle,
                    background: "#f6b26b",
                },
                sourcePosition: Position.Right,
                targetPosition: Position.Left,
                data: {
                    label: (
                        <>
                            {columnHeader("Combiner")}
                            {getCombinerType(config, schema)}
                        </>
                    ),
                    onClick: () => openModal("Combiner", "combiner", 0),
                },
                position: { x: xOffset * XPAD, y: yCenter },
            },
        ];
    }

    let treeNodes: Node[] = [];
    if (modelType === ModelTypes.DECISION_TREE) {
        xOffset += 1;
        treeNodes = [
            {
                id: "tree",
                style: {
                    ...nodeStyle,
                    background: "#f6b26b",
                },
                sourcePosition: Position.Right,
                targetPosition: Position.Left,
                data: {
                    label: (
                        <>
                            {columnHeader("Tree")}
                            'tree'
                        </>
                    ),
                    onClick: () => {},
                },
                position: { x: xOffset * XPAD, y: yCenter },
            },
        ];
    }

    let llmNodes: Node[] = [];
    if (modelType === ModelTypes.LARGE_LANGUAGE_MODEL) {
        xOffset += 1;
        llmNodes = [
            {
                id: "llm",
                style: {
                    ...nodeStyle,
                    background: "#f6b26b",
                },
                sourcePosition: Position.Right,
                targetPosition: Position.Left,
                data: {
                    label: (
                        <>
                            {columnHeader("LLM")}
                            {getLLMType(config, schema) || "language_model"}
                            {config?.base_model && columnFooter(config.base_model)}
                        </>
                    ),
                    onClick: () => openModal("Large Language Model", "llm", -1),
                },
                position: { x: xOffset * XPAD, y: yCenter },
            },
        ];
    }

    let decoderNodes: Node[] = [];
    if (modelType === ModelTypes.NEURAL_NETWORK || modelType === ModelTypes.LARGE_LANGUAGE_MODEL) {
        xOffset += 1;
        decoderNodes = config.output_features.map((feature, idx) => {
            return {
                id: `decoder-${idx}`,
                style: {
                    ...nodeStyle,
                    background: "#e99a99",
                },
                sourcePosition: Position.Right,
                targetPosition: Position.Left,
                data: {
                    label:
                        idx === 0 ? (
                            <>
                                {columnHeader("Decoder")}
                                {getDecoderType(config, idx, schema)}
                            </>
                        ) : (
                            getDecoderType(config, idx, schema)
                        ),
                    onClick: () => openModal(feature.name + " Decoder", "decoder", idx),
                },
                position: { x: xOffset * XPAD, y: yOutputOrigin + YPAD * idx },
            };
        });
    }

    xOffset += 1;
    let postprocNodes: Node[] = config.output_features.map((feature, idx) => {
        return {
            id: `postproc-${idx}`,
            style: {
                ...nodeStyle,
                background: "#f5cbcc",
            },
            sourcePosition: Position.Right,
            targetPosition: Position.Left,
            data: {
                label:
                    idx === 0 ? (
                        <>
                            {columnHeader("Postprocessing")}
                            {feature.type}
                        </>
                    ) : (
                        feature.type
                    ),
                onClick: () => openModal(feature.name + " Postprocessing", "output_postproc", idx),
            },
            position: { x: xOffset * XPAD, y: yOutputOrigin + YPAD * idx },
        };
    });

    xOffset += 1;
    let outputNodes = config.output_features.map((feature, idx) => {
        return {
            id: `output-${idx}`,
            type: "output",
            style: {
                ...nodeStyle,
                background: "#efefef",
            },
            sourcePosition: Position.Right,
            targetPosition: Position.Left,
            data: {
                label:
                    idx === 0 ? (
                        <>
                            {columnHeader("Output Feature")}
                            {truncateStringWithPopup(feature.name, TRUNCATE_LIMIT)}
                        </>
                    ) : (
                        truncateStringWithPopup(feature.name, TRUNCATE_LIMIT)
                    ),
                onClick: () => {},
            },
            position: { x: xOffset * XPAD, y: yOutputOrigin + YPAD * idx },
        };
    });

    let inputEdges: Edge[] = [];
    for (let i = 0; i < inputCount; i++) {
        inputEdges.push({
            id: `input-preproc-${i}`,
            source: `input-${i}`,
            type: "smoothstep",
            target: `preproc-${i}`,
            animated: true,
        });

        switch (modelType) {
            case ModelTypes.LARGE_LANGUAGE_MODEL:
                inputEdges.push({
                    id: `preproc-llm-${i}`,
                    source: `preproc-${i}`,
                    type: "smoothstep",
                    target: "llm",
                    animated: true,
                });
                break;
            case ModelTypes.DECISION_TREE:
                inputEdges.push({
                    id: `preproc-tree-${i}`,
                    source: `preproc-${i}`,
                    type: "smoothstep",
                    target: "tree",
                    animated: true,
                });
                break;
            default:
                inputEdges.push({
                    id: `preproc-encoder-${i}`,
                    source: `preproc-${i}`,
                    type: "smoothstep",
                    target: `encoder-${i}`,
                    animated: true,
                });
                inputEdges.push({
                    id: `encoder-combiner-${i}`,
                    source: `encoder-${i}`,
                    type: "smoothstep",
                    target: "combiner",
                    animated: true,
                });
                break;
        }
    }

    let outputEdges: Edge[] = [];
    for (let i = 0; i < outputCount; i++) {
        switch (modelType) {
            case ModelTypes.DECISION_TREE:
                outputEdges.push({
                    id: `tree-postproc-${i}`,
                    source: "tree",
                    type: "smoothstep",
                    target: `postproc-${i}`,
                    animated: true,
                });
                break;
            case ModelTypes.LARGE_LANGUAGE_MODEL:
                outputEdges.push({
                    id: `llm-decoder-${i}`,
                    source: "llm",
                    type: "smoothstep",
                    target: `decoder-${i}`,
                    animated: true,
                });
                outputEdges.push({
                    id: `decoder-postproc-${i}`,
                    source: `decoder-${i}`,
                    type: "smoothstep",
                    target: `postproc-${i}`,
                    animated: true,
                });
                break;
            default:
                outputEdges.push({
                    id: `combiner-decoder-${i}`,
                    source: "combiner",
                    type: "smoothstep",
                    target: `decoder-${i}`,
                    animated: true,
                });
                outputEdges.push({
                    id: `decoder-postproc-${i}`,
                    source: `decoder-${i}`,
                    type: "smoothstep",
                    target: `postproc-${i}`,
                    animated: true,
                });
                break;
        }

        outputEdges.push({
            id: `postproc-output-${i}`,
            source: `postproc-${i}`,
            type: "smoothstep",
            target: `output-${i}`,
            animated: true,
        });
    }

    const nodes = [
        ...inputNodes,
        ...preprocNodes,
        ...encoderNodes,
        ...combinerNodes,
        ...treeNodes,
        ...llmNodes,
        ...decoderNodes,
        ...postprocNodes,
        ...outputNodes,
    ];

    const edges = [...inputEdges, ...outputEdges];

    return { nodes, edges };
};

const ModelFlow = (props: {
    config: CreateModelConfig;
    schema?: JSONSchema7;
    modelType: ModelTypes;
    openModal: (header: string, component: string, idx: number) => void;
}) => {
    const [nodes, setNodes] = useNodesState([]);
    const [edges, setEdges] = useEdgesState([]);

    useEffect(() => {
        const { nodes, edges } = buildGraphElements(props.config, props.modelType, props.openModal, props.schema);
        setNodes(nodes);
        setEdges(edges);
    }, [props.config, props.modelType, props.schema, props.openModal, setEdges, setNodes]);

    if (!props.schema) {
        return null;
    }

    return (
        <ReactFlow
            nodes={nodes}
            edges={edges}
            onNodeClick={(event, node) => node.data.onClick()}
            nodesDraggable={false}
            zoomOnScroll={false}
            preventScrolling={false}
            minZoom={0.1}
            fitView
        >
            <Controls showInteractive={false} />
        </ReactFlow>
    );
};

const ModelGraph = (props: {
    applyConfig: (config?: CreateModelConfig) => void;
    config?: CreateModelConfig;
    modelType: ModelTypes;
    readOnly: boolean;
    schema?: JSONSchema7;
    validator?: ValidateFunction;
    expectedImpactFilter: number;
    setExpectedImpactFilter: (value: number) => void;
}) => {
    const [modalOpen, setModalOpen] = useState(false);
    const [modalIndex, setModalIndex] = useState(0);
    const [modalHeader, setModalHeader] = useState<string>();
    const [modalComponentName, setModalComponentName] = useState<string>();

    if (!props.config) {
        return null;
    }

    const openModal = (header: string, component: string, idx: number) => {
        setModalIndex(idx);
        setModalHeader(header);
        setModalComponentName(component);
        setModalOpen(true);
    };

    let ModalComponent = modalComponentName ? components[modalComponentName] : null;
    return (
        <div style={{ width: "100%", height: "80vh" }}>
            <ReactFlowProvider>
                <ModelFlow
                    config={props.config}
                    schema={props.schema}
                    modelType={props.modelType}
                    openModal={openModal}
                />
            </ReactFlowProvider>
            {ModalComponent !== null && (
                <ConfigSetModal
                    config={props.config}
                    schema={props.schema}
                    validator={props.validator}
                    applyConfig={props.applyConfig}
                    readOnly={props.readOnly}
                    metricRoot={MetricRoot}
                    featureIndex={modalIndex}
                    modalOpen={modalOpen}
                    setModalOpen={setModalOpen}
                    modalHeader={modalHeader as string}
                    ModalComponent={ModalComponent}
                    expectedImpactFilter={props.expectedImpactFilter}
                    setExpectedImpactFilter={props.setExpectedImpactFilter}
                />
            )}
        </div>
    );
};

export default memo(ModelGraph, noFunctionCompare);
