# Copyright (C) 2026 Embedl AB

"""
Deploying vision models with Embedl Deploy
==========================================

In this tutorial, we demonstrate an end-to-end workflow for deploying vision
models using :mod:`embedl_deploy` for TensorRT. The models selected are
from the torchvision library, covering a range of architectures: ConvNeXt
Tiny, ResNet-50, and ViT-B/16.  We apply post-training static quantization
(PTQ) using the built-in TensorRT pattern set.  Embedl Deploy
automatically handles special cases such as depthwise convolutions,
which are memory-bound and left in FP16 to avoid quantization overhead.

The pipeline after transformation and quantization for deploying a model
consists of the following steps:

1. Export the PyTorch model to ONNX and simplify with ``onnxsim``.
2. Build a TensorRT engine (FP16 for baseline, ``--best`` for QDQ
   models).
3. Run TensorRT inference.
4. Measure latency with the TensorRT Python API.

.. note::

   This tutorial requires an NVIDIA GPU with TensorRT 10.x installed.
"""

# sphinx_gallery_start_ignore
# pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports,useless-suppression
# sphinx_gallery_end_ignore

# %%
# Constants
# ---------

import sys
from pathlib import Path

import tensorrt as trt

# ConvNeXt-Base can exceed default recursion limit during deepcopy.
sys.setrecursionlimit(5000)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

IMAGENETTE_URL = (
    "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
)
DATA_DIR = Path("artifacts/data")
IMAGENETTE_DIR = DATA_DIR / "imagenette2-320"

IMAGENETTE_TO_IMAGENET = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]

BATCH_SIZE = 1
NUM_WORKERS = 8
CALIBRATION_BATCHES = 32

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

# %%
# Pattern selection
# -----------------
#
# ``TENSORRT_PATTERNS`` is the default pattern set shipped with Embedl
# Deploy.  It includes structural conversions (e.g. decomposing
# ``MultiheadAttention``), operator fusions (Conv-BN-ReLU, Linear-ReLU,
# LayerNorm, etc.), and **automatic handling of depthwise convolutions**.
# Depthwise convolutions are memory-bound, so quantizing them adds
# TensorRT reformatting overhead that exceeds the compute gain from
# INT8 — Embedl Deploy detects these and keeps them in FP16
# automatically.

# %%
# Dataset helpers
# ---------------
#
# We use `ImageNette <https://github.com/fastai/imagenette>`_ — a
# 10-class subset of ImageNet — for fast evaluation.
import tarfile
import urllib.request

import torchvision
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader

from embedl_deploy.tensorrt import TENSORRT_PATTERNS


def download_imagenette() -> None:
    """Download and extract ImageNette if not already present."""
    if IMAGENETTE_DIR.exists():
        print(f"ImageNette already present at {IMAGENETTE_DIR}")
        return
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    tgz_path = DATA_DIR / "imagenette2-320.tgz"
    print(f"Downloading ImageNette to {tgz_path} ...")
    urllib.request.urlretrieve(IMAGENETTE_URL, str(tgz_path))  # noqa: S310
    print("Extracting ...")
    with tarfile.open(tgz_path) as tar:
        tar.extractall(DATA_DIR)  # noqa: S202
    tgz_path.unlink()
    print("Done.")


def _val_transform(crop_size: int = 224, resize_size: int = 256) -> T.Compose:
    return T.Compose(
        [
            T.Resize(resize_size),
            T.CenterCrop(crop_size),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def _train_transform(crop_size: int = 224) -> T.Compose:
    return T.Compose(
        [
            T.RandomResizedCrop(crop_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def make_loaders(
    crop_size: int = 224, resize_size: int = 256
) -> tuple[DataLoader, DataLoader]:
    """Create train and validation data loaders for ImageNette."""

    def remap(t):
        return IMAGENETTE_TO_IMAGENET[t]

    train_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "train"),
        transform=_train_transform(crop_size),
        target_transform=remap,
    )
    val_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "val"),
        transform=_val_transform(crop_size, resize_size),
        target_transform=remap,
    )
    train = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )
    val = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
    return train, val


# %%
# ONNX export
# ------------
#
# Export the model to ONNX and simplify with ``onnxsim``.

import onnx
import onnxsim
import torch


def export_and_simplify(
    model: nn.Module, onnx_path: Path, crop_size: int = 224
) -> Path:
    """Export to ONNX + simplify with onnxsim."""
    model = model.cpu().eval()
    x = torch.randn(1, 3, crop_size, crop_size)
    torch.onnx.export(
        model,
        (x,),
        str(onnx_path),
        opset_version=20,
        input_names=["input"],
        output_names=["output"],
        dynamo=False,
    )
    onnx_model = onnx.load(str(onnx_path))
    simplified, ok = onnxsim.simplify(onnx_model)
    simp_path = onnx_path.with_name(onnx_path.stem + "_sim.onnx")
    if ok:
        onnx.save(simplified, str(simp_path))
        print(f"  Simplified ONNX: {simp_path}")
    else:
        print("  onnxsim failed; using raw export.")
        simp_path = onnx_path
    return simp_path


