// Remove Background — runs an ONNX segmentation model entirely in the browser
// via onnxruntime-web. Default model is the quantized briaai RMBG-1.4 (~44 MB,
// downloaded once then cached by the browser). Users can paste any compatible
// ONNX segmentation model URL if the default is unreachable.
//
// Pipeline:
//  1. Resize image to MODEL_SIZE × MODEL_SIZE on a canvas.
//  2. Convert RGBA → CHW float32, range [-1, 1] for RMBG (see PREPROCESS).
//  3. Run ort.InferenceSession.run() — returns a single mask tensor.
//  4. Resize the mask back to original dimensions, multiply into the image's
//     alpha channel, export as PNG.

const REMOVE_BG_ORT_URL   = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/ort.min.js';
const REMOVE_BG_ORT_BASE  = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/';
const REMOVE_BG_DEFAULT_MODEL = 'https://huggingface.co/briaai/RMBG-1.4/resolve/main/onnx/model_quantized.onnx';
const REMOVE_BG_MODEL_SIZE = 1024;
// RMBG-1.4: normalize pixels to [-1, 1]
const REMOVE_BG_MEAN = 0.5;
const REMOVE_BG_STD  = 1.0;
const REMOVE_BG_MODEL_STORAGE_KEY = 'mm-remove-bg-model-url';
const REMOVE_BG_ALLOWED_HOSTS = ['huggingface.co', 'cdn.jsdelivr.net', 'unpkg.com', 'github.com', 'raw.githubusercontent.com'];

function removeBgIsValidModelUrl(url) {
  try {
    const u = new URL(url);
    if (u.protocol !== 'https:') return false;
    return REMOVE_BG_ALLOWED_HOSTS.some((h) => u.hostname === h || u.hostname.endsWith('.' + h));
  } catch { return false; }
}

async function removeBgEnsureOrt() {
  if (window.ort) return window.ort;
  if (window.__ortLoading) return window.__ortLoading;
  window.__ortLoading = new Promise((resolve, reject) => {
    const s = document.createElement('script');
    s.src = REMOVE_BG_ORT_URL;
    s.onload = () => {
      try { window.ort.env.wasm.wasmPaths = REMOVE_BG_ORT_BASE; } catch {}
      resolve(window.ort);
    };
    s.onerror = () => reject(new Error('Failed to load onnxruntime-web'));
    document.head.appendChild(s);
  });
  return window.__ortLoading;
}

async function removeBgFetchModel(url, onProgress) {
  const res = await fetch(url);
  if (!res.ok) throw new Error(`Model download failed: ${res.status}`);
  const total = parseInt(res.headers.get('content-length') || '0', 10);
  if (!res.body || !total) return new Uint8Array(await res.arrayBuffer());
  const reader = res.body.getReader();
  const chunks = [];
  let received = 0;
  while (true) {
    const { done, value } = await reader.read();
    if (done) break;
    chunks.push(value);
    received += value.length;
    if (onProgress) onProgress(received, total);
  }
  const all = new Uint8Array(received);
  let offset = 0;
  for (const c of chunks) { all.set(c, offset); offset += c.length; }
  return all;
}

// Cache the session across invocations so switching images doesn't re-download.
let removeBgSession = null;
let removeBgSessionUrl = '';

