# Copyright (C) 2026 Embedl AB
# mypy: disable-error-code="no-untyped-call"

"""
Deploying SAM3 as an INT8 TensorRT engine
=========================================

This tutorial shows how to deploy the detector branch of ``facebook/sam3``
end-to-end with Embedl Deploy: from the HuggingFace checkpoint to a
quantized (mixed-precision) TensorRT engine running at 12 QPS at 924×924
input on an NVIDIA GPU (L4).

Install dependencies:

.. code-block:: bash

    pip install torch transformers onnx opencv-python matplotlib onnxscript
    pip install --upgrade tensorrt-cu12
    pip install "embedl-deploy[tensorrt]"

Pipeline:

1. Export the HF SAM3 detector at 924×924 in ``fp32`` with ``torch.export`` as
   a ``pt2`` file (Aten graph).

2. Apply layer fusions + post-training quantization (PTQ) with mixed-precision
   using ``embedl_deploy``, calibrating on real frames from a demo video.
   Calibrate on a larger dataset, e.g., `COCO <https://cocodataset.org/>`_
   or the `LVIS <https://www.lvisdataset.org/>`_, for better accuracy.

The output is a PyTorch model with fake quantization operators, i.e.,
Quantize/Dequantize (QDQ) nodes that can be exported to ONNX for compilation
with TensorRT. Once the quantized model is exported, a TensorRT engine can be
built with mixed precision to run inference.

.. note::
    The full pipeline through the TensorRT engine build, sanity check,
    benchmark and video demo requires an NVIDIA GPU with TensorRT 10.x
    (10.16 pip wheel recommended). Steps 1–3 (``fp32`` export, INT8 PTQ,
    ONNX export) and the QDQ-vs-``fp32`` sanity check run on CPU too but
    slowly.
"""

# sphinx_gallery_start_ignore
# pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports,useless-suppression,no-member,import-error
# sphinx_gallery_end_ignore

# %%
# Constants
# ---------

import sys
import time
from pathlib import Path

import cv2  # type: ignore[import-not-found]
import numpy as np  # type: ignore[import-not-found]
import torch
from torch import nn

sys.setrecursionlimit(20000)

ARTIFACTS_PATH = Path("artifacts")
ARTIFACTS_PATH.mkdir(parents=True, exist_ok=True)

MODEL_ID = "facebook/sam3"
IMAGE_SIZE = 924  # multiple of patch size (14); triggers fusion patterns
PATCH_SIZE = 14
CONTEXT_LENGTH = 32
PROMPT = "person"
N_CALIB = 16  # frames used for PTQ calibration
CONFIDENCE = 0.25
BUILDER_OPTIMIZATION_LEVEL = 3

VIDEO_URL = (
    "https://huggingface.co/datasets/hf-internal-testing/"
    "sam2-fixtures/resolve/main/bedroom.mp4"
)

# Normalisation — must match the SAM3 training-time preprocessing.
MEAN = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
STD = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
MEAN_NP = np.array([0.5, 0.5, 0.5], dtype=np.float32)
STD_NP = np.array([0.5, 0.5, 0.5], dtype=np.float32)

FP32_PT2 = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}.pt2"
QDQ_PT2 = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq.pt2"
QDQ_ONNX = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq.onnx"
QDQ_ONNX_FIXED = (
    ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq_fixed.onnx"
)
ENGINE = ARTIFACTS_PATH / "sam3.engine"
TIMING_CACHE = ARTIFACTS_PATH / "trt_timing.cache"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %%
# Step 1 — fp32 ``torch.export`` of the SAM3 detector
# ---------------------------------------------------
#
# We override the image-size fields in the model config so the encoder runs
# at 924×924 (multiple of the 14-px patch) so that correct fusions fire.
# The detector model is wrapped in a thin ``nn.Module`` exposing a clean
# ``(image, tokens) → (masks, logits, boxes)`` signature for export since we
# require Torch tensors.
#
# Everything stays ``fp32``: the ViT backbone has internal ``fp32`` regions
# (attention softmax) that emit ``aten._to_copy(dtype=float32)`` nodes
# when surrounded by ``fp16``, which then conflict with fused conv/BN
# weights. The precision drop happens at the TRT-engine building stage.

from transformers import (  # type: ignore[import-not-found]
    AutoConfig,
    AutoTokenizer,
    Sam3VideoModel,
)
from transformers.video_utils import (  # type: ignore[import-not-found]
    load_video,
)