# %%
# TensorRT engine build and inference
# ------------------------------------
#
# Build a TensorRT engine from an ONNX file and run inference.

import statistics
import time


def _parse_onnx(onnx_path: Path) -> tuple[trt.Builder, trt.INetworkDefinition]:
    """Parse an ONNX model into a TensorRT network."""
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network()
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with open(onnx_path, "rb") as f:
        if not parser.parse(f.read()):
            for err_idx in range(parser.num_errors):
                print(f"  ONNX parse error: {parser.get_error(err_idx)}")
            raise RuntimeError("Failed to parse ONNX model")
    return builder, network


def _load_timing_cache(config: trt.IBuilderConfig, cache_path: Path) -> None:
    """Load a timing cache from disk into a builder config."""
    data = cache_path.read_bytes() if cache_path.exists() else b""
    cache = config.create_timing_cache(data)
    config.set_timing_cache(cache, ignore_mismatch=True)


def build_trt_engine(
    onnx_path: Path,
    engine_path: Path,
    fp16: bool = True,
    int8: bool = False,
    best: bool = False,
) -> Path:
    """Build a TensorRT engine from an ONNX file."""
    builder, network = _parse_onnx(onnx_path)

    config = builder.create_builder_config()
    config.builder_optimization_level = 5

    cache_path = engine_path.parent / "timing.cache"
    _load_timing_cache(config, cache_path)

    if best:
        config.set_flag(trt.BuilderFlag.FP16)
        config.set_flag(trt.BuilderFlag.INT8)
    elif int8:
        config.set_flag(trt.BuilderFlag.INT8)
    elif fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    print("  Building TensorRT engine ...")
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("TensorRT engine build failed")

    cache_path.write_bytes(
        bytes(memoryview(config.get_timing_cache().serialize()))
    )
    engine_path.write_bytes(bytes(serialized))
    print(f"  Engine: {engine_path} ({len(bytes(serialized)) / 1e6:.1f} MB)")
    return engine_path


def _prepare_trt_context(
    engine_path: Path,
) -> tuple[trt.IExecutionContext, torch.cuda.Stream]:
    """Load engine, allocate GPU buffers, return ready context."""
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine_path.read_bytes())
    context = engine.create_execution_context()

    input_name = engine.get_tensor_name(0)
    output_name = engine.get_tensor_name(1)
    input_shape = tuple(engine.get_tensor_shape(input_name))
    output_shape = tuple(engine.get_tensor_shape(output_name))

    inp = torch.empty(input_shape, dtype=torch.float32, device="cuda")
    out = torch.empty(output_shape, dtype=torch.float32, device="cuda")
    context.set_tensor_address(input_name, inp.data_ptr())
    context.set_tensor_address(output_name, out.data_ptr())

    return context, torch.cuda.Stream()


def measure_latency(
    engine_path: Path,
    warmup_ms: int = 1000,
    duration: int = 30,
) -> dict:
    """Measure inference latency using the TensorRT Python API.

    :param engine_path:
        Path to a serialized TensorRT engine.
    :param warmup_ms:
        Warm-up time in milliseconds before recording.
    :param duration:
        Benchmarking duration in seconds.
    :returns:
        Dict with mean/median/p99 latency (ms) and throughput (qps).
    """
    context, stream = _prepare_trt_context(engine_path)

    # Warm up.
    print("  Warming up ...")
    warmup_until = time.perf_counter() + warmup_ms / 1000
    while time.perf_counter() < warmup_until:
        context.execute_async_v3(stream.cuda_stream)
    stream.synchronize()

    # Timed iterations.
    print(f"  Benchmarking ({duration}s) ...")
    timings: list[float] = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    deadline = time.perf_counter() + duration
    while time.perf_counter() < deadline:
        start.record(stream)
        context.execute_async_v3(stream.cuda_stream)
        end.record(stream)
        end.synchronize()
        timings.append(start.elapsed_time(end))

    mean_ms = statistics.mean(timings)
    median_ms = statistics.median(timings)
    p99_ms = statistics.quantiles(timings, n=100)[-1]
    throughput = len(timings) / (sum(timings) / 1000)

    print(
        f"  Latency: mean={mean_ms:.3f} ms, "
        f"median={median_ms:.3f} ms, "
        f"p99={p99_ms:.3f} ms, throughput={throughput:.1f} qps"
    )
    return {
        "mean_latency_ms": mean_ms,
        "median_latency_ms": median_ms,
        "p99_latency_ms": p99_ms,
        "throughput_qps": throughput,
    }


