{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# TensorRT inference and evaluation helpers\n\nShared utilities for building TensorRT engines, running inference, and\nmeasuring accuracy and latency. The other tutorials import from this module so\nthey can focus on the Embedl Deploy workflow rather than TensorRT boilerplate.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>This tutorial requires an NVIDIA GPU with TensorRT 10.x installed.</p></div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Constants\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import statistics\nimport sys\nimport time\nfrom pathlib import Path\n\nimport numpy as np  # type: ignore[import-not-found]\nimport tensorrt as trt\nimport torch\nimport torchvision\nimport torchvision.transforms as T\nfrom torch.utils.data import DataLoader\n\n# ConvNeXt-Base can exceed default recursion limit during deepcopy.\nsys.setrecursionlimit(5000)\n\nIMAGENET_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_STD = (0.229, 0.224, 0.225)\n\nIMAGENETTE_URL = (\n    \"https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz\"\n)\nDATA_DIR = Path(\"artifacts/data\")\nIMAGENETTE_DIR = DATA_DIR / \"imagenette2-320\"\n\nIMAGENETTE_TO_IMAGENET = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]\n\nBATCH_SIZE = 1\nNUM_WORKERS = 8\nCALIBRATION_BATCHES = 32\n\nTRT_LOGGER = trt.Logger(trt.Logger.WARNING)\n\nImageLoader = DataLoader[tuple[torch.Tensor, torch.Tensor]]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Dataset helpers\n\nWe use [ImageNette](https://github.com/fastai/imagenette) \u2014 a\n10-class subset of ImageNet \u2014 for fast evaluation and calibration.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _val_transform(crop_size: int = 224, resize_size: int = 256) -> T.Compose:\n    return T.Compose(\n        [\n            T.Resize(resize_size),\n            T.CenterCrop(crop_size),\n            T.ToTensor(),\n            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n        ]\n    )\n\n\ndef _train_transform(crop_size: int = 224) -> T.Compose:\n    return T.Compose(\n        [\n            T.RandomResizedCrop(crop_size),\n            T.RandomHorizontalFlip(),\n            T.ToTensor(),\n            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),\n        ]\n    )\n\n\ndef make_loaders(\n    crop_size: int = 224, resize_size: int = 256\n) -> tuple[ImageLoader, ImageLoader]:\n    \"\"\"Create train and validation data loaders for ImageNette.\"\"\"\n\n    def remap(t: int) -> int:\n        return IMAGENETTE_TO_IMAGENET[t]\n\n    train_ds = torchvision.datasets.ImageFolder(\n        str(IMAGENETTE_DIR / \"train\"),\n        transform=_train_transform(crop_size),\n        target_transform=remap,\n    )\n    val_ds = torchvision.datasets.ImageFolder(\n        str(IMAGENETTE_DIR / \"val\"),\n        transform=_val_transform(crop_size, resize_size),\n        target_transform=remap,\n    )\n    train = DataLoader(\n        train_ds,\n        batch_size=BATCH_SIZE,\n        shuffle=True,\n        num_workers=NUM_WORKERS,\n        pin_memory=True,\n        drop_last=True,\n    )\n    val = DataLoader(\n        val_ds,\n        batch_size=BATCH_SIZE,\n        shuffle=False,\n        num_workers=NUM_WORKERS,\n        pin_memory=True,\n    )\n    print(f\"Train: {len(train_ds)} images, Val: {len(val_ds)} images\")\n    return train, val"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## TensorRT engine build\n\nParse an ONNX model and compile it into a TensorRT engine with\noptional FP16 / INT8 precision flags.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _parse_onnx(onnx_path: Path) -> tuple[trt.Builder, trt.INetworkDefinition]:\n    \"\"\"Parse an ONNX model into a TensorRT network.\"\"\"\n    builder = trt.Builder(TRT_LOGGER)\n    network = builder.create_network()\n    parser = trt.OnnxParser(network, TRT_LOGGER)\n    with onnx_path.open(\"rb\") as f:\n        if not parser.parse(f.read()):\n            for err_idx in range(parser.num_errors):\n                print(f\"  ONNX parse error: {parser.get_error(err_idx)}\")\n            raise RuntimeError(\"Failed to parse ONNX model\")\n    return builder, network\n\n\ndef _load_timing_cache(config: trt.IBuilderConfig, cache_path: Path) -> None:\n    \"\"\"Load a timing cache from disk into a builder config.\"\"\"\n    data = cache_path.read_bytes() if cache_path.exists() else b\"\"\n    cache = config.create_timing_cache(data)\n    config.set_timing_cache(cache, ignore_mismatch=True)\n\n\ndef build_trt_engine(\n    onnx_path: Path,\n    engine_path: Path,\n    fp16: bool = True,\n    int8: bool = False,\n    best: bool = False,\n) -> Path:\n    \"\"\"Build a TensorRT engine from an ONNX file.\"\"\"\n    builder, network = _parse_onnx(onnx_path)\n\n    config = builder.create_builder_config()\n    config.builder_optimization_level = 5\n\n    cache_path = engine_path.parent / \"timing.cache\"\n    _load_timing_cache(config, cache_path)\n\n    if best:\n        config.set_flag(trt.BuilderFlag.FP16)\n        config.set_flag(trt.BuilderFlag.INT8)\n    elif int8:\n        config.set_flag(trt.BuilderFlag.INT8)\n    elif fp16:\n        config.set_flag(trt.BuilderFlag.FP16)\n\n    print(\"  Building TensorRT engine ...\")\n    serialized = builder.build_serialized_network(network, config)\n    if serialized is None:\n        raise RuntimeError(\"TensorRT engine build failed\")\n\n    cache_path.write_bytes(\n        bytes(memoryview(config.get_timing_cache().serialize()))\n    )\n    engine_path.write_bytes(bytes(serialized))\n    print(f\"  Engine: {engine_path} ({len(bytes(serialized)) / 1e6:.1f} MB)\")\n    return engine_path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Latency measurement\n\nWarm up the engine, then time individual inference calls using CUDA\nevents to get accurate GPU-side latency without host overhead.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _prepare_trt_context(\n    engine_path: Path,\n) -> tuple[\n    trt.IExecutionContext, torch.cuda.Stream, torch.Tensor, torch.Tensor\n]:\n    \"\"\"Load engine, allocate GPU buffers, return ready context.\"\"\"\n    runtime = trt.Runtime(TRT_LOGGER)\n    engine = runtime.deserialize_cuda_engine(engine_path.read_bytes())\n    context = engine.create_execution_context()\n\n    input_name = engine.get_tensor_name(0)\n    output_name = engine.get_tensor_name(1)\n    input_shape = tuple(engine.get_tensor_shape(input_name))\n    output_shape = tuple(engine.get_tensor_shape(output_name))\n\n    inp = torch.empty(input_shape, dtype=torch.float32, device=\"cuda\")\n    out = torch.empty(output_shape, dtype=torch.float32, device=\"cuda\")\n    context.set_tensor_address(input_name, inp.data_ptr())\n    context.set_tensor_address(output_name, out.data_ptr())\n\n    return context, torch.cuda.Stream(), inp, out\n\n\ndef measure_latency(\n    engine_path: Path,\n    warmup_ms: int = 1000,\n    duration: int = 30,\n) -> dict[str, float]:\n    \"\"\"Measure inference latency using the TensorRT Python API.\n\n    :param engine_path:\n        Path to a serialized TensorRT engine.\n    :param warmup_ms:\n        Warm-up time in milliseconds before recording.\n    :param duration:\n        Benchmarking duration in seconds.\n    :returns:\n        Dict with mean/median/p99 latency (ms) and throughput (qps).\n    \"\"\"\n    context, stream, _inp, _out = _prepare_trt_context(engine_path)\n\n    # Warm up.\n    print(\"  Warming up ...\")\n    warmup_until = time.perf_counter() + warmup_ms / 1000\n    while time.perf_counter() < warmup_until:\n        context.execute_async_v3(stream.cuda_stream)\n    stream.synchronize()\n\n    # Timed iterations.\n    print(f\"  Benchmarking ({duration}s) ...\")\n    timings: list[float] = []\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n    deadline = time.perf_counter() + duration\n    while time.perf_counter() < deadline:\n        start.record(stream)\n        context.execute_async_v3(stream.cuda_stream)\n        end.record(stream)\n        end.synchronize()\n        timings.append(start.elapsed_time(end))\n\n    mean_ms = statistics.mean(timings)\n    median_ms = statistics.median(timings)\n    p99_ms = statistics.quantiles(timings, n=100)[-1]\n    throughput = len(timings) / (sum(timings) / 1000)\n\n    print(\n        f\"  Latency: mean={mean_ms:.3f} ms, \"\n        f\"median={median_ms:.3f} ms, \"\n        f\"p99={p99_ms:.3f} ms, throughput={throughput:.1f} qps\"\n    )\n    return {\n        \"mean_latency_ms\": mean_ms,\n        \"median_latency_ms\": median_ms,\n        \"p99_latency_ms\": p99_ms,\n        \"throughput_qps\": throughput,\n    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inference and accuracy evaluation\n\n``TRTInferencer`` wraps an engine for batch inference.\n``evaluate_trt`` runs it over a validation loader and computes\nTop-1 / Top-5 accuracy.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class TRTInferencer:\n    \"\"\"TensorRT engine wrapper for batch inference.\"\"\"\n\n    def __init__(self, engine_path: Path):\n        \"\"\"Initialize the TensorRT inferencer.\"\"\"\n        with engine_path.open(\"rb\") as f:\n            engine_data = f.read()\n        runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))\n        self.engine = runtime.deserialize_cuda_engine(engine_data)\n        self.context = self.engine.create_execution_context()\n        self.stream = torch.cuda.Stream()\n\n    def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:\n        \"\"\"Run inference on a single batch.\n\n        :param input_tensor:\n            NCHW image batch on CUDA.\n\n        :returns:\n            Model output tensor.\n        \"\"\"\n        input_name = self.engine.get_tensor_name(0)\n        output_name = self.engine.get_tensor_name(1)\n        self.context.set_input_shape(input_name, tuple(input_tensor.shape))\n        output_shape = tuple(self.context.get_tensor_shape(output_name))\n        output_tensor = torch.empty(\n            output_shape, dtype=torch.float32, device=\"cuda\"\n        )\n        self.context.set_tensor_address(input_name, input_tensor.data_ptr())\n        self.context.set_tensor_address(output_name, output_tensor.data_ptr())\n        self.context.execute_async_v3(self.stream.cuda_stream)\n        self.stream.synchronize()\n        return output_tensor\n\n\ndef evaluate_trt(\n    engine_path: Path, val_loader: ImageLoader\n) -> dict[str, float]:\n    \"\"\"Top-1/Top-5 accuracy on the validation set.\"\"\"\n    inferencer = TRTInferencer(engine_path)\n    top1_sum = top5_sum = total = 0.0\n\n    for batch, targets in val_loader:\n        batch_gpu = batch.cuda()\n        output = inferencer.infer(batch_gpu)\n\n        output_np = output.cpu().numpy()\n        targets_np = targets.numpy()\n        pred = np.argsort(output_np, axis=1)[:, ::-1][:, :5]\n        top1_sum += (pred[:, :1] == targets_np[:, None]).sum()\n        top5_sum += (pred == targets_np[:, None]).any(axis=1).sum()\n        total += len(targets_np)\n\n    return {\n        \"top1\": top1_sum / total * 100,\n        \"top5\": top5_sum / total * 100,\n    }"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}