def patch_image_size(cfg: AutoConfig, image_size: int) -> AutoConfig:
    """Patch image size."""
    grid = image_size // PATCH_SIZE
    feat_sizes = [[grid * 4, grid * 4], [grid * 2, grid * 2], [grid, grid]]
    cfg.image_size = image_size
    cfg.low_res_mask_size = grid * 4
    for sub in (cfg.detector_config, cfg.tracker_config):
        sub.image_size = image_size
        sub.vision_config.backbone_feature_sizes = feat_sizes
        sub.vision_config.backbone_config.image_size = image_size
    cfg.tracker_config.memory_attention_rope_feat_sizes = [grid, grid]
    return cfg


class Sam3DetectorWrapper(nn.Module):
    """Simple wrapper for tensor inputs and outputs"""

    def __init__(self, detector: nn.Module) -> None:
        """Store the SAM3 detector sub-module."""
        super().__init__()
        self.detector = detector

    def forward(
        self, image: torch.Tensor, tokenized_text: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Run the detector and return masks, logits, and boxes."""
        out = self.detector(pixel_values=image, input_ids=tokenized_text)
        return out.pred_masks, out.pred_logits, out.pred_boxes


def export_sam3_fp32_pt2(path: Path) -> None:
    """Export the SAM3 detector to a fp32 pt2 file via torch.export."""
    cfg = patch_image_size(AutoConfig.from_pretrained(MODEL_ID), IMAGE_SIZE)
    model = (
        Sam3VideoModel.from_pretrained(
            MODEL_ID, config=cfg, torch_dtype=torch.float32
        )
        .eval()
        .to(DEVICE)
    )
    wrapped = Sam3DetectorWrapper(model.detector_model)
    img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
    ids = torch.randint(
        0, 32000, (1, CONTEXT_LENGTH), dtype=torch.long, device=DEVICE
    )
    with torch.no_grad():
        exported = torch.export.export(wrapped, (img, ids), strict=False)
    torch.export.save(exported, str(path))
    print(f"  {path.stat().st_size / 1e9:.2f} GB")

    # Cleanup GPU memory
    del model, wrapped, img, ids, exported
    torch.cuda.empty_cache()


export_sam3_fp32_pt2(path=FP32_PT2)

# %%
# Step 2 — Fuse and INT8-quantize with Embedl Deploy
# ---------------------------------------------------
#
# This is the main part of the recipe that uses Embedl Deploy to fuse and
# quantize the model. In addition to the standard post-training static
# quantization we apply a few customizations to maximize performance and
# accuracy:
#
# **Mixed-precision.** We quantize the encoder (INT8) and skip one node
# from quantization to preserve accuracy - the first 3-channel convolution
# with a ≥7×7 kernel.
#
# **SmoothQuant.** We skip smooth quant for layer norms on the full model.

from embedl_deploy import transform
from embedl_deploy._internal.core.quantize.config import (
    CalibrationMethod,
    ModulesToSkip,
    QuantConfig,
    TensorQuantConfig,
)
from embedl_deploy._internal.core.quantize.main import quantize
from embedl_deploy.tensorrt import TENSORRT_PATTERNS


def load_video_frames(n: int) -> list[torch.Tensor]:
    """Sample ``n`` evenly-spaced frames from the demo video."""
    frames, _ = load_video(VIDEO_URL)
    idxs = (
        np.linspace(0, len(frames) - 1, n, dtype=int).tolist()
        if n < len(frames)
        else list(range(len(frames)))
    )
    out: list[torch.Tensor] = []
    for i in idxs:
        resized = cv2.resize(
            frames[i], (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR
        )
        t = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0
        out.append((t - MEAN) / STD)
    return out


def _find_patch_embed_conv(model: nn.Module) -> nn.Conv2d | None:
    # The 3-channel Conv with a ≥7×7 kernel is the patch-embed stem.
    for m in model.modules():
        if (
            isinstance(m, nn.Conv2d)
            and m.in_channels == 3
            and max(m.kernel_size) >= 7
        ):
            return m
    return None


def quantize_to_qdq(gm: torch.fx.GraphModule) -> None:
    """Fuse and INT8-quantize the fp32 graph with Embedl Deploy."""
    fused = (
        transform(gm, TENSORRT_PATTERNS)
        .model.eval()
        .to(device=DEVICE, dtype=torch.float32)
    )

    stub_w_skip: set[nn.Module] = set()
    if (patch := _find_patch_embed_conv(fused)) is not None:
        stub_w_skip.add(patch)

    quant_cfg = QuantConfig(
        activation=TensorQuantConfig(
            n_bits=8,
            symmetric=True,
            per_channel=False,
            calibration_method=CalibrationMethod.MINMAX,
        ),
        weight=TensorQuantConfig(
            n_bits=8,
            symmetric=True,
            per_channel=True,
            calibration_method=CalibrationMethod.MINMAX,
        ),
        skip=ModulesToSkip(
            stub=stub_w_skip,  # type: ignore[arg-type]
            weight=stub_w_skip,  # type: ignore[arg-type]
            smooth={nn.LayerNorm},
        ),
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    input_ids = tokenizer(
        PROMPT,
        padding="max_length",
        max_length=CONTEXT_LENGTH,
        truncation=True,
        return_tensors="pt",
    )["input_ids"].to(device=DEVICE)

    print(f"Loading {N_CALIB} calibration frames from the demo video")
    calib_imgs = load_video_frames(N_CALIB)

    def calib_fn(model_fn: nn.Module) -> None:  # pylint: disable=cell-var-from-loop
        """Run the model on the calibration data to collect quant stats."""
        model_fn.eval()
        for img in calib_imgs:  # noqa: F821
            with torch.no_grad():
                model_fn(
                    img.unsqueeze(0).to(device=DEVICE, dtype=torch.float32),
                    input_ids,  # noqa: F821
                )

    dummy_img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
    qmodel = quantize(
        fused,
        args=(dummy_img, input_ids),
        config=quant_cfg,
        forward_loop=calib_fn,
        freeze_weights=True,
    )

    # Cleanup calibration frames after quantization
    # (input_ids still needed for export)
    del calib_imgs
    torch.cuda.empty_cache()

    print(f"Re-exporting QDQ → {QDQ_PT2.name}")
    qmodel.eval()
    with torch.no_grad():
        exported = torch.export.export(
            qmodel, (dummy_img, input_ids), strict=False
        )
    torch.export.save(exported, str(QDQ_PT2))
    print(f"  {QDQ_PT2.stat().st_size / 1e9:.2f} GB")

    # Cleanup GPU memory
    del fused, qmodel, exported, dummy_img, input_ids
    torch.cuda.empty_cache()


graph_module: torch.fx.GraphModule = torch.export.load(FP32_PT2).module()
quantize_to_qdq(graph_module)

# Cleanup GPU memory after quantization
del graph_module
torch.cuda.empty_cache()

# %%
# Export quantized model to ONNX
# ------------------------------
#
# We use the dynamo path of ``torch.onnx.export`` (the classic
# ``torchscript`` path calls ``model.train(False)`` which
# ``torch.export``-loaded modules do not support).
# ``Quantize``/``DequantizeLinear`` pairs carry the calibrated scales in
# the resulting ONNX model.


def export_to_onnx(qdq_model_path: Path, onnx_model_path: Path) -> None:
    """Export the QDQ pt2 graph to ONNX with opset 18 and external data."""
    model: torch.fx.GraphModule = torch.export.load(qdq_model_path).module()
    img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
    ids = torch.randint(
        0, 32000, (1, CONTEXT_LENGTH), dtype=torch.long, device=DEVICE
    )
    print(f"Exporting to {onnx_model_path.name}")
    with torch.no_grad():
        torch.onnx.export(
            model,
            (img, ids),
            str(onnx_model_path),
            input_names=["image", "tokenized_text"],
            output_names=["pred_masks", "pred_logits", "pred_boxes"],
            do_constant_folding=True,
            dynamo=True,
        )

    # Cleanup GPU memory
    del model, img, ids
    torch.cuda.empty_cache()


export_to_onnx(qdq_model_path=QDQ_PT2, onnx_model_path=QDQ_ONNX)

# %%
# Build the TensorRT engine
# -----------------------------------
#
# * INT8 + FP16 hybrid precision — TRT picks per-layer.
# * Shared timing cache file — kernel-timing results persist across
#   runs of this script and subsequent variants of the same ONNX.
#
# This cell and every cell below it requires TensorRT.
import tensorrt as trt


def build_engine(onnx_path: Path, engine_path: Path) -> None:
    """Build a TensorRT INT8+FP16 engine from the fixed QDQ ONNX graph."""
    logger = trt.Logger(trt.Logger.WARNING)
    trt.init_libnvinfer_plugins(logger, "")
    builder = trt.Builder(logger)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, logger)
    print(f"Parsing {onnx_path.name}")
    if not parser.parse(onnx_path.read_bytes(), path=str(onnx_path)):
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        raise RuntimeError("ONNX parse failed")

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1024**3)
    config.builder_optimization_level = BUILDER_OPTIMIZATION_LEVEL
    config.set_flag(trt.BuilderFlag.FP16)
    config.set_flag(trt.BuilderFlag.INT8)
    cache = config.create_timing_cache(
        TIMING_CACHE.read_bytes() if TIMING_CACHE.exists() else b""
    )
    config.set_timing_cache(cache, ignore_mismatch=False)

    print(
        f"Building INT8+FP16 engine (opt-level {BUILDER_OPTIMIZATION_LEVEL})"
        "It can take 15-30 min for the first build without a timing cache."
    )
    t0 = time.perf_counter()
    plan = builder.build_serialized_network(network, config)
    if plan is None:
        raise RuntimeError("build_serialized_network returned None")
    dt = time.perf_counter() - t0

    TIMING_CACHE.write_bytes(bytes(config.get_timing_cache().serialize()))
    engine_path.write_bytes(bytes(plan))
    print(
        f"  built in {dt:.0f} s → {engine_path.name} "
        f"({len(bytes(plan)) / 1e6:.0f} MB)"
    )


build_engine(onnx_path=QDQ_ONNX, engine_path=ENGINE)

# Cleanup GPU memory after engine build (TensorRT may hold resources)
torch.cuda.empty_cache()

# %%
# Run the engine on a video
# ---------------------------------------
# We wrap the TRT engine in a simple class that manages the I/O buffers as
# CUDA tensors and exposes a clean ``infer()`` method. The demo video is the
# same one used for calibration; the output is written to a new video file with
# the predicted masks.

_TRT_TORCH = {
    trt.float32: torch.float32,
    trt.float16: torch.float16,
    trt.int32: torch.int32,
    trt.int64: torch.int64,
    trt.int8: torch.int8,
    trt.bool: torch.bool,
}


class TrtRunner:
    """TRT engine wrapper backed by torch.cuda tensors as I/O buffers."""

    def __init__(self, engine_path: Path) -> None:
        """Initialize with an engine."""
        logger = trt.Logger(trt.Logger.WARNING)
        trt.init_libnvinfer_plugins(logger, "")
        self.engine = trt.Runtime(logger).deserialize_cuda_engine(
            engine_path.read_bytes()
        )
        self.ctx = self.engine.create_execution_context()
        self.stream = torch.cuda.Stream()
        self.outputs: list[str] = []
        self.bufs: dict[str, torch.Tensor] = {}
        for i in range(self.engine.num_io_tensors):
            name = self.engine.get_tensor_name(i)
            shape = tuple(self.engine.get_tensor_shape(name))
            dtype = _TRT_TORCH[self.engine.get_tensor_dtype(name)]
            self.bufs[name] = torch.empty(shape, dtype=dtype, device="cuda")
            self.ctx.set_tensor_address(name, self.bufs[name].data_ptr())
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
                self.outputs.append(name)

    @property
    def input_size(self) -> int:
        """Get the input size"""
        return int(self.engine.get_tensor_shape("image")[2])

    def infer(
        self, feeds: dict[str, np.ndarray | torch.Tensor]
    ) -> dict[str, np.ndarray]:
        """Copy inputs into device buffers, execute, return outputs on host."""
        for name, x in feeds.items():
            self.bufs[name].copy_(torch.as_tensor(x))
        self.ctx.execute_async_v3(self.stream.cuda_stream)
        self.stream.synchronize()
        return {n: self.bufs[n].cpu().numpy() for n in self.outputs}


def _sigmoid_np(x: np.ndarray) -> np.ndarray:
    return 1 / (1 + np.exp(-np.clip(x, -50, 50)))


def _postprocess(
    pred_masks: np.ndarray, pred_logits: np.ndarray, h: int, w: int
) -> np.ndarray:
    """Return a stack of binary masks for detections above ``CONFIDENCE``."""
    logits = pred_logits[0, :, 0] if pred_logits.ndim == 3 else pred_logits[0]
    keep = _sigmoid_np(logits) > CONFIDENCE
    masks = pred_masks[0][keep]
    if len(masks) == 0:
        return np.zeros((0, h, w), dtype=bool)
    return np.stack(
        [
            _sigmoid_np(cv2.resize(m, (w, h), interpolation=cv2.INTER_LINEAR))
            > 0.5
            for m in masks
        ]
    )


def _overlay(
    frame_rgb: np.ndarray, masks: np.ndarray, hud: str, colors: np.ndarray
) -> np.ndarray:
    canvas = frame_rgb.copy().astype(np.uint8)
    for i, m in enumerate(masks):
        sel = m.astype(bool)
        c = colors[i % len(colors)]
        canvas[sel] = (canvas[sel] * 0.5 + c * 0.5).astype(np.uint8)
    (tw, th), _ = cv2.getTextSize(hud, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
    cv2.rectangle(canvas, (6, 6), (14 + tw, 18 + th), (0, 0, 0), -1)
    cv2.putText(
        canvas,
        hud,
        (10, 10 + th),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.7,
        (255, 255, 255),
        2,
        cv2.LINE_AA,
    )
    return canvas


def _process_frame(  # pylint: disable=too-many-positional-arguments
    runner: TrtRunner,
    frame: np.ndarray,
    height: int,
    width: int,
    colors: np.ndarray,
    writer: cv2.VideoWriter,
) -> tuple[float, np.ndarray]:
    """Process a single frame and write to video."""
    resized = cv2.resize(
        frame,
        (runner.input_size, runner.input_size),
        interpolation=cv2.INTER_LINEAR,
    )
    img = (resized.astype(np.float32) / 255.0 - MEAN_NP) / STD_NP
    img = np.ascontiguousarray(img.transpose(2, 0, 1)[np.newaxis])

    t0 = time.perf_counter()
    out = runner.infer({"image": img})
    fwd = (time.perf_counter() - t0) * 1000

    masks = _postprocess(out["pred_masks"], out["pred_logits"], height, width)
    hud = f"TensorRT INT8 | {fwd:.1f} ms | {1000 / fwd:.1f} FPS"
    canvas = _overlay(frame, masks, hud, colors)
    writer.write(cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR))
    return fwd, canvas


def run_engine(
    engine: Path, output_path: Path
) -> tuple[Path, list[np.ndarray]]:
    """Run engine on input video with overlay."""  # pylint: disable=too-many-locals
    out_path = output_path
    frames, _ = load_video(VIDEO_URL)
    h, w = frames[0].shape[:2]
    runner = TrtRunner(engine)

    for _ in range(3):  # warm up — sanity check already populated the buffers
        runner.ctx.execute_async_v3(runner.stream.cuda_stream)
    runner.stream.synchronize()

    colors = np.random.default_rng(42).integers(64, 256, size=(32, 3))
    writer = cv2.VideoWriter(
        str(out_path), cv2.VideoWriter_fourcc(*"mp4v"), 30, (w, h)
    )
    fwd_ms: list[float] = []
    sample_overlays: list[np.ndarray] = []
    sample_idxs = {0, len(frames) // 3, 2 * len(frames) // 3}

    print(f"Inference on {len(frames)} frames")
    for i, frame in enumerate(frames):
        fwd, canvas = _process_frame(runner, frame, h, w, colors, writer)
        fwd_ms.append(fwd)
        if i in sample_idxs:
            sample_overlays.append(canvas)

    writer.release()
    a = np.array(fwd_ms)
    print(
        f"  {len(frames)} frames | mean {a.mean():.1f} ms | "
        f"p95 {np.percentile(a, 95):.1f} ms | {1000 / a.mean():.1f} FPS"
    )
    print(f"  Saved: {out_path}")

    # Cleanup GPU memory and TensorRT resources
    del runner, frames, colors, fwd_ms, a
    torch.cuda.empty_cache()

    return out_path, sample_overlays


video_path, samples = run_engine(ENGINE, ARTIFACTS_PATH / "output_trt.mp4")

# %%
# Summary
# -------
#
# ::
#
#     Pipeline: HuggingFace (fp32 .pt2) ─▶ Embedl Deploy (QDQ .pt2)
#               ONNX ─▶ TensorRT engine
#
# On an NVIDIA L4 with TensorRT 10.16:
#
# Throughput: ~12 QPS / ~85 ms latency at 924×924.