class TRTInferencer:
    """TensorRT engine wrapper for batch inference."""

    def __init__(self, engine_path: Path):
        with open(engine_path, "rb") as f:
            engine_data = f.read()
        runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
        self.engine = runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()
        self.stream = torch.cuda.Stream()

    def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Run inference on a single batch.

        :param input_tensor:
            NCHW image batch on CUDA.

        :returns:
            Model output tensor.
        """
        input_name = self.engine.get_tensor_name(0)
        output_name = self.engine.get_tensor_name(1)
        self.context.set_input_shape(input_name, tuple(input_tensor.shape))
        output_shape = tuple(self.context.get_tensor_shape(output_name))
        output_tensor = torch.empty(
            output_shape, dtype=torch.float32, device="cuda"
        )
        self.context.set_tensor_address(input_name, input_tensor.data_ptr())
        self.context.set_tensor_address(output_name, output_tensor.data_ptr())
        self.context.execute_async_v3(self.stream.cuda_stream)
        self.stream.synchronize()
        return output_tensor


def evaluate_trt(
    engine_path: Path, val_loader: DataLoader
) -> dict[str, float]:
    """Top-1/Top-5 accuracy on the validation set."""
    inferencer = TRTInferencer(engine_path)
    top1_sum = top5_sum = total = 0

    for batch, targets in val_loader:
        batch, targets = batch.cuda(), targets.cuda()
        output = inferencer.infer(batch)
        _, pred = output.topk(5, dim=1, largest=True, sorted=True)
        correct = pred.t().eq(targets.view(1, -1).expand_as(pred.t()))
        top1_sum += correct[:1].reshape(-1).float().sum().item()
        top5_sum += correct[:5].reshape(-1).float().sum().item()
        total += batch.size(0)

    return {
        "top1": top1_sum / total * 100,
        "top5": top5_sum / total * 100,
    }


# %%
# Quantization with Embedl Deploy
# ---------------------------------
#
# Fuse layers, insert QDQ stubs with the selective pattern list, and
# calibrate on training data.

from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy import transform
from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy.quantize import (
    QuantConfig,
    QuantStub,
    TensorQuantConfig,
    WeightFakeQuantize,
    quantize,
)


def quantize_embedl(
    pretrained_model: nn.Module,
    calib_batches: list[torch.Tensor],
    crop_size: int = 224,
) -> nn.Module:
    """Fuse + quantize + calibrate using Embedl Deploy."""
    print("\n=== Embedl Deploy PTQ ===")

    gm = symbolic_trace(pretrained_model.cpu().eval())
    ShapeProp(gm).propagate(torch.randn(1, 3, crop_size, crop_size))

    fused_model = transform(gm, patterns=TENSORRT_PATTERNS).model

    # Verify lossless fusion.
    with torch.no_grad():
        x = torch.randn(1, 3, crop_size, crop_size)
        max_diff = (pretrained_model.cpu()(x) - fused_model(x)).abs().max()
    assert max_diff < 1e-4, f"Fusion diverged: {max_diff:.2e}"
    print(f"  Fusion check passed (max diff = {max_diff:.2e})")

    print(f"  Calibrating ({len(calib_batches)} batches) ...")

    def forward_loop(model):
        with torch.no_grad():
            for batch in calib_batches:
                model(batch)

    quantized = quantize(
        fused_model,
        (torch.randn(1, 3, crop_size, crop_size),),
        config=QuantConfig(
            activation=TensorQuantConfig(n_bits=8, symmetric=True),
            weight=TensorQuantConfig(
                n_bits=8,
                symmetric=True,
                per_channel=True,
            ),
        ),
        forward_loop=forward_loop,
    )

    n_act = sum(1 for m in quantized.modules() if isinstance(m, QuantStub))
    n_wt = sum(
        1 for m in quantized.modules() if isinstance(m, WeightFakeQuantize)
    )
    print(f"  QuantStubs: {n_act}, WeightFakeQuantize: {n_wt}")
    print("  Calibration complete.")
    return quantized


# %%
# Benchmark runner
# ----------------
#
# Helper that runs the full export -> build -> evaluate -> latency
# pipeline for a single model variant.


def run_variant(
    tag: str,
    model: nn.Module,
    *,
    dest: Path,
    val_loader: DataLoader | None,
    crop_size: int,
    best: bool = False,
    fp16: bool = True,
) -> tuple[dict, dict]:
    """Export, build, evaluate, and measure one model variant."""
    print(f"\n{'=' * 60}\n{tag}\n{'=' * 60}")
    onnx_path = export_and_simplify(model, dest / f"{tag}.onnx", crop_size)
    engine_path = dest / f"{tag}.engine"
    build_trt_engine(onnx_path, engine_path, fp16=fp16, best=best)

    acc = (
        evaluate_trt(engine_path, val_loader)
        if val_loader is not None
        else {"top1": 0.0, "top5": 0.0}
    )
    latency = measure_latency(engine_path)

    if val_loader is not None:
        print(f"  Top-1: {acc['top1']:.2f}%  Top-5: {acc['top5']:.2f}%")
    print(
        f"  Latency: {latency['mean_latency_ms']:.3f} ms  "
        f"Throughput: {latency['throughput_qps']:.1f} qps"
    )
    return acc, latency


# %%
# Configuration
# -------------
#
# Choose the model and image sizes.  Change ``MODEL_NAME`` to benchmark
# a different architecture (e.g. ``"convnext_base"``, ``"resnet50"``,
# ``"vit_b_16"``).

MODEL_NAME = "convnext_tiny"
CROP_SIZE = 224
RESIZE_SIZE = 256

BENCHMARK_DIR = Path(f"artifacts/ptq/{MODEL_NAME}_benchmark")
BENCHMARK_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "This benchmark requires a CUDA GPU."

print(f"Model: {MODEL_NAME}  crop={CROP_SIZE}  resize={RESIZE_SIZE}")
print(f"Output: {BENCHMARK_DIR}")

# %%
# Prepare data and model
# -----------------------
#
# Download ImageNette and load the pretrained torchvision model.

download_imagenette()
train_dl, val_dl = make_loaders(CROP_SIZE, RESIZE_SIZE)

calib_data: list[torch.Tensor] = []
for batch_idx, (imgs, _) in enumerate(train_dl):
    if batch_idx >= CALIBRATION_BATCHES:
        break
    calib_data.append(imgs)
print(f"Collected {len(calib_data)} calibration batches.")

pretrained = torchvision.models.get_model(MODEL_NAME, weights="DEFAULT").eval()
print(
    f"Loaded {MODEL_NAME} "
    f"({sum(p.numel() for p in pretrained.parameters()):,} params)"
)

# %%
# Baseline FP16
# -------------
#
# Build and benchmark the unquantized model with FP16 precision.

baseline_acc, baseline_lat = run_variant(
    "baseline_fp16",
    pretrained,
    dest=BENCHMARK_DIR,
    val_loader=val_dl,
    crop_size=CROP_SIZE,
)

# %%
# Embedl Deploy mixed-precision
# -------------------------------
#
# Apply ``TENSORRT_PATTERNS``.  Depthwise convolutions are
# automatically kept in FP16 while compute-bound operators are
# quantized to INT8.

embedl_model = quantize_embedl(pretrained, calib_data, CROP_SIZE)
embedl_acc, embedl_lat = run_variant(
    "embedl_mixed_precision",
    embedl_model,
    dest=BENCHMARK_DIR,
    val_loader=val_dl,
    crop_size=CROP_SIZE,
    best=True,
)

# %%
# Summary
# -------
#
# Compare accuracy and latency across all variants.

HEADER = (
    f"{'Variant':<25s} {'Top-1':>7s} {'Top-5':>7s} "
    f"{'Latency(ms)':>12s} {'Throughput':>12s}"
)
rows = [
    ("Baseline (FP16)", baseline_acc, baseline_lat),
    ("Embedl Deploy (best)", embedl_acc, embedl_lat),
]

print(f"\n{'=' * 80}")
print(f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette")
print("=" * 80)
print(HEADER)
print("-" * len(HEADER))
for label, row_acc, row_lat in rows:
    print(
        f"{label:<25s} {row_acc['top1']:6.2f}% {row_acc['top5']:6.2f}% "
        f"{row_lat['mean_latency_ms']:11.3f} "
        f"{row_lat['throughput_qps']:10.1f} qps"
    )

speedup_e = baseline_lat["mean_latency_ms"] / max(
    embedl_lat["mean_latency_ms"], 1e-6
)
drop_e = baseline_acc["top1"] - embedl_acc["top1"]

print()
print(
    f"  Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, speedup: {speedup_e:.2f}x"
)
print("=" * 80)

# Save results.
RESULTS_PATH = BENCHMARK_DIR / f"{MODEL_NAME}_results.txt"
lines = [
    f"TRT {trt.__version__}",
    "=" * 80,
    f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette",
    "=" * 80,
    HEADER,
    "-" * len(HEADER),
]
for label, row_acc, row_lat in rows:
    lines.append(
        f"{label:<25s} {row_acc['top1']:6.2f}% {row_acc['top5']:6.2f}% "
        f"{row_lat['mean_latency_ms']:11.3f} "
        f"{row_lat['throughput_qps']:10.1f} qps"
    )
lines += [
    "",
    f"  Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, "
    f"speedup: {speedup_e:.2f}x",
    "=" * 80,
]
RESULTS_PATH.write_text("\n".join(lines) + "\n")
print(f"\nResults saved to {RESULTS_PATH}")
