import {
  SamModel,
  AutoProcessor,
  RawImage,
  Tensor,
} from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.3";

// Reference the elements we will use
const statusLabel = document.getElementById("status");
const fileUpload = document.getElementById("upload");
const imageContainer = document.getElementById("container");
const example = document.getElementById("example");
const uploadButton = document.getElementById("upload-button");
const resetButton = document.getElementById("reset-image");
const clearButton = document.getElementById("clear-points");
const cutButton = document.getElementById("cut-mask");
const starIcon = document.getElementById("star-icon");
const crossIcon = document.getElementById("cross-icon");
const maskCanvas = document.getElementById("mask-output");
const maskContext = maskCanvas.getContext("2d");

const EXAMPLE_URL =
  "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg";

// State variables
let isEncoding = false;
let isDecoding = false;
let decodePending = false;
let lastPoints = null;
let isMultiMaskMode = false;
let imageInput = null;
let imageProcessed = null;
let imageEmbeddings = null;

async function decode() {
  // Only proceed if we are not already decoding
  if (isDecoding) {
    decodePending = true;
    return;
  }
  isDecoding = true;

  // Prepare inputs for decoding
  const reshaped = imageProcessed.reshaped_input_sizes[0];
  const points = lastPoints
    .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
    .flat(Infinity);
  const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);

  const num_points = lastPoints.length;
  const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
  const input_labels = new Tensor("int64", labels, [1, 1, num_points]);

  // Generate the mask
  const { pred_masks, iou_scores } = await model({
    ...imageEmbeddings,
    input_points,
    input_labels,
  });

  // Post-process the mask
  const masks = await processor.post_process_masks(
    pred_masks,
    imageProcessed.original_sizes,
    imageProcessed.reshaped_input_sizes,
  );

  isDecoding = false;

  updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);

  // Check if another decode is pending
  if (decodePending) {
    decodePending = false;
    decode();
  }
}

function updateMaskOverlay(mask, scores) {
  // Update canvas dimensions (if different)
  if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
    maskCanvas.width = mask.width;
    maskCanvas.height = mask.height;
  }

  // Allocate buffer for pixel data
  const imageData = maskContext.createImageData(
    maskCanvas.width,
    maskCanvas.height,
  );

  // Select best mask
  const numMasks = scores.length; // 3
  let bestIndex = 0;
  for (let i = 1; i < numMasks; ++i) {
    if (scores[i] > scores[bestIndex]) {
      bestIndex = i;
    }
  }
  statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;

  // Fill mask with colour
  const pixelData = imageData.data;
  for (let i = 0; i < pixelData.length; ++i) {
    if (mask.data[numMasks * i + bestIndex] === 1) {
      const offset = 4 * i;
      pixelData[offset] = 0; // red
      pixelData[offset + 1] = 114; // green
      pixelData[offset + 2] = 189; // blue
      pixelData[offset + 3] = 255; // alpha
    }
  }

  // Draw image data to context
  maskContext.putImageData(imageData, 0, 0);
}

function clearPointsAndMask() {
  // Reset state
  isMultiMaskMode = false;
  lastPoints = null;

  // Remove points from previous mask (if any)
  document.querySelectorAll(".icon").forEach((e) => e.remove());

  // Disable cut button
  cutButton.disabled = true;

  // Reset mask canvas
  maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
}
clearButton.addEventListener("click", clearPointsAndMask);

resetButton.addEventListener("click", () => {
  // Reset the state
  imageInput = null;
  imageProcessed = null;
  imageEmbeddings = null;
  isEncoding = false;
  isDecoding = false;

  // Clear points and mask (if present)
  clearPointsAndMask();

  // Update UI
  cutButton.disabled = true;
  imageContainer.style.backgroundImage = "none";
  uploadButton.style.display = "flex";
  statusLabel.textContent = "Ready";
});

