TensorRT inference and evaluation helpers#

Shared utilities for building TensorRT engines, running inference, and measuring accuracy and latency. The other tutorials import from this module so they can focus on the Embedl Deploy workflow rather than TensorRT boilerplate.

Note

This tutorial requires an NVIDIA GPU with TensorRT 10.x installed.

Constants#

import statistics
import sys
import time
from pathlib import Path

import numpy as np  # type: ignore[import-not-found]
import tensorrt as trt
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# ConvNeXt-Base can exceed default recursion limit during deepcopy.
sys.setrecursionlimit(5000)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

IMAGENETTE_URL = (
    "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
)
DATA_DIR = Path("artifacts/data")
IMAGENETTE_DIR = DATA_DIR / "imagenette2-320"

IMAGENETTE_TO_IMAGENET = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]

BATCH_SIZE = 1
NUM_WORKERS = 8
CALIBRATION_BATCHES = 32

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

ImageLoader = DataLoader[tuple[torch.Tensor, torch.Tensor]]

Dataset helpers#

We use ImageNette — a 10-class subset of ImageNet — for fast evaluation and calibration.

def _val_transform(crop_size: int = 224, resize_size: int = 256) -> T.Compose:
    return T.Compose(
        [
            T.Resize(resize_size),
            T.CenterCrop(crop_size),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def _train_transform(crop_size: int = 224) -> T.Compose:
    return T.Compose(
        [
            T.RandomResizedCrop(crop_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def make_loaders(
    crop_size: int = 224, resize_size: int = 256
) -> tuple[ImageLoader, ImageLoader]:
    """Create train and validation data loaders for ImageNette."""

    def remap(t: int) -> int:
        return IMAGENETTE_TO_IMAGENET[t]

    train_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "train"),
        transform=_train_transform(crop_size),
        target_transform=remap,
    )
    val_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "val"),
        transform=_val_transform(crop_size, resize_size),
        target_transform=remap,
    )
    train = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )
    val = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
    return train, val

TensorRT engine build#

Parse an ONNX model and compile it into a TensorRT engine with optional FP16 / INT8 precision flags.

def _parse_onnx(onnx_path: Path) -> tuple[trt.Builder, trt.INetworkDefinition]:
    """Parse an ONNX model into a TensorRT network."""
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network()
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with onnx_path.open("rb") as f:
        if not parser.parse(f.read()):
            for err_idx in range(parser.num_errors):
                print(f"  ONNX parse error: {parser.get_error(err_idx)}")
            raise RuntimeError("Failed to parse ONNX model")
    return builder, network


def _load_timing_cache(config: trt.IBuilderConfig, cache_path: Path) -> None:
    """Load a timing cache from disk into a builder config."""
    data = cache_path.read_bytes() if cache_path.exists() else b""
    cache = config.create_timing_cache(data)
    config.set_timing_cache(cache, ignore_mismatch=True)


def build_trt_engine(
    onnx_path: Path,
    engine_path: Path,
    fp16: bool = True,
    int8: bool = False,
    best: bool = False,
) -> Path:
    """Build a TensorRT engine from an ONNX file."""
    builder, network = _parse_onnx(onnx_path)

    config = builder.create_builder_config()
    config.builder_optimization_level = 5

    cache_path = engine_path.parent / "timing.cache"
    _load_timing_cache(config, cache_path)

    if best:
        config.set_flag(trt.BuilderFlag.FP16)
        config.set_flag(trt.BuilderFlag.INT8)
    elif int8:
        config.set_flag(trt.BuilderFlag.INT8)
    elif fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    print("  Building TensorRT engine ...")
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("TensorRT engine build failed")

    cache_path.write_bytes(
        bytes(memoryview(config.get_timing_cache().serialize()))
    )
    engine_path.write_bytes(bytes(serialized))
    print(f"  Engine: {engine_path} ({len(bytes(serialized)) / 1e6:.1f} MB)")
    return engine_path

Latency measurement#

Warm up the engine, then time individual inference calls using CUDA events to get accurate GPU-side latency without host overhead.

