import { InferenceSession, Tensor } from "onnxruntime-web";
import npyjs from "npyjs";
import * as _ from "underscore";
import concaveman from "concaveman"
import simplify from "simplify-js"
import { onnxMaskClip, onnxMaskToPolygon, getImageByMask, downLoadImage } from "@antv/sam/dist/utils/mask";
import { tilesServer } from "../../../../AppConstants";
// import {cv2} from "opencv-react-ts"
// ---------------------- Draw on image start -------------------------------------//
// const InferenceSession = "4", Tensor  = "" ,npyjs="";
// 
const simplifyMultiPolygon = (multiPolygon, simplification = 0.5, concavity = 0) => {
    var left = Infinity;
    var right = -Infinity;
    var top = Infinity;
    var bottom = -Infinity;

    var originalPoints = [];
    _.each(multiPolygon, function (segmentation) {
        var distance, x, y, prev, dividers, newPoint;
        for (var i = 0; i <= segmentation.length - 2; i += 2) {
            x = segmentation[i];
            y = segmentation[i + 1];

            // make sure there are enough edge points along long lines that the
            // concave hull doesn't bust through too far by adding intermediary points
            prev = originalPoints[originalPoints.length - 1];
            if (!prev) {
                // if this is the first point, compare to the final point
                prev = [segmentation[segmentation.length - 2], segmentation[segmentation.length - 1]];
            }

            distance = Math.sqrt((prev[0] - x) * (prev[0] - x) + (prev[1] - y) * (prev[1] - y));

            dividers = Math.floor(distance / 5); // if the points are farther than 5 pixels apart, add some guides

            for (var j = 1; j < dividers; j++) {
                newPoint = [
                    prev[0] + (x - prev[0]) * j / dividers,
                    prev[1] + (y - prev[1]) * j / dividers
                ];
                originalPoints.push(newPoint);
            }

            originalPoints.push([x, y]);
        }
    });

    // get a concave hull & then simplify it to remove our added dividers
    var hull;
    if (concavity === 0) {
        // disable; used eg in smart polygon when you want complex shapes to be allowed
        hull = originalPoints;
    } else {
        hull = concaveman(originalPoints, concavity);
    }
    var simplifiedHull = simplify(_.map(hull, function (p) {
        return {
            x: p[0],
            y: p[1]
        };
    }), simplification, true);

    let points = [];
    _.each(simplifiedHull, function (point) {
        if (point.x < left) left = point.x;
        if (point.x > right) right = point.x;
        if (point.y > bottom) bottom = point.y;
        if (point.y < top) top = point.y;

        points.push([point.x, point.y]);
    });

    return [points, {
        left: left,
        right: right,
        top: top,
        bottom: bottom,
        width: right - left,
        height: bottom - top,
        x: (left + right) / 2,
        y: (top + bottom) / 2
    }];
}
// Convert the onnx model mask prediction to ImageData
async function arrayToImageData(input, width, height) {
    const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
    const arr = new Uint8ClampedArray(4 * width * height).fill(0);

    for (let i = 0; i < input.length; i++) {
        if (input[i] > 0.0) {
            arr[4 * i + 0] = r;
            arr[4 * i + 1] = g;
            arr[4 * i + 2] = b;
            arr[4 * i + 3] = a;
        }
    }

    return new ImageData(arr, height, width);
}

// Use a Canvas element to produce an image from ImageData
async function imageDataToImage(imageData) {
    const canvas = imageDataToCanvas(imageData);
    const image = new Image();
    image.src = canvas.toDataURL();
    return image;
}

// Canvas elements can be created from ImageData
function imageDataToCanvas(imageData) {
    const canvas = document.createElement("canvas");
    const ctx = canvas.getContext("2d");
    canvas.width = imageData.width;
    canvas.height = imageData.height;
    if (ctx) ctx.putImageData(imageData, 0, 0);
    return canvas;
}

// Convert the onnx model mask output to an HTMLImageElement
async function onnxMaskToImage(input, width, height) {
    return await imageDataToImage(await arrayToImageData(input, width, height));
}
// ---------------------- Draw on image end -------------------------------------//

// ---------------------- ONNX setup and model setup -------------------------------------//
const modelData = ({ clicks, tensor, modelScale }) => {
    const imageEmbedding = tensor;
    let pointCoords;
    let pointLabels;
    let pointCoordsTensor;
    let pointLabelsTensor;

    // Check there are input click prompts
    if (clicks) {
        let n = clicks.length;

        // If there is no box input, a single padding point with 
        // label -1 and coordinates (0.0, 0.0) should be concatenated
        // so initialize the array to support (n + 1) points.
        pointCoords = new Float32Array(2 * (n + 1));
        pointLabels = new Float32Array(n + 1);

        // Add clicks and scale to what SAM expects
        for (let i = 0; i < n; i++) {
            pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
            pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
            pointLabels[i] = clicks[i].clickType;
        }

        // Add in the extra point/label when only clicks and no box
        // The extra point is at (0, 0) with label -1
        pointCoords[2 * n] = 0.0;
        pointCoords[2 * n + 1] = 0.0;
        pointLabels[n] = -1.0;

        // Create the tensor
        pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
        pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
    }
    const imageSizeTensor = new Tensor("float32", [
        modelScale.height,
        modelScale.width,
    ]);

    if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
        return;

    // There is no previous mask, so default to an empty tensor
    const maskInput = new Tensor(
        "float32",
        new Float32Array(256 * 256),
        [1, 1, 256, 256]
    );
    // There is no previous mask, so default to 0
    const hasMaskInput = new Tensor("float32", [0]);

    return {
        image_embeddings: imageEmbedding,
        point_coords: pointCoordsTensor,
        point_labels: pointLabelsTensor,
        orig_im_size: imageSizeTensor,
        mask_input: maskInput,
        has_mask_input: hasMaskInput,
    };
};

