// Upscale Image — runs a Real-ESRGAN-style ONNX model in the browser to
// upscale images (default 4×). Large images are processed in overlapping
// tiles so we don't blow up GPU/WASM memory.

const UPSCALE_ORT_URL     = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/ort.min.js';
const UPSCALE_ORT_BASE    = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.1/dist/';
const UPSCALE_DEFAULT_MODEL = 'https://huggingface.co/Xenova/real-esrgan-x4/resolve/main/onnx/model.onnx';
const UPSCALE_DEFAULT_SCALE = 4;
const UPSCALE_TILE = 128;       // tile side (input pixels)
const UPSCALE_OVERLAP = 16;     // feather overlap
const UPSCALE_MODEL_STORAGE = 'mm-upscale-model';
const UPSCALE_ALLOWED_HOSTS = ['huggingface.co', 'cdn.jsdelivr.net', 'unpkg.com', 'github.com', 'raw.githubusercontent.com'];

function upscaleIsValidModelUrl(url) {
  try {
    const u = new URL(url);
    if (u.protocol !== 'https:') return false;
    return UPSCALE_ALLOWED_HOSTS.some((h) => u.hostname === h || u.hostname.endsWith('.' + h));
  } catch { return false; }
}

async function upscaleEnsureOrt() {
  if (window.ort) return window.ort;
  if (window.__ortLoading) return window.__ortLoading;
  window.__ortLoading = new Promise((resolve, reject) => {
    const s = document.createElement('script');
    s.src = UPSCALE_ORT_URL;
    s.onload = () => {
      try { window.ort.env.wasm.wasmPaths = UPSCALE_ORT_BASE; } catch {}
      resolve(window.ort);
    };
    s.onerror = () => reject(new Error('Failed to load onnxruntime-web'));
    document.head.appendChild(s);
  });
  return window.__ortLoading;
}

async function upscaleFetchModel(url, onProgress) {
  const res = await fetch(url);
  if (!res.ok) throw new Error(`Model download failed: ${res.status}`);
  const total = parseInt(res.headers.get('content-length') || '0', 10);
  if (!res.body || !total) return new Uint8Array(await res.arrayBuffer());
  const reader = res.body.getReader();
  const chunks = [];
  let received = 0;
  while (true) {
    const { done, value } = await reader.read();
    if (done) break;
    chunks.push(value);
    received += value.length;
    if (onProgress) onProgress(received, total);
  }
  const out = new Uint8Array(received);
  let off = 0;
  for (const c of chunks) { out.set(c, off); off += c.length; }
  return out;
}

let upscaleSession = null;
let upscaleSessionUrl = '';

// CHW float32, [0,1] — Real-ESRGAN default preprocess.
function upscaleTileToTensor(imageData) {
  const { data, width, height } = imageData;
  const area = width * height;
  const chw = new Float32Array(3 * area);
  for (let i = 0, px = 0; i < data.length; i += 4, px++) {
    chw[0 * area + px] = data[i + 0] / 255;
    chw[1 * area + px] = data[i + 1] / 255;
    chw[2 * area + px] = data[i + 2] / 255;
  }
  return chw;
}

function upscaleTensorToImageData(chw, width, height) {
  const out = new Uint8ClampedArray(width * height * 4);
  const area = width * height;
  for (let px = 0, i = 0; px < area; px++, i += 4) {
    out[i + 0] = Math.max(0, Math.min(255, Math.round(chw[0 * area + px] * 255)));
    out[i + 1] = Math.max(0, Math.min(255, Math.round(chw[1 * area + px] * 255)));
    out[i + 2] = Math.max(0, Math.min(255, Math.round(chw[2 * area + px] * 255)));
    out[i + 3] = 255;
  }
  return new ImageData(out, width, height);
}