def _prepare_trt_context(
    engine_path: Path,
) -> tuple[
    trt.IExecutionContext, torch.cuda.Stream, torch.Tensor, torch.Tensor
]:
    """Load engine, allocate GPU buffers, return ready context."""
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine_path.read_bytes())
    context = engine.create_execution_context()

    input_name = engine.get_tensor_name(0)
    output_name = engine.get_tensor_name(1)
    input_shape = tuple(engine.get_tensor_shape(input_name))
    output_shape = tuple(engine.get_tensor_shape(output_name))

    inp = torch.empty(input_shape, dtype=torch.float32, device="cuda")
    out = torch.empty(output_shape, dtype=torch.float32, device="cuda")
    context.set_tensor_address(input_name, inp.data_ptr())
    context.set_tensor_address(output_name, out.data_ptr())

    return context, torch.cuda.Stream(), inp, out


def measure_latency(
    engine_path: Path,
    warmup_ms: int = 1000,
    duration: int = 30,
) -> dict[str, float]:
    """Measure inference latency using the TensorRT Python API.

    :param engine_path:
        Path to a serialized TensorRT engine.
    :param warmup_ms:
        Warm-up time in milliseconds before recording.
    :param duration:
        Benchmarking duration in seconds.
    :returns:
        Dict with mean/median/p99 latency (ms) and throughput (qps).
    """
    context, stream, _inp, _out = _prepare_trt_context(engine_path)

    # Warm up.
    print("  Warming up ...")
    warmup_until = time.perf_counter() + warmup_ms / 1000
    while time.perf_counter() < warmup_until:
        context.execute_async_v3(stream.cuda_stream)
    stream.synchronize()

    # Timed iterations.
    print(f"  Benchmarking ({duration}s) ...")
    timings: list[float] = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    deadline = time.perf_counter() + duration
    while time.perf_counter() < deadline:
        start.record(stream)
        context.execute_async_v3(stream.cuda_stream)
        end.record(stream)
        end.synchronize()
        timings.append(start.elapsed_time(end))

    mean_ms = statistics.mean(timings)
    median_ms = statistics.median(timings)
    p99_ms = statistics.quantiles(timings, n=100)[-1]
    throughput = len(timings) / (sum(timings) / 1000)

    print(
        f"  Latency: mean={mean_ms:.3f} ms, "
        f"median={median_ms:.3f} ms, "
        f"p99={p99_ms:.3f} ms, throughput={throughput:.1f} qps"
    )
    return {
        "mean_latency_ms": mean_ms,
        "median_latency_ms": median_ms,
        "p99_latency_ms": p99_ms,
        "throughput_qps": throughput,
    }

Inference and accuracy evaluation#

TRTInferencer wraps an engine for batch inference. evaluate_trt runs it over a validation loader and computes Top-1 / Top-5 accuracy.

class TRTInferencer:
    """TensorRT engine wrapper for batch inference."""

    def __init__(self, engine_path: Path):
        """Initialize the TensorRT inferencer."""
        with engine_path.open("rb") as f:
            engine_data = f.read()
        runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
        self.engine = runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()
        self.stream = torch.cuda.Stream()

    def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Run inference on a single batch.

        :param input_tensor:
            NCHW image batch on CUDA.

        :returns:
            Model output tensor.
        """
        input_name = self.engine.get_tensor_name(0)
        output_name = self.engine.get_tensor_name(1)
        self.context.set_input_shape(input_name, tuple(input_tensor.shape))
        output_shape = tuple(self.context.get_tensor_shape(output_name))
        output_tensor = torch.empty(
            output_shape, dtype=torch.float32, device="cuda"
        )
        self.context.set_tensor_address(input_name, input_tensor.data_ptr())
        self.context.set_tensor_address(output_name, output_tensor.data_ptr())
        self.context.execute_async_v3(self.stream.cuda_stream)
        self.stream.synchronize()
        return output_tensor


def evaluate_trt(
    engine_path: Path, val_loader: ImageLoader
) -> dict[str, float]:
    """Top-1/Top-5 accuracy on the validation set."""
    inferencer = TRTInferencer(engine_path)
    top1_sum = top5_sum = total = 0.0

    for batch, targets in val_loader:
        batch_gpu = batch.cuda()
        output = inferencer.infer(batch_gpu)

        output_np = output.cpu().numpy()
        targets_np = targets.numpy()
        pred = np.argsort(output_np, axis=1)[:, ::-1][:, :5]
        top1_sum += (pred[:, :1] == targets_np[:, None]).sum()
        top5_sum += (pred == targets_np[:, None]).any(axis=1).sum()
        total += len(targets_np)

    return {
        "top1": top1_sum / total * 100,
        "top5": top5_sum / total * 100,
    }

Gallery generated by Sphinx-Gallery