async function encode(url) {
  if (isEncoding) return;
  isEncoding = true;
  statusLabel.textContent = "Extracting image embedding...";

  imageInput = await RawImage.fromURL(url);

  // Update UI
  imageContainer.style.backgroundImage = `url(${url})`;
  uploadButton.style.display = "none";
  cutButton.disabled = true;

  // Recompute image embeddings
  imageProcessed = await processor(imageInput);
  imageEmbeddings = await model.get_image_embeddings(imageProcessed);

  statusLabel.textContent = "Embedding extracted!";
  isEncoding = false;
}

// Handle file selection
fileUpload.addEventListener("change", function (e) {
  const file = e.target.files[0];
  if (!file) return;

  const reader = new FileReader();

  // Set up a callback when the file is loaded
  reader.onload = (e2) => encode(e2.target.result);

  reader.readAsDataURL(file);
});

example.addEventListener("click", (e) => {
  e.preventDefault();
  encode(EXAMPLE_URL);
});

// Attach hover event to image container
imageContainer.addEventListener("mousedown", (e) => {
  if (e.button !== 0 && e.button !== 2) {
    return; // Ignore other buttons
  }
  if (!imageEmbeddings) {
    return; // Ignore if not encoded yet
  }
  if (!isMultiMaskMode) {
    lastPoints = [];
    isMultiMaskMode = true;
    cutButton.disabled = false;
  }

  const point = getPoint(e);
  lastPoints.push(point);

  // add icon
  const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode();
  icon.style.left = `${point.position[0] * 100}%`;
  icon.style.top = `${point.position[1] * 100}%`;
  imageContainer.appendChild(icon);

  // Run decode
  decode();
});

// Clamp a value inside a range [min, max]
function clamp(x, min = 0, max = 1) {
  return Math.max(Math.min(x, max), min);
}

function getPoint(e) {
  // Get bounding box
  const bb = imageContainer.getBoundingClientRect();

  // Get the mouse coordinates relative to the container
  const mouseX = clamp((e.clientX - bb.left) / bb.width);
  const mouseY = clamp((e.clientY - bb.top) / bb.height);

  return {
    position: [mouseX, mouseY],
    label:
      e.button === 2 // right click
        ? 0 // negative prompt
        : 1, // positive prompt
  };
}

// Do not show context menu on right click
imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());

// Attach hover event to image container
imageContainer.addEventListener("mousemove", (e) => {
  if (!imageEmbeddings || isMultiMaskMode) {
    // Ignore mousemove events if the image is not encoded yet,
    // or we are in multi-mask mode
    return;
  }
  lastPoints = [getPoint(e)];

  decode();
});

// Handle cut button click
cutButton.addEventListener("click", async () => {
  const [w, h] = [maskCanvas.width, maskCanvas.height];

  // Get the mask pixel data (and use this as a buffer)
  const maskImageData = maskContext.getImageData(0, 0, w, h);

  // Create a new canvas to hold the cut-out
  const cutCanvas = new OffscreenCanvas(w, h);
  const cutContext = cutCanvas.getContext("2d");

  // Copy the image pixel data to the cut canvas
  const maskPixelData = maskImageData.data;
  const imagePixelData = imageInput.data;
  for (let i = 0; i < w * h; ++i) {
    const sourceOffset = 3 * i; // RGB
    const targetOffset = 4 * i; // RGBA

    if (maskPixelData[targetOffset + 3] > 0) {
      // Only copy opaque pixels
      for (let j = 0; j < 3; ++j) {
        maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
      }
    }
  }
  cutContext.putImageData(maskImageData, 0, 0);

  // Download image
  const link = document.createElement("a");
  link.download = "image.png";
  link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
  link.click();
  link.remove();
});

const model_id = "Xenova/slimsam-77-uniform";
statusLabel.textContent = "Loading model...";
const model = await SamModel.from_pretrained(model_id, {
  dtype: "fp16", // or "fp32"
  device: "webgpu",
});
const processor = await AutoProcessor.from_pretrained(model_id);
statusLabel.textContent = "Ready";

// Enable the user interface
fileUpload.disabled = false;
uploadButton.style.opacity = 1;
example.style.pointerEvents = "auto";