import * as tf from "@tensorflow/tfjs";

const getContrast = (contrast: number) => {
    contrast = Number(contrast);
    if (contrast < -100) {
        contrast = -100;
    }

    if (contrast > 100) {
        contrast = 100;
    }
    contrast = (100.0 + contrast) / 100.0;
    contrast *= contrast;
    return contrast;
};

const imageByteArray = (image: ImageData, numChannels: number) => {
    const pixels = image.data;
    const numPixels = image.width * image.height;
    const values = new Int32Array(numPixels * numChannels);

    for (let i = 0; i < numPixels; i++) {
        for (let channel = 0; channel < numChannels; ++channel) {
            values[i * numChannels + channel] = pixels[i * 4 + channel];
        }
    }

    return values;
};

const imageToInput = (image: ImageData, numChannels: number) => {
    const values = imageByteArray(image, numChannels);
    return tf.tensor3d(values, [image.height, image.width, numChannels], "int32");
};

export const applyPixelContrast = (r: number, contrast: number) => {
    if (r > 150) {
        r = Math.min(r + contrast / 2, 255);
    }
    r /= 255;
    r -= 0.5;
    r *= contrast;
    r += 0.5;
    r *= 255;
    if (r < 0) {
        r = 0;
    }
    if (r > 255) {
        r = 255;
    }
    return Math.floor(r);
};

export const applyContrast = (imageData: ImageData, contrast = 10): ImageData => {
    const imageDataClone = new ImageData(new Uint8ClampedArray(imageData.data), imageData.width, imageData.height);

    contrast = getContrast(contrast);

    for (let x = 0; x < imageData.width; x++) {
        for (let y = 0; y < imageData.height; y++) {
            const posR = (y * imageData.width + x) * 4;
            const posG = posR + 1;
            const posB = posG + 1;
            let r = imageDataClone.data[posR];
            let g = imageDataClone.data[posG];
            let b = imageDataClone.data[posB];
            r = applyPixelContrast(r, contrast);
            g = applyPixelContrast(g, contrast);
            b = applyPixelContrast(b, contrast);
            imageData.data[posR] = r;
            imageData.data[posG] = g;
            imageData.data[posB] = b;
        }
    }
    return imageData;
};

export const aiScan = (model: tf.LayersModel, images: ImageData[]): Map<ImageData, string> => {
    const res = new Map();
    if (!model) {
        return res;
    }
    tf.tidy(() => {
        const tensors = images.map(imgData => imageToInput(imgData, 3));
        const pred = model.predict(tf.stack(tensors)) as tf.Tensor;
        (pred.argMax(1).arraySync() as number[]).forEach((a, idx) => {
            res.set(images[idx], a === 0 ? "ai" : null);
        });
    });
    return res;
};