const getModelScale = async (image) => {
    // Input images to SAM must be resized so the longest side is 1024
    const LONG_SIDE_LENGTH = 1024;
    let w = image.naturalWidth;
    let h = image.naturalHeight;
    const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
    return { height: h, width: w, samScale };
};

// Decode a Numpy file into a tensor. 
const loadNpyTensor = async (tensorFile, dType = "float32") => {
    return new Promise(async (resolve, reject) => {

        try {
            let npLoader = new npyjs();
            const npArray = await npLoader.load(tensorFile);
            console.log(npArray)
            const tensor = new Tensor(dType, npArray.data, npArray.shape);
            return resolve(tensor);
        } catch (err) {
            reject()
        }
    })
};


// Load the Segment Anything pre-computed embedding
const runModel = async (model, modelOutput, clicks, { modelScale, tensor }, isPolygonOutput = true) => {
    try {
        // Preapre the model input in the correct format for SAM. 

        const feeds = modelData({
            clicks, // imageEmbeddings [{ x, y, clickType(1) }]
            tensor, // loadNpyTensor() 
            modelScale, // {height,width,samScale} imageBasis
        });

        if (feeds === undefined) return;

        // Run the SAM ONNX model with the feeds returned from modelData()
        const results = await model.run(feeds);
        const output = results[model.outputNames[0]];
        modelOutput = output
        // The predicted mask returned from the ONNX model is an array which is 
        // rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
        // if (isPolygonOutput) {
        //     const result = await onnxMaskToPolygon(output.data, output.dims[2], output.dims[3]);
        //     return { output, result }
        // }
        // else {
        // const result = await onnxMaskToImage(output.data, output.dims[2], output.dims[3]);
        return { output }
        // }

    } catch (e) {
        console.log(e);
    }
};

const getImageEmbeddings = async (imageURL) => {
    console.log("...getting image embeddings")
    // get image embeddings numpy file logic
    return (`${tilesServer}/get_image_embeddings?imageUrl=${encodeURIComponent(imageURL)}`)
}


const initModel = async (MODEL_DIR) => {
    try {
        console.log("entred in intit model!", MODEL_DIR)
        return (await InferenceSession.create(MODEL_DIR))
    } catch (e) {
        console.log(e);
    }
};

// ---------------------- ONNX setup and model setup END -------------------------------------//


class SamModel {
    constructor(modelURL) {
        this.model = null
        this.imageData = {}
        this.maskImg = null
        this.modelURL = modelURL
        this.modelOutput = null
        // console.log(modelURL)
        this.initModelData()
    }

    initModelData = async () => {
        if (!this.model) {
            this.model = await initModel(this.modelURL)
        }
    }

    changeImage = async (imageUrl) => {

        return new Promise(async (resolve, reject) => {
            try {
                await this.initImageEmbeddings(imageUrl)
                resolve()
            } catch (err) {
                reject(err)
            }
        })
    }

    initImageEmbeddings = (imageUrl) => {

        console.log("entered in init imageembedding")
        return new Promise((resolve, reject) => {

            let imageHTML = document.createElement('img');
            imageHTML.src = imageUrl;
            imageHTML.addEventListener('load', async () => {
                try {
                    console.log("loaded in init imageembedding")

                    const modelScale = await getModelScale(imageHTML)
                    const IMAGE_EMBEDDING = await getImageEmbeddings(imageUrl)
                    const tensor = await loadNpyTensor(IMAGE_EMBEDDING)
                    this.imageData = {
                        modelScale,
                        tensor,
                        imageUrl,
                        imageHTML
                    }
                    resolve()
                } catch (err) {
                    reject(err)
                }
            })
        })
    }

    getPrediction = async (point, isPolygonOutput) => {
        if (point) {

            const { output } = await runModel(this.model, this.modelOutput, [{ ...point, clickType: 1 }], this.imageData, isPolygonOutput)
            this.modelOutput = output
        }
    }
    drawPolygon = async (output) => {
        var polygons = onnxMaskToPolygon(output.data, output.dims[3], output.dims[2], 2);

        // simplify multipolygons
        const result = _.map(polygons.geometry.coordinates, function (polygon) {
            // sort points so closest ones are nearest each other to assist simplifyMultiPolygon
            var closestIndex;
            var closestDistance;
            var cursor, next, distance, placeholder, i, j;
            for (i = 0; i < polygon.length - 1; i++) {
                cursor = polygon[i];
                closestIndex = i + 1;
                closestDistance = Infinity;

                // find closest point spatially
                for (j = i + 1; j < polygon.length; j++) {
                    next = polygon[j];
                    distance = Math.sqrt((cursor[0] - next[0]) * (cursor[0] - next[0]) + (cursor[1] - next[1]) * (cursor[1] - next[1]));
                    if (distance < closestDistance) {
                        closestDistance = distance;
                        closestIndex = j;
                    }
                }

                // swap closest point with the next position in the array
                placeholder = polygon[closestIndex];
                polygon[closestIndex] = polygon[i + 1];
                polygon[i + 1] = placeholder;
            }

            // simplify to reduce number of vertices; we will compress down multipolygons into single polygons later
            return simplifyMultiPolygon(
                [_.flatten(polygon)], 1, 1.5 // settings chosen by trial and error
            )[0] // we just want the points

        });

        // return the largest polygon ,bcoz others are unnecessary polygons detected
        const largestPolygon = result.reduce(function (prev, current) {
            return (current.length > prev.length) ? current : prev;
        }, result[0]);
        return largestPolygon
    }
    downloadImage = (image) => { downLoadImage(image) }
}

export {
    SamModel
}