window.TOOL_HANDLERS['upscale-image'] = function UpscaleImageTool() {
  const [file, setFile] = React.useState(null);
  const [srcUrl, setSrcUrl] = React.useState('');
  const [outUrl, setOutUrl] = React.useState('');
  const [outDims, setOutDims] = React.useState({ w: 0, h: 0 });
  const [srcDims, setSrcDims] = React.useState({ w: 0, h: 0 });
  const [status, setStatus] = React.useState('');
  const [progress, setProgress] = React.useState(0);
  const [err, setErr] = React.useState('');
  const [busy, setBusy] = React.useState(false);
  const [modelUrl, setModelUrl] = React.useState(() => {
    try {
      const stored = localStorage.getItem(UPSCALE_MODEL_STORAGE);
      return (stored && upscaleIsValidModelUrl(stored)) ? stored : UPSCALE_DEFAULT_MODEL;
    } catch { return UPSCALE_DEFAULT_MODEL; }
  });
  const [scale] = React.useState(UPSCALE_DEFAULT_SCALE);

  const ensureSession = async (url) => {
    if (upscaleSession && upscaleSessionUrl === url) return upscaleSession;
    setStatus('Loading ONNX Runtime…');
    const ort = await upscaleEnsureOrt();
    setStatus('Downloading model…');
    const bytes = await upscaleFetchModel(url, (got, total) => {
      setProgress(Math.round((got / total) * 100));
      setStatus(`Downloading model… ${window.fmtBytes(got)} / ${window.fmtBytes(total)}`);
    });
    setStatus('Initializing model…');
    setProgress(0);
    upscaleSession = await ort.InferenceSession.create(bytes, {
      executionProviders: ['wasm'],
    });
    upscaleSessionUrl = url;
    return upscaleSession;
  };

  const run = async (f) => {
    setBusy(true); setErr(''); setOutUrl(''); setProgress(0);
    try {
      const { img } = await window.loadImageFromFile(f);
      setSrcDims({ w: img.width, h: img.height });

      const ort = await upscaleEnsureOrt();
      const session = await ensureSession(modelUrl);

      // Draw to an input canvas for pixel access.
      const srcCanvas = document.createElement('canvas');
      srcCanvas.width = img.width;
      srcCanvas.height = img.height;
      srcCanvas.getContext('2d').drawImage(img, 0, 0);

      const outW = img.width * scale;
      const outH = img.height * scale;
      const outCanvas = document.createElement('canvas');
      outCanvas.width = outW;
      outCanvas.height = outH;
      const octx = outCanvas.getContext('2d');

      const inputName = session.inputNames[0];
      const step = UPSCALE_TILE - UPSCALE_OVERLAP;
      const tilesX = Math.max(1, Math.ceil((img.width - UPSCALE_OVERLAP) / step));
      const tilesY = Math.max(1, Math.ceil((img.height - UPSCALE_OVERLAP) / step));
      const totalTiles = tilesX * tilesY;
      let done = 0;

      for (let ty = 0; ty < tilesY; ty++) {
        for (let tx = 0; tx < tilesX; tx++) {
          const sx = Math.min(tx * step, Math.max(0, img.width - UPSCALE_TILE));
          const sy = Math.min(ty * step, Math.max(0, img.height - UPSCALE_TILE));
          const sw = Math.min(UPSCALE_TILE, img.width - sx);
          const sh = Math.min(UPSCALE_TILE, img.height - sy);

          const tileCanvas = document.createElement('canvas');
          tileCanvas.width = UPSCALE_TILE;
          tileCanvas.height = UPSCALE_TILE;
          const tctx = tileCanvas.getContext('2d');
          // Pad by edge-replicating to UPSCALE_TILE — models need consistent input size.
          tctx.drawImage(srcCanvas, sx, sy, sw, sh, 0, 0, sw, sh);
          if (sw < UPSCALE_TILE || sh < UPSCALE_TILE) {
            // Replicate the last row/column out to fill the tile so the model
            // doesn't see a black edge.
            if (sw < UPSCALE_TILE) {
              tctx.drawImage(tileCanvas, sw - 1, 0, 1, sh, sw, 0, UPSCALE_TILE - sw, sh);
            }
            if (sh < UPSCALE_TILE) {
              tctx.drawImage(tileCanvas, 0, sh - 1, UPSCALE_TILE, 1, 0, sh, UPSCALE_TILE, UPSCALE_TILE - sh);
            }
          }
          const tileData = tctx.getImageData(0, 0, UPSCALE_TILE, UPSCALE_TILE);
          const chw = upscaleTileToTensor(tileData);
          const tensor = new ort.Tensor('float32', chw, [1, 3, UPSCALE_TILE, UPSCALE_TILE]);

          setStatus(`Upscaling tile ${done + 1} / ${totalTiles}…`);
          const result = await session.run({ [inputName]: tensor });
          const outTensor = result[session.outputNames[0]];
          const [, , tileOutH, tileOutW] = outTensor.dims;
          const outImage = upscaleTensorToImageData(outTensor.data, tileOutW, tileOutH);

          const tmp = document.createElement('canvas');
          tmp.width = tileOutW; tmp.height = tileOutH;
          tmp.getContext('2d').putImageData(outImage, 0, 0);

          // Destination region (keep only the valid scaled area).
          const dx = sx * scale;
          const dy = sy * scale;
          const dw = sw * scale;
          const dh = sh * scale;
          octx.drawImage(tmp, 0, 0, sw * scale, sh * scale, dx, dy, dw, dh);

          done++;
          setProgress(Math.round((done / totalTiles) * 100));
        }
      }

      const blob = await new Promise((resolve) => outCanvas.toBlob(resolve, 'image/png'));
      setOutUrl(URL.createObjectURL(blob));
      setOutDims({ w: outW, h: outH });
      setStatus('Done');
    } catch (e) {
      setErr(e.message || String(e));
      setStatus('');
    } finally {
      setBusy(false);
    }
  };

  const handleFile = (f) => {
    if (srcUrl) URL.revokeObjectURL(srcUrl);
    if (outUrl) URL.revokeObjectURL(outUrl);
    setFile(f);
    setSrcUrl(URL.createObjectURL(f));
    setOutUrl('');
    run(f);
  };
  const saveModelUrl = (url) => {
    if (!upscaleIsValidModelUrl(url)) {
      setErr('Invalid model URL. Only HTTPS URLs from HuggingFace, jsDelivr, unpkg, or GitHub are allowed.');
      return;
    }
    setModelUrl(url);
    try { localStorage.setItem(UPSCALE_MODEL_STORAGE, url); } catch {}
    if (upscaleSession) { try { upscaleSession.release?.(); } catch {} }
    upscaleSession = null; upscaleSessionUrl = '';
    if (file) run(file);
  };
  const download = () => {
    if (!outUrl) return;
    const a = document.createElement('a');
    a.href = outUrl;
    a.download = (file?.name || 'image').replace(/\.[^.]+$/, '') + `-${scale}x.png`;
    a.click();
  };
  const reset = () => {
    if (srcUrl) URL.revokeObjectURL(srcUrl);
    if (outUrl) URL.revokeObjectURL(outUrl);
    setFile(null); setSrcUrl(''); setOutUrl('');
    setErr(''); setStatus(''); setProgress(0);
  };

  if (!file) return (
    <div>
      <window.Dropzone onFile={handleFile} title="Drop an image here" hint={`PNG, JPG — upscales ${scale}× with Real-ESRGAN`} accept="image/*" />
      <div className="cmp-meta" style={{ textAlign: 'center', marginTop: 14 }}>
        First run downloads the model (~65 MB, cached by your browser). Large images are tiled — expect several seconds per megapixel.
      </div>
    </div>
  );

  return (
    <div className="mini-tool">
      <div className="cmp-preview">
        <div className="cmp-side">
          <div className="cmp-ttl">Original · {srcDims.w}×{srcDims.h}</div>
          <img src={srcUrl} alt="" />
        </div>
        <div className="cmp-side after">
          <div className="cmp-ttl">Upscaled · {outDims.w}×{outDims.h || '…'}</div>
          {busy && (
            <div style={{ padding: 16 }}>
              <div style={{ fontSize: 13, marginBottom: 8 }}>{status || 'Working…'}</div>
              <div className="pw-bar"><div className="pw-fill" style={{ width: progress + '%', background: 'var(--id-brand-blue)' }} /></div>
              <div className="cmp-meta" style={{ marginTop: 8 }}>{progress}%</div>
            </div>
          )}
          {!busy && outUrl && <img src={outUrl} alt="" />}
          {!busy && err && (
            <div style={{ padding: 16 }}>
              <window.ToolError error={err} hint="Try a different model URL below if this one is unreachable." onRetry={() => file && run(file)} />
            </div>
          )}
        </div>
      </div>

      <details style={{ marginTop: 14 }}>
        <summary className="cmp-meta" style={{ cursor: 'pointer' }}>Model URL (advanced)</summary>
        <div style={{ marginTop: 8, display: 'flex', gap: 8 }}>
          <input className="mini-input" style={{ flex: 1 }}
                 defaultValue={modelUrl}
                 onBlur={(e) => {
                   const v = e.target.value.trim();
                   if (v && v !== modelUrl) saveModelUrl(v);
                 }} />
          <button className="btn btn-secondary" onClick={() => saveModelUrl(UPSCALE_DEFAULT_MODEL)}>Reset</button>
        </div>
        <div className="cmp-meta" style={{ marginTop: 6 }}>
          Any Real-ESRGAN-style ONNX model with 1×3×H×W float32 input (range [0,1]) works.
        </div>
      </details>

      <div className="cmp-actions">
        <button className="btn btn-secondary" onClick={reset} disabled={busy}>
          <window.Icon name="upload" size={16} /> Another image
        </button>
        <button className="btn btn-primary" onClick={download} disabled={!outUrl || busy}>
          <window.Icon name="download" size={16} /> Download PNG
        </button>
      </div>
    </div>
  );
};
