Note
Go to the end to download the full example code.
Deploying vision models with Embedl Deploy#
In this tutorial, we demonstrate an end-to-end workflow for deploying vision
models using embedl_deploy for TensorRT. The models selected are
from the torchvision library, covering a range of architectures: ConvNeXt
Tiny, ResNet-50, and ViT-B/16. We apply post-training static quantization
(PTQ) using the built-in TensorRT pattern set. Embedl Deploy
automatically handles special cases such as depthwise convolutions,
which are memory-bound and left in FP16 to avoid quantization overhead.
The pipeline after transformation and quantization for deploying a model consists of the following steps:
Export the PyTorch model to ONNX and simplify with
onnxsim.Build a TensorRT engine (FP16 for baseline,
--bestfor QDQ models).Run TensorRT inference.
Measure latency with the TensorRT Python API.
Note
This tutorial requires an NVIDIA GPU with TensorRT 10.x installed.
Constants#
import sys
from pathlib import Path
import tensorrt as trt
# ConvNeXt-Base can exceed default recursion limit during deepcopy.
sys.setrecursionlimit(5000)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
IMAGENETTE_URL = (
"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
)
DATA_DIR = Path("artifacts/data")
IMAGENETTE_DIR = DATA_DIR / "imagenette2-320"
IMAGENETTE_TO_IMAGENET = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]
BATCH_SIZE = 1
NUM_WORKERS = 8
CALIBRATION_BATCHES = 32
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
Pattern selection#
TENSORRT_PATTERNS is the default pattern set shipped with Embedl
Deploy. It includes structural conversions (e.g. decomposing
MultiheadAttention), operator fusions (Conv-BN-ReLU, Linear-ReLU,
LayerNorm, etc.), and automatic handling of depthwise convolutions.
Depthwise convolutions are memory-bound, so quantizing them adds
TensorRT reformatting overhead that exceeds the compute gain from
INT8 — Embedl Deploy detects these and keeps them in FP16
automatically.
Dataset helpers#
We use ImageNette — a 10-class subset of ImageNet — for fast evaluation.
import tarfile
import urllib.request
import torchvision
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
def download_imagenette() -> None:
"""Download and extract ImageNette if not already present."""
if IMAGENETTE_DIR.exists():
print(f"ImageNette already present at {IMAGENETTE_DIR}")
return
DATA_DIR.mkdir(parents=True, exist_ok=True)
tgz_path = DATA_DIR / "imagenette2-320.tgz"
print(f"Downloading ImageNette to {tgz_path} ...")
urllib.request.urlretrieve(IMAGENETTE_URL, str(tgz_path)) # noqa: S310
print("Extracting ...")
with tarfile.open(tgz_path) as tar:
tar.extractall(DATA_DIR) # noqa: S202
tgz_path.unlink()
print("Done.")
def _val_transform(crop_size: int = 224, resize_size: int = 256) -> T.Compose:
return T.Compose(
[
T.Resize(resize_size),
T.CenterCrop(crop_size),
T.ToTensor(),
T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
]
)
def _train_transform(crop_size: int = 224) -> T.Compose:
return T.Compose(
[
T.RandomResizedCrop(crop_size),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
]
)
def make_loaders(
crop_size: int = 224, resize_size: int = 256
) -> tuple[DataLoader, DataLoader]:
"""Create train and validation data loaders for ImageNette."""
def remap(t):
return IMAGENETTE_TO_IMAGENET[t]
train_ds = torchvision.datasets.ImageFolder(
str(IMAGENETTE_DIR / "train"),
transform=_train_transform(crop_size),
target_transform=remap,
)
val_ds = torchvision.datasets.ImageFolder(
str(IMAGENETTE_DIR / "val"),
transform=_val_transform(crop_size, resize_size),
target_transform=remap,
)
train = DataLoader(
train_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True,
drop_last=True,
)
val = DataLoader(
val_ds,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
)
print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
return train, val
ONNX export#
Export the model to ONNX and simplify with onnxsim.
import onnx
import onnxsim
import torch
def export_and_simplify(
model: nn.Module, onnx_path: Path, crop_size: int = 224
) -> Path:
"""Export to ONNX + simplify with onnxsim."""
model = model.cpu().eval()
x = torch.randn(1, 3, crop_size, crop_size)
torch.onnx.export(
model,
(x,),
str(onnx_path),
opset_version=20,
input_names=["input"],
output_names=["output"],
dynamo=False,
)
onnx_model = onnx.load(str(onnx_path))
simplified, ok = onnxsim.simplify(onnx_model)
simp_path = onnx_path.with_name(onnx_path.stem + "_sim.onnx")
if ok:
onnx.save(simplified, str(simp_path))
print(f" Simplified ONNX: {simp_path}")
else:
print(" onnxsim failed; using raw export.")
simp_path = onnx_path
return simp_path
TensorRT engine build and inference#
Build a TensorRT engine from an ONNX file and run inference.
import statistics
import time
def _parse_onnx(onnx_path: Path) -> tuple[trt.Builder, trt.INetworkDefinition]:
"""Parse an ONNX model into a TensorRT network."""
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network()
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for err_idx in range(parser.num_errors):
print(f" ONNX parse error: {parser.get_error(err_idx)}")
raise RuntimeError("Failed to parse ONNX model")
return builder, network
def _load_timing_cache(config: trt.IBuilderConfig, cache_path: Path) -> None:
"""Load a timing cache from disk into a builder config."""
data = cache_path.read_bytes() if cache_path.exists() else b""
cache = config.create_timing_cache(data)
config.set_timing_cache(cache, ignore_mismatch=True)
def build_trt_engine(
onnx_path: Path,
engine_path: Path,
fp16: bool = True,
int8: bool = False,
best: bool = False,
) -> Path:
"""Build a TensorRT engine from an ONNX file."""
builder, network = _parse_onnx(onnx_path)
config = builder.create_builder_config()
config.builder_optimization_level = 5
cache_path = engine_path.parent / "timing.cache"
_load_timing_cache(config, cache_path)
if best:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
elif int8:
config.set_flag(trt.BuilderFlag.INT8)
elif fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(" Building TensorRT engine ...")
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("TensorRT engine build failed")
cache_path.write_bytes(
bytes(memoryview(config.get_timing_cache().serialize()))
)
engine_path.write_bytes(bytes(serialized))
print(f" Engine: {engine_path} ({len(bytes(serialized)) / 1e6:.1f} MB)")
return engine_path
def _prepare_trt_context(
engine_path: Path,
) -> tuple[trt.IExecutionContext, torch.cuda.Stream]:
"""Load engine, allocate GPU buffers, return ready context."""
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(engine_path.read_bytes())
context = engine.create_execution_context()
input_name = engine.get_tensor_name(0)
output_name = engine.get_tensor_name(1)
input_shape = tuple(engine.get_tensor_shape(input_name))
output_shape = tuple(engine.get_tensor_shape(output_name))
inp = torch.empty(input_shape, dtype=torch.float32, device="cuda")
out = torch.empty(output_shape, dtype=torch.float32, device="cuda")
context.set_tensor_address(input_name, inp.data_ptr())
context.set_tensor_address(output_name, out.data_ptr())
return context, torch.cuda.Stream()
def measure_latency(
engine_path: Path,
warmup_ms: int = 1000,
duration: int = 30,
) -> dict:
"""Measure inference latency using the TensorRT Python API.
:param engine_path:
Path to a serialized TensorRT engine.
:param warmup_ms:
Warm-up time in milliseconds before recording.
:param duration:
Benchmarking duration in seconds.
:returns:
Dict with mean/median/p99 latency (ms) and throughput (qps).
"""
context, stream = _prepare_trt_context(engine_path)
# Warm up.
print(" Warming up ...")
warmup_until = time.perf_counter() + warmup_ms / 1000
while time.perf_counter() < warmup_until:
context.execute_async_v3(stream.cuda_stream)
stream.synchronize()
# Timed iterations.
print(f" Benchmarking ({duration}s) ...")
timings: list[float] = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
deadline = time.perf_counter() + duration
while time.perf_counter() < deadline:
start.record(stream)
context.execute_async_v3(stream.cuda_stream)
end.record(stream)
end.synchronize()
timings.append(start.elapsed_time(end))
mean_ms = statistics.mean(timings)
median_ms = statistics.median(timings)
p99_ms = statistics.quantiles(timings, n=100)[-1]
throughput = len(timings) / (sum(timings) / 1000)
print(
f" Latency: mean={mean_ms:.3f} ms, "
f"median={median_ms:.3f} ms, "
f"p99={p99_ms:.3f} ms, throughput={throughput:.1f} qps"
)
return {
"mean_latency_ms": mean_ms,
"median_latency_ms": median_ms,
"p99_latency_ms": p99_ms,
"throughput_qps": throughput,
}
class TRTInferencer:
"""TensorRT engine wrapper for batch inference."""
def __init__(self, engine_path: Path):
with open(engine_path, "rb") as f:
engine_data = f.read()
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
self.engine = runtime.deserialize_cuda_engine(engine_data)
self.context = self.engine.create_execution_context()
self.stream = torch.cuda.Stream()
def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Run inference on a single batch.
:param input_tensor:
NCHW image batch on CUDA.
:returns:
Model output tensor.
"""
input_name = self.engine.get_tensor_name(0)
output_name = self.engine.get_tensor_name(1)
self.context.set_input_shape(input_name, tuple(input_tensor.shape))
output_shape = tuple(self.context.get_tensor_shape(output_name))
output_tensor = torch.empty(
output_shape, dtype=torch.float32, device="cuda"
)
self.context.set_tensor_address(input_name, input_tensor.data_ptr())
self.context.set_tensor_address(output_name, output_tensor.data_ptr())
self.context.execute_async_v3(self.stream.cuda_stream)
self.stream.synchronize()
return output_tensor
def evaluate_trt(
engine_path: Path, val_loader: DataLoader
) -> dict[str, float]:
"""Top-1/Top-5 accuracy on the validation set."""
inferencer = TRTInferencer(engine_path)
top1_sum = top5_sum = total = 0
for batch, targets in val_loader:
batch, targets = batch.cuda(), targets.cuda()
output = inferencer.infer(batch)
_, pred = output.topk(5, dim=1, largest=True, sorted=True)
correct = pred.t().eq(targets.view(1, -1).expand_as(pred.t()))
top1_sum += correct[:1].reshape(-1).float().sum().item()
top5_sum += correct[:5].reshape(-1).float().sum().item()
total += batch.size(0)
return {
"top1": top1_sum / total * 100,
"top5": top5_sum / total * 100,
}
Quantization with Embedl Deploy#
Fuse layers, insert QDQ stubs with the selective pattern list, and calibrate on training data.
from torch.fx.passes.shape_prop import ShapeProp
from embedl_deploy import transform
from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy.quantize import (
QuantConfig,
QuantStub,
TensorQuantConfig,
WeightFakeQuantize,
quantize,
)
def quantize_embedl(
pretrained_model: nn.Module,
calib_batches: list[torch.Tensor],
crop_size: int = 224,
) -> nn.Module:
"""Fuse + quantize + calibrate using Embedl Deploy."""
print("\n=== Embedl Deploy PTQ ===")
gm = symbolic_trace(pretrained_model.cpu().eval())
ShapeProp(gm).propagate(torch.randn(1, 3, crop_size, crop_size))
fused_model = transform(gm, patterns=TENSORRT_PATTERNS).model
# Verify lossless fusion.
with torch.no_grad():
x = torch.randn(1, 3, crop_size, crop_size)
max_diff = (pretrained_model.cpu()(x) - fused_model(x)).abs().max()
assert max_diff < 1e-4, f"Fusion diverged: {max_diff:.2e}"
print(f" Fusion check passed (max diff = {max_diff:.2e})")
print(f" Calibrating ({len(calib_batches)} batches) ...")
def forward_loop(model):
with torch.no_grad():
for batch in calib_batches:
model(batch)
quantized = quantize(
fused_model,
(torch.randn(1, 3, crop_size, crop_size),),
config=QuantConfig(
activation=TensorQuantConfig(n_bits=8, symmetric=True),
weight=TensorQuantConfig(
n_bits=8,
symmetric=True,
per_channel=True,
),
),
forward_loop=forward_loop,
)
n_act = sum(1 for m in quantized.modules() if isinstance(m, QuantStub))
n_wt = sum(
1 for m in quantized.modules() if isinstance(m, WeightFakeQuantize)
)
print(f" QuantStubs: {n_act}, WeightFakeQuantize: {n_wt}")
print(" Calibration complete.")
return quantized
Benchmark runner#
Helper that runs the full export -> build -> evaluate -> latency pipeline for a single model variant.
def run_variant(
tag: str,
model: nn.Module,
*,
dest: Path,
val_loader: DataLoader | None,
crop_size: int,
best: bool = False,
fp16: bool = True,
) -> tuple[dict, dict]:
"""Export, build, evaluate, and measure one model variant."""
print(f"\n{'=' * 60}\n{tag}\n{'=' * 60}")
onnx_path = export_and_simplify(model, dest / f"{tag}.onnx", crop_size)
engine_path = dest / f"{tag}.engine"
build_trt_engine(onnx_path, engine_path, fp16=fp16, best=best)
acc = (
evaluate_trt(engine_path, val_loader)
if val_loader is not None
else {"top1": 0.0, "top5": 0.0}
)
latency = measure_latency(engine_path)
if val_loader is not None:
print(f" Top-1: {acc['top1']:.2f}% Top-5: {acc['top5']:.2f}%")
print(
f" Latency: {latency['mean_latency_ms']:.3f} ms "
f"Throughput: {latency['throughput_qps']:.1f} qps"
)
return acc, latency
Configuration#
Choose the model and image sizes. Change MODEL_NAME to benchmark
a different architecture (e.g. "convnext_base", "resnet50",
"vit_b_16").
MODEL_NAME = "convnext_tiny"
CROP_SIZE = 224
RESIZE_SIZE = 256
BENCHMARK_DIR = Path(f"artifacts/ptq/{MODEL_NAME}_benchmark")
BENCHMARK_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "This benchmark requires a CUDA GPU."
print(f"Model: {MODEL_NAME} crop={CROP_SIZE} resize={RESIZE_SIZE}")
print(f"Output: {BENCHMARK_DIR}")
Prepare data and model#
Download ImageNette and load the pretrained torchvision model.
download_imagenette()
train_dl, val_dl = make_loaders(CROP_SIZE, RESIZE_SIZE)
calib_data: list[torch.Tensor] = []
for batch_idx, (imgs, _) in enumerate(train_dl):
if batch_idx >= CALIBRATION_BATCHES:
break
calib_data.append(imgs)
print(f"Collected {len(calib_data)} calibration batches.")
pretrained = torchvision.models.get_model(MODEL_NAME, weights="DEFAULT").eval()
print(
f"Loaded {MODEL_NAME} "
f"({sum(p.numel() for p in pretrained.parameters()):,} params)"
)
Baseline FP16#
Build and benchmark the unquantized model with FP16 precision.
baseline_acc, baseline_lat = run_variant(
"baseline_fp16",
pretrained,
dest=BENCHMARK_DIR,
val_loader=val_dl,
crop_size=CROP_SIZE,
)
Embedl Deploy mixed-precision#
Apply TENSORRT_PATTERNS. Depthwise convolutions are
automatically kept in FP16 while compute-bound operators are
quantized to INT8.
embedl_model = quantize_embedl(pretrained, calib_data, CROP_SIZE)
embedl_acc, embedl_lat = run_variant(
"embedl_mixed_precision",
embedl_model,
dest=BENCHMARK_DIR,
val_loader=val_dl,
crop_size=CROP_SIZE,
best=True,
)
Summary#
Compare accuracy and latency across all variants.
HEADER = (
f"{'Variant':<25s} {'Top-1':>7s} {'Top-5':>7s} "
f"{'Latency(ms)':>12s} {'Throughput':>12s}"
)
rows = [
("Baseline (FP16)", baseline_acc, baseline_lat),
("Embedl Deploy (best)", embedl_acc, embedl_lat),
]
print(f"\n{'=' * 80}")
print(f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette")
print("=" * 80)
print(HEADER)
print("-" * len(HEADER))
for label, row_acc, row_lat in rows:
print(
f"{label:<25s} {row_acc['top1']:6.2f}% {row_acc['top5']:6.2f}% "
f"{row_lat['mean_latency_ms']:11.3f} "
f"{row_lat['throughput_qps']:10.1f} qps"
)
speedup_e = baseline_lat["mean_latency_ms"] / max(
embedl_lat["mean_latency_ms"], 1e-6
)
drop_e = baseline_acc["top1"] - embedl_acc["top1"]
print()
print(
f" Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, speedup: {speedup_e:.2f}x"
)
print("=" * 80)
# Save results.
RESULTS_PATH = BENCHMARK_DIR / f"{MODEL_NAME}_results.txt"
lines = [
f"TRT {trt.__version__}",
"=" * 80,
f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette",
"=" * 80,
HEADER,
"-" * len(HEADER),
]
for label, row_acc, row_lat in rows:
lines.append(
f"{label:<25s} {row_acc['top1']:6.2f}% {row_acc['top5']:6.2f}% "
f"{row_lat['mean_latency_ms']:11.3f} "
f"{row_lat['throughput_qps']:10.1f} qps"
)
lines += [
"",
f" Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, "
f"speedup: {speedup_e:.2f}x",
"=" * 80,
]
RESULTS_PATH.write_text("\n".join(lines) + "\n")
print(f"\nResults saved to {RESULTS_PATH}")