window.TOOL_HANDLERS['remove-bg'] = function RemoveBgTool() {
  const [file, setFile] = React.useState(null);
  const [srcUrl, setSrcUrl] = React.useState('');
  const [outUrl, setOutUrl] = React.useState('');
  const [status, setStatus] = React.useState('');
  const [progress, setProgress] = React.useState(0);
  const [err, setErr] = React.useState('');
  const [busy, setBusy] = React.useState(false);
  const [modelUrl, setModelUrl] = React.useState(() => {
    try {
      const stored = localStorage.getItem(REMOVE_BG_MODEL_STORAGE_KEY);
      return (stored && removeBgIsValidModelUrl(stored)) ? stored : REMOVE_BG_DEFAULT_MODEL;
    } catch { return REMOVE_BG_DEFAULT_MODEL; }
  });

  const ensureSession = async (url) => {
    if (removeBgSession && removeBgSessionUrl === url) return removeBgSession;
    setStatus('Loading ONNX Runtime…');
    const ort = await removeBgEnsureOrt();
    setStatus('Downloading model…');
    const bytes = await removeBgFetchModel(url, (got, total) => {
      setProgress(Math.round((got / total) * 100));
      setStatus(`Downloading model… ${window.fmtBytes(got)} / ${window.fmtBytes(total)}`);
    });
    setStatus('Initializing model…');
    setProgress(0);
    removeBgSession = await ort.InferenceSession.create(bytes, {
      executionProviders: ['wasm'],
    });
    removeBgSessionUrl = url;
    return removeBgSession;
  };

  const run = async (f) => {
    setBusy(true); setErr(''); setOutUrl(''); setProgress(0);
    try {
      const { img } = await window.loadImageFromFile(f);
      const ort = await removeBgEnsureOrt();
      const session = await ensureSession(modelUrl);

      setStatus('Preparing image…');
      // Draw onto a square input canvas.
      const N = REMOVE_BG_MODEL_SIZE;
      const inputCanvas = document.createElement('canvas');
      inputCanvas.width = N; inputCanvas.height = N;
      const ictx = inputCanvas.getContext('2d');
      ictx.drawImage(img, 0, 0, N, N);
      const { data } = ictx.getImageData(0, 0, N, N);

      // Convert to CHW float32 normalized.
      const chw = new Float32Array(3 * N * N);
      const area = N * N;
      for (let i = 0, px = 0; i < data.length; i += 4, px++) {
        chw[0 * area + px] = (data[i + 0] / 255 - REMOVE_BG_MEAN) / REMOVE_BG_STD;
        chw[1 * area + px] = (data[i + 1] / 255 - REMOVE_BG_MEAN) / REMOVE_BG_STD;
        chw[2 * area + px] = (data[i + 2] / 255 - REMOVE_BG_MEAN) / REMOVE_BG_STD;
      }
      const tensor = new ort.Tensor('float32', chw, [1, 3, N, N]);

      setStatus('Running model…');
      const inputName = session.inputNames[0];
      const feeds = { [inputName]: tensor };
      const result = await session.run(feeds);
      const outName = session.outputNames[0];
      const maskRaw = result[outName].data; // Float32Array, 1x1xN×N or N×N

      // Normalize the mask into [0, 1] in case the model returns logits or a wider range.
      let min = Infinity, max = -Infinity;
      for (let i = 0; i < maskRaw.length; i++) {
        const v = maskRaw[i]; if (v < min) min = v; if (v > max) max = v;
      }
      const range = max - min || 1;

      setStatus('Compositing result…');
      // Resize mask to original dimensions and composite against source image.
      const maskCanvas = document.createElement('canvas');
      maskCanvas.width = N; maskCanvas.height = N;
      const mctx = maskCanvas.getContext('2d');
      const maskImg = mctx.createImageData(N, N);
      for (let i = 0, p = 0; i < maskRaw.length; i++, p += 4) {
        const v = Math.round(((maskRaw[i] - min) / range) * 255);
        maskImg.data[p + 0] = 255;
        maskImg.data[p + 1] = 255;
        maskImg.data[p + 2] = 255;
        maskImg.data[p + 3] = v;
      }
      mctx.putImageData(maskImg, 0, 0);

      const outCanvas = document.createElement('canvas');
      outCanvas.width = img.width;
      outCanvas.height = img.height;
      const octx = outCanvas.getContext('2d');
      octx.drawImage(img, 0, 0);
      octx.globalCompositeOperation = 'destination-in';
      octx.drawImage(maskCanvas, 0, 0, img.width, img.height);
      octx.globalCompositeOperation = 'source-over';

      const outBlob = await new Promise((resolve) => outCanvas.toBlob(resolve, 'image/png'));
      setOutUrl(URL.createObjectURL(outBlob));
      setStatus('Done');
    } catch (e) {
      setErr(e.message || String(e));
      setStatus('');
    } finally {
      setBusy(false);
    }
  };

  const handleFile = (f) => {
    if (srcUrl) URL.revokeObjectURL(srcUrl);
    setFile(f);
    setSrcUrl(URL.createObjectURL(f));
    run(f);
  };

  const saveModelUrl = (url) => {
    if (!removeBgIsValidModelUrl(url)) {
      setErr('Invalid model URL. Only HTTPS URLs from HuggingFace, jsDelivr, unpkg, or GitHub are allowed.');
      return;
    }
    setModelUrl(url);
    try { localStorage.setItem(REMOVE_BG_MODEL_STORAGE_KEY, url); } catch {}
    if (removeBgSession) { try { removeBgSession.release?.(); } catch {} }
    removeBgSession = null; removeBgSessionUrl = '';
    if (file) run(file);
  };

  const download = () => {
    if (!outUrl) return;
    const a = document.createElement('a');
    a.href = outUrl;
    a.download = (file?.name || 'image').replace(/\.[^.]+$/, '') + '-nobg.png';
    a.click();
  };
  const reset = () => {
    if (srcUrl) URL.revokeObjectURL(srcUrl);
    if (outUrl) URL.revokeObjectURL(outUrl);
    setFile(null); setSrcUrl(''); setOutUrl(''); setErr(''); setStatus(''); setProgress(0);
  };

  if (!file) return (
    <div>
      <window.Dropzone onFile={handleFile} title="Drop an image here" hint="PNG, JPG, WebP — runs a segmentation model locally" accept="image/*" />
      <div className="cmp-meta" style={{ textAlign: 'center', marginTop: 14 }}>
        First run downloads the model (~44 MB, cached by your browser). Everything runs in your browser — nothing uploads.
      </div>
    </div>
  );

  return (
    <div className="mini-tool">
      <div className="cmp-preview">
        <div className="cmp-side">
          <div className="cmp-ttl">Original</div>
          <img src={srcUrl} alt="" />
        </div>
        <div className="cmp-side after" style={{ background: 'repeating-conic-gradient(var(--id-border) 0% 25%, transparent 0% 50%) 50% / 16px 16px' }}>
          <div className="cmp-ttl">Result</div>
          {busy && (
            <div style={{ padding: 16 }}>
              <div style={{ fontSize: 13, marginBottom: 8 }}>{status || 'Working…'}</div>
              <div className="pw-bar"><div className="pw-fill" style={{ width: progress + '%', background: 'var(--id-brand-blue)' }} /></div>
            </div>
          )}
          {!busy && outUrl && <img src={outUrl} alt="" />}
          {!busy && err && (
            <div style={{ padding: 16 }}>
              <window.ToolError error={err} hint="Check your network or try a different model URL below." onRetry={() => file && run(file)} />
            </div>
          )}
        </div>
      </div>

      <details style={{ marginTop: 14 }}>
        <summary className="cmp-meta" style={{ cursor: 'pointer' }}>Model URL (advanced)</summary>
        <div style={{ marginTop: 8, display: 'flex', gap: 8 }}>
          <input
            className="mini-input"
            style={{ flex: 1 }}
            defaultValue={modelUrl}
            onBlur={(e) => {
              const v = e.target.value.trim();
              if (v && v !== modelUrl) saveModelUrl(v);
            }}
          />
          <button className="btn btn-secondary" onClick={() => saveModelUrl(REMOVE_BG_DEFAULT_MODEL)}>Reset</button>
        </div>
        <div className="cmp-meta" style={{ marginTop: 6 }}>
          Any ONNX segmentation model with a 3×{REMOVE_BG_MODEL_SIZE}×{REMOVE_BG_MODEL_SIZE} input and single-channel mask output works.
        </div>
      </details>

      <div className="cmp-actions">
        <button className="btn btn-secondary" onClick={reset} disabled={busy}>
          <window.Icon name="upload" size={16} /> Another image
        </button>
        <button className="btn btn-primary" onClick={download} disabled={!outUrl || busy}>
          <window.Icon name="download" size={16} /> Download PNG
        </button>
      </div>
    </div>
  );
};
