Note
Go to the end to download the full example code.
Deploying SAM3 as an INT8 TensorRT engine#
This tutorial shows how to deploy the detector branch of facebook/sam3
end-to-end with Embedl Deploy: from the HuggingFace checkpoint to a
quantized (mixed-precision) TensorRT engine running at 12 QPS at 924×924
input on an NVIDIA GPU (L4).
Install dependencies:
pip install torch transformers onnx opencv-python matplotlib onnxscript
pip install --upgrade tensorrt-cu12
pip install "embedl-deploy[tensorrt]"
Pipeline:
Export the HF SAM3 detector at 924×924 in
fp32withtorch.exportas apt2file (Aten graph).Apply layer fusions + post-training quantization (PTQ) with mixed-precision using
embedl_deploy, calibrating on real frames from a demo video. Calibrate on a larger dataset, e.g., COCO or the LVIS, for better accuracy.
The output is a PyTorch model with fake quantization operators, i.e., Quantize/Dequantize (QDQ) nodes that can be exported to ONNX for compilation with TensorRT. Once the quantized model is exported, a TensorRT engine can be built with mixed precision to run inference.
Note
The full pipeline through the TensorRT engine build, sanity check,
benchmark and video demo requires an NVIDIA GPU with TensorRT 10.x
(10.16 pip wheel recommended). Steps 1–3 (fp32 export, INT8 PTQ,
ONNX export) and the QDQ-vs-fp32 sanity check run on CPU too but
slowly.
Constants#
import sys
import time
from pathlib import Path
import cv2 # type: ignore[import-not-found]
import numpy as np # type: ignore[import-not-found]
import torch
from torch import nn
sys.setrecursionlimit(20000)
ARTIFACTS_PATH = Path("artifacts")
ARTIFACTS_PATH.mkdir(parents=True, exist_ok=True)
MODEL_ID = "facebook/sam3"
IMAGE_SIZE = 924 # multiple of patch size (14); triggers fusion patterns
PATCH_SIZE = 14
CONTEXT_LENGTH = 32
PROMPT = "person"
N_CALIB = 16 # frames used for PTQ calibration
CONFIDENCE = 0.25
BUILDER_OPTIMIZATION_LEVEL = 3
VIDEO_URL = (
"https://huggingface.co/datasets/hf-internal-testing/"
"sam2-fixtures/resolve/main/bedroom.mp4"
)
# Normalisation — must match the SAM3 training-time preprocessing.
MEAN = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
STD = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
MEAN_NP = np.array([0.5, 0.5, 0.5], dtype=np.float32)
STD_NP = np.array([0.5, 0.5, 0.5], dtype=np.float32)
FP32_PT2 = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}.pt2"
QDQ_PT2 = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq.pt2"
QDQ_ONNX = ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq.onnx"
QDQ_ONNX_FIXED = (
ARTIFACTS_PATH / f"sam3_resized_{IMAGE_SIZE}_int8_qdq_fixed.onnx"
)
ENGINE = ARTIFACTS_PATH / "sam3.engine"
TIMING_CACHE = ARTIFACTS_PATH / "trt_timing.cache"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Step 1 — fp32 torch.export of the SAM3 detector#
We override the image-size fields in the model config so the encoder runs
at 924×924 (multiple of the 14-px patch) so that correct fusions fire.
The detector model is wrapped in a thin nn.Module exposing a clean
(image, tokens) → (masks, logits, boxes) signature for export since we
require Torch tensors.
Everything stays fp32: the ViT backbone has internal fp32 regions
(attention softmax) that emit aten._to_copy(dtype=float32) nodes
when surrounded by fp16, which then conflict with fused conv/BN
weights. The precision drop happens at the TRT-engine building stage.
from transformers import ( # type: ignore[import-not-found]
AutoConfig,
AutoTokenizer,
Sam3VideoModel,
)
from transformers.video_utils import ( # type: ignore[import-not-found]
load_video,
)
def patch_image_size(cfg: AutoConfig, image_size: int) -> AutoConfig:
"""Patch image size."""
grid = image_size // PATCH_SIZE
feat_sizes = [[grid * 4, grid * 4], [grid * 2, grid * 2], [grid, grid]]
cfg.image_size = image_size
cfg.low_res_mask_size = grid * 4
for sub in (cfg.detector_config, cfg.tracker_config):
sub.image_size = image_size
sub.vision_config.backbone_feature_sizes = feat_sizes
sub.vision_config.backbone_config.image_size = image_size
cfg.tracker_config.memory_attention_rope_feat_sizes = [grid, grid]
return cfg
class Sam3DetectorWrapper(nn.Module):
"""Simple wrapper for tensor inputs and outputs"""
def __init__(self, detector: nn.Module) -> None:
"""Store the SAM3 detector sub-module."""
super().__init__()
self.detector = detector
def forward(
self, image: torch.Tensor, tokenized_text: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run the detector and return masks, logits, and boxes."""
out = self.detector(pixel_values=image, input_ids=tokenized_text)
return out.pred_masks, out.pred_logits, out.pred_boxes
def export_sam3_fp32_pt2(path: Path) -> None:
"""Export the SAM3 detector to a fp32 pt2 file via torch.export."""
cfg = patch_image_size(AutoConfig.from_pretrained(MODEL_ID), IMAGE_SIZE)
model = (
Sam3VideoModel.from_pretrained(
MODEL_ID, config=cfg, torch_dtype=torch.float32
)
.eval()
.to(DEVICE)
)
wrapped = Sam3DetectorWrapper(model.detector_model)
img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
ids = torch.randint(
0, 32000, (1, CONTEXT_LENGTH), dtype=torch.long, device=DEVICE
)
with torch.no_grad():
exported = torch.export.export(wrapped, (img, ids), strict=False)
torch.export.save(exported, str(path))
print(f" {path.stat().st_size / 1e9:.2f} GB")
# Cleanup GPU memory
del model, wrapped, img, ids, exported
torch.cuda.empty_cache()
export_sam3_fp32_pt2(path=FP32_PT2)
Step 2 — Fuse and INT8-quantize with Embedl Deploy#
This is the main part of the recipe that uses Embedl Deploy to fuse and quantize the model. In addition to the standard post-training static quantization we apply a few customizations to maximize performance and accuracy:
Mixed-precision. We quantize the encoder (INT8) and skip one node from quantization to preserve accuracy - the first 3-channel convolution with a ≥7×7 kernel.
SmoothQuant. We skip smooth quant for layer norms on the full model.
from embedl_deploy import transform
from embedl_deploy._internal.core.quantize.config import (
CalibrationMethod,
ModulesToSkip,
QuantConfig,
TensorQuantConfig,
)
from embedl_deploy._internal.core.quantize.main import quantize
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
def load_video_frames(n: int) -> list[torch.Tensor]:
"""Sample ``n`` evenly-spaced frames from the demo video."""
frames, _ = load_video(VIDEO_URL)
idxs = (
np.linspace(0, len(frames) - 1, n, dtype=int).tolist()
if n < len(frames)
else list(range(len(frames)))
)
out: list[torch.Tensor] = []
for i in idxs:
resized = cv2.resize(
frames[i], (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR
)
t = torch.from_numpy(resized).permute(2, 0, 1).float() / 255.0
out.append((t - MEAN) / STD)
return out
def _find_patch_embed_conv(model: nn.Module) -> nn.Conv2d | None:
# The 3-channel Conv with a ≥7×7 kernel is the patch-embed stem.
for m in model.modules():
if (
isinstance(m, nn.Conv2d)
and m.in_channels == 3
and max(m.kernel_size) >= 7
):
return m
return None
def quantize_to_qdq(gm: torch.fx.GraphModule) -> None:
"""Fuse and INT8-quantize the fp32 graph with Embedl Deploy."""
fused = (
transform(gm, TENSORRT_PATTERNS)
.model.eval()
.to(device=DEVICE, dtype=torch.float32)
)
stub_w_skip: set[nn.Module] = set()
if (patch := _find_patch_embed_conv(fused)) is not None:
stub_w_skip.add(patch)
quant_cfg = QuantConfig(
activation=TensorQuantConfig(
n_bits=8,
symmetric=True,
per_channel=False,
calibration_method=CalibrationMethod.MINMAX,
),
weight=TensorQuantConfig(
n_bits=8,
symmetric=True,
per_channel=True,
calibration_method=CalibrationMethod.MINMAX,
),
skip=ModulesToSkip(
stub=stub_w_skip, # type: ignore[arg-type]
weight=stub_w_skip, # type: ignore[arg-type]
smooth={nn.LayerNorm},
),
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
input_ids = tokenizer(
PROMPT,
padding="max_length",
max_length=CONTEXT_LENGTH,
truncation=True,
return_tensors="pt",
)["input_ids"].to(device=DEVICE)
print(f"Loading {N_CALIB} calibration frames from the demo video")
calib_imgs = load_video_frames(N_CALIB)
def calib_fn(model_fn: nn.Module) -> None: # pylint: disable=cell-var-from-loop
"""Run the model on the calibration data to collect quant stats."""
model_fn.eval()
for img in calib_imgs: # noqa: F821
with torch.no_grad():
model_fn(
img.unsqueeze(0).to(device=DEVICE, dtype=torch.float32),
input_ids, # noqa: F821
)
dummy_img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
qmodel = quantize(
fused,
args=(dummy_img, input_ids),
config=quant_cfg,
forward_loop=calib_fn,
freeze_weights=True,
)
# Cleanup calibration frames after quantization
# (input_ids still needed for export)
del calib_imgs
torch.cuda.empty_cache()
print(f"Re-exporting QDQ → {QDQ_PT2.name}")
qmodel.eval()
with torch.no_grad():
exported = torch.export.export(
qmodel, (dummy_img, input_ids), strict=False
)
torch.export.save(exported, str(QDQ_PT2))
print(f" {QDQ_PT2.stat().st_size / 1e9:.2f} GB")
# Cleanup GPU memory
del fused, qmodel, exported, dummy_img, input_ids
torch.cuda.empty_cache()
graph_module: torch.fx.GraphModule = torch.export.load(FP32_PT2).module()
quantize_to_qdq(graph_module)
# Cleanup GPU memory after quantization
del graph_module
torch.cuda.empty_cache()
Export quantized model to ONNX#
We use the dynamo path of torch.onnx.export (the classic
torchscript path calls model.train(False) which
torch.export-loaded modules do not support).
Quantize/DequantizeLinear pairs carry the calibrated scales in
the resulting ONNX model.
def export_to_onnx(qdq_model_path: Path, onnx_model_path: Path) -> None:
"""Export the QDQ pt2 graph to ONNX with opset 18 and external data."""
model: torch.fx.GraphModule = torch.export.load(qdq_model_path).module()
img = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
ids = torch.randint(
0, 32000, (1, CONTEXT_LENGTH), dtype=torch.long, device=DEVICE
)
print(f"Exporting to {onnx_model_path.name}")
with torch.no_grad():
torch.onnx.export(
model,
(img, ids),
str(onnx_model_path),
input_names=["image", "tokenized_text"],
output_names=["pred_masks", "pred_logits", "pred_boxes"],
do_constant_folding=True,
dynamo=True,
)
# Cleanup GPU memory
del model, img, ids
torch.cuda.empty_cache()
export_to_onnx(qdq_model_path=QDQ_PT2, onnx_model_path=QDQ_ONNX)
Build the TensorRT engine#
INT8 + FP16 hybrid precision — TRT picks per-layer.
Shared timing cache file — kernel-timing results persist across runs of this script and subsequent variants of the same ONNX.
This cell and every cell below it requires TensorRT.
import tensorrt as trt
def build_engine(onnx_path: Path, engine_path: Path) -> None:
"""Build a TensorRT INT8+FP16 engine from the fixed QDQ ONNX graph."""
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
print(f"Parsing {onnx_path.name}")
if not parser.parse(onnx_path.read_bytes(), path=str(onnx_path)):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed")
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1024**3)
config.builder_optimization_level = BUILDER_OPTIMIZATION_LEVEL
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
cache = config.create_timing_cache(
TIMING_CACHE.read_bytes() if TIMING_CACHE.exists() else b""
)
config.set_timing_cache(cache, ignore_mismatch=False)
print(
f"Building INT8+FP16 engine (opt-level {BUILDER_OPTIMIZATION_LEVEL})"
"It can take 15-30 min for the first build without a timing cache."
)
t0 = time.perf_counter()
plan = builder.build_serialized_network(network, config)
if plan is None:
raise RuntimeError("build_serialized_network returned None")
dt = time.perf_counter() - t0
TIMING_CACHE.write_bytes(bytes(config.get_timing_cache().serialize()))
engine_path.write_bytes(bytes(plan))
print(
f" built in {dt:.0f} s → {engine_path.name} "
f"({len(bytes(plan)) / 1e6:.0f} MB)"
)
build_engine(onnx_path=QDQ_ONNX, engine_path=ENGINE)
# Cleanup GPU memory after engine build (TensorRT may hold resources)
torch.cuda.empty_cache()
Run the engine on a video#
We wrap the TRT engine in a simple class that manages the I/O buffers as
CUDA tensors and exposes a clean infer() method. The demo video is the
same one used for calibration; the output is written to a new video file with
the predicted masks.
_TRT_TORCH = {
trt.float32: torch.float32,
trt.float16: torch.float16,
trt.int32: torch.int32,
trt.int64: torch.int64,
trt.int8: torch.int8,
trt.bool: torch.bool,
}
class TrtRunner:
"""TRT engine wrapper backed by torch.cuda tensors as I/O buffers."""
def __init__(self, engine_path: Path) -> None:
"""Initialize with an engine."""
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
self.engine = trt.Runtime(logger).deserialize_cuda_engine(
engine_path.read_bytes()
)
self.ctx = self.engine.create_execution_context()
self.stream = torch.cuda.Stream()
self.outputs: list[str] = []
self.bufs: dict[str, torch.Tensor] = {}
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
shape = tuple(self.engine.get_tensor_shape(name))
dtype = _TRT_TORCH[self.engine.get_tensor_dtype(name)]
self.bufs[name] = torch.empty(shape, dtype=dtype, device="cuda")
self.ctx.set_tensor_address(name, self.bufs[name].data_ptr())
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
self.outputs.append(name)
@property
def input_size(self) -> int:
"""Get the input size"""
return int(self.engine.get_tensor_shape("image")[2])
def infer(
self, feeds: dict[str, np.ndarray | torch.Tensor]
) -> dict[str, np.ndarray]:
"""Copy inputs into device buffers, execute, return outputs on host."""
for name, x in feeds.items():
self.bufs[name].copy_(torch.as_tensor(x))
self.ctx.execute_async_v3(self.stream.cuda_stream)
self.stream.synchronize()
return {n: self.bufs[n].cpu().numpy() for n in self.outputs}
def _sigmoid_np(x: np.ndarray) -> np.ndarray:
return 1 / (1 + np.exp(-np.clip(x, -50, 50)))
def _postprocess(
pred_masks: np.ndarray, pred_logits: np.ndarray, h: int, w: int
) -> np.ndarray:
"""Return a stack of binary masks for detections above ``CONFIDENCE``."""
logits = pred_logits[0, :, 0] if pred_logits.ndim == 3 else pred_logits[0]
keep = _sigmoid_np(logits) > CONFIDENCE
masks = pred_masks[0][keep]
if len(masks) == 0:
return np.zeros((0, h, w), dtype=bool)
return np.stack(
[
_sigmoid_np(cv2.resize(m, (w, h), interpolation=cv2.INTER_LINEAR))
> 0.5
for m in masks
]
)
def _overlay(
frame_rgb: np.ndarray, masks: np.ndarray, hud: str, colors: np.ndarray
) -> np.ndarray:
canvas = frame_rgb.copy().astype(np.uint8)
for i, m in enumerate(masks):
sel = m.astype(bool)
c = colors[i % len(colors)]
canvas[sel] = (canvas[sel] * 0.5 + c * 0.5).astype(np.uint8)
(tw, th), _ = cv2.getTextSize(hud, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)
cv2.rectangle(canvas, (6, 6), (14 + tw, 18 + th), (0, 0, 0), -1)
cv2.putText(
canvas,
hud,
(10, 10 + th),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
cv2.LINE_AA,
)
return canvas
def _process_frame( # pylint: disable=too-many-positional-arguments
runner: TrtRunner,
frame: np.ndarray,
height: int,
width: int,
colors: np.ndarray,
writer: cv2.VideoWriter,
) -> tuple[float, np.ndarray]:
"""Process a single frame and write to video."""
resized = cv2.resize(
frame,
(runner.input_size, runner.input_size),
interpolation=cv2.INTER_LINEAR,
)
img = (resized.astype(np.float32) / 255.0 - MEAN_NP) / STD_NP
img = np.ascontiguousarray(img.transpose(2, 0, 1)[np.newaxis])
t0 = time.perf_counter()
out = runner.infer({"image": img})
fwd = (time.perf_counter() - t0) * 1000
masks = _postprocess(out["pred_masks"], out["pred_logits"], height, width)
hud = f"TensorRT INT8 | {fwd:.1f} ms | {1000 / fwd:.1f} FPS"
canvas = _overlay(frame, masks, hud, colors)
writer.write(cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR))
return fwd, canvas
def run_engine(
engine: Path, output_path: Path
) -> tuple[Path, list[np.ndarray]]:
"""Run engine on input video with overlay.""" # pylint: disable=too-many-locals
out_path = output_path
frames, _ = load_video(VIDEO_URL)
h, w = frames[0].shape[:2]
runner = TrtRunner(engine)
for _ in range(3): # warm up — sanity check already populated the buffers
runner.ctx.execute_async_v3(runner.stream.cuda_stream)
runner.stream.synchronize()
colors = np.random.default_rng(42).integers(64, 256, size=(32, 3))
writer = cv2.VideoWriter(
str(out_path), cv2.VideoWriter_fourcc(*"mp4v"), 30, (w, h)
)
fwd_ms: list[float] = []
sample_overlays: list[np.ndarray] = []
sample_idxs = {0, len(frames) // 3, 2 * len(frames) // 3}
print(f"Inference on {len(frames)} frames")
for i, frame in enumerate(frames):
fwd, canvas = _process_frame(runner, frame, h, w, colors, writer)
fwd_ms.append(fwd)
if i in sample_idxs:
sample_overlays.append(canvas)
writer.release()
a = np.array(fwd_ms)
print(
f" {len(frames)} frames | mean {a.mean():.1f} ms | "
f"p95 {np.percentile(a, 95):.1f} ms | {1000 / a.mean():.1f} FPS"
)
print(f" Saved: {out_path}")
# Cleanup GPU memory and TensorRT resources
del runner, frames, colors, fwd_ms, a
torch.cuda.empty_cache()
return out_path, sample_overlays
video_path, samples = run_engine(ENGINE, ARTIFACTS_PATH / "output_trt.mp4")
Summary#
Pipeline: HuggingFace (fp32 .pt2) ─▶ Embedl Deploy (QDQ .pt2)
ONNX ─▶ TensorRT engine
On an NVIDIA L4 with TensorRT 10.16:
Throughput: ~12 QPS / ~85 ms latency at 924×924.