# Copyright (C) 2026 Embedl AB
# mypy: disable-error-code="no-untyped-call,no-any-return"
"""
Model deployment and tracking with Embedl Hub
=============================================

In this tutorial, we demonstrate an end-to-end workflow for deploying vision
models using :mod:`embedl_deploy` for TensorRT and tracking the artifacts with
Embedl Hub.

The models selected are from the torchvision library, covering a
range of architectures: ConvNeXt Base, ResNet-50, and ViT-B/16.  We apply
post-training static quantization (PTQ) using the built-in TensorRT pattern set.
Embedl Deploy automatically handles special cases such as depthwise convolutions,
which are memory-bound and left in FP16 to avoid quantization overhead.

The pipeline after transformation and quantization for deploying a model
consists of the following steps:

1. Export the PyTorch model to ONNX and simplify with ``onnxsim``.
2. Build a TensorRT engine (FP16 for baseline, ``--best`` for QDQ
   models).
3. Run TensorRT inference.
4. Measure latency with the TensorRT Python API.



.. note::

   This tutorial requires an NVIDIA GPU with TensorRT 10.x installed.
   You will need to create an account at hub.embedl.com and create an API key
   for logging (in the profile section). You can run ``pip install embedl-hub``
   to install the hub client and then set up the API keys with:
   ``embedl-hub auth --api-key <YOUR_API_KEY>``
"""

from __future__ import annotations

# sphinx_gallery_start_ignore
# pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports,useless-suppression
# sphinx_gallery_end_ignore
# %%
# Constants
# ---------
import json
import platform
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from urllib.parse import urljoin

import numpy as np  # type: ignore[import-not-found]
import tensorrt as trt

# ConvNeXt-Base can exceed default recursion limit during deepcopy.
sys.setrecursionlimit(5000)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

IMAGENETTE_URL = (
    "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
)
DATA_DIR = Path("artifacts/data")
IMAGENETTE_DIR = DATA_DIR / "imagenette2-320"

IMAGENETTE_TO_IMAGENET = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]

BATCH_SIZE = 1
NUM_WORKERS = 8
CALIBRATION_BATCHES = 32

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

API_KEY_ENV = "EMBEDL_HUB_API_KEY"
MODEL_NAME = "resnet50"
HUB_PROJECT = "Deploy Demo"
CROP_SIZE = 224
RESIZE_SIZE = 256
WARMUP_MS = 1000
BENCHMARK_DURATION = 30
BENCHMARK_DIR = Path(f"artifacts/ptq/{MODEL_NAME}_benchmark")
UPLOAD_LARGE_ARTIFACTS = True


@dataclass(frozen=True)
class ExportResult:
    """Files produced by ONNX export."""

    raw_path: Path
    model_path: Path
    simplified: bool


@dataclass(frozen=True)
class QuantizationResult:
    """Embedl Deploy quantization output and summary metrics."""

    model: nn.Module
    fusion_max_diff: float
    activation_quantizers: int
    weight_quantizers: int


@dataclass(frozen=True)
class CompileResult:
    """Files and measurements produced by one compiled candidate."""

    tag: str
    export: ExportResult
    engine_path: Path
    timing_cache_path: Path
    phase_seconds: dict[str, float]
    quantization_metrics: dict[str, float]


@dataclass(frozen=True)
class ProfileResult:
    """Accuracy and latency measurements for one compiled candidate."""

    tag: str
    accuracy: dict[str, float]
    latency: dict[str, float]
    phase_seconds: dict[str, float]


@dataclass(frozen=True)
class SetupResult:
    """Data and model objects produced by the custom setup run."""

    train_loader: _ImageLoader
    val_loader: _ImageLoader
    calibration_batches: list[torch.Tensor]
    pretrained_model: nn.Module
    parameter_count: int
    trainable_parameters: int
    train_images: int
    val_images: int
    downloaded: bool
    model_import_seconds: float
    setup_seconds: float
    mapping_path: Path


def _stringify_param(value: object) -> str:
    """Convert values to stable Hub parameter strings."""
    if isinstance(value, (dict, list, tuple)):
        return json.dumps(value, sort_keys=True)
    return str(value)


def _params(items: dict[str, object]) -> list[tuple[str, str]]:
    return [(key, _stringify_param(value)) for key, value in items.items()]


def _metrics(items: dict[str, float]) -> list[tuple[str, float, int | None]]:
    return [(key, float(value), None) for key, value in items.items()]


def _file_size_mb(path: Path) -> float:
    return path.stat().st_size / 1_000_000 if path.exists() else 0.0


def _write_json(path: Path, payload: object) -> Path:
    path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
    return path


def _project_url(client: Client, project_id: str) -> str:
    return urljoin(client.api_config.base_url, f"projects/{project_id}")


def _log_demo_tags(client: Client, *extra_tags: tuple[str, str]) -> None:
    tags = [
        ("demo", "deploy"),
        ("workflow", "torchvision-tensorrt-ptq"),
        ("dataset", "imagenette"),
        ("runtime", "tensorrt"),
        ("product_demo", "true"),
        *extra_tags,
    ]
    for name, value in tags:
        client.log_tag(name, value)


# %%
# Pattern selection
# -----------------
#
# ``TENSORRT_PATTERNS`` is the default pattern set shipped with Embedl
# Deploy.  It includes structural conversions (e.g. decomposing
# ``MultiheadAttention``), operator fusions (Conv-BN-ReLU, Linear-ReLU,
# LayerNorm, etc.), and **automatic handling of depthwise convolutions**.
# Depthwise convolutions are memory-bound, so quantizing them adds
# TensorRT reformatting overhead that exceeds the compute gain from
# INT8 — Embedl Deploy detects these and keeps them in FP16
# automatically.

# %%
# Dataset helpers
# ---------------
#
# We use `ImageNette <https://github.com/fastai/imagenette>`_ — a
# 10-class subset of ImageNet — for fast evaluation.
import tarfile
import urllib.request

import torch
import torchvision
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader

from embedl_deploy.tensorrt import TENSORRT_PATTERNS
from embedl_hub.tracking import Client


def download_imagenette() -> bool:
    """Download and extract ImageNette if not already present."""
    if IMAGENETTE_DIR.exists():
        print(f"ImageNette already present at {IMAGENETTE_DIR}")
        return False
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    tgz_path = DATA_DIR / "imagenette2-320.tgz"
    print(f"Downloading ImageNette to {tgz_path} ...")
    urllib.request.urlretrieve(IMAGENETTE_URL, str(tgz_path))
    print("Extracting ...")
    with tarfile.open(tgz_path) as tar:
        tar.extractall(DATA_DIR)
    tgz_path.unlink()
    print("Done.")
    return True


def _val_transform(crop_size: int = 224, resize_size: int = 256) -> T.Compose:
    return T.Compose(
        [
            T.Resize(resize_size),
            T.CenterCrop(crop_size),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def _train_transform(crop_size: int = 224) -> T.Compose:
    return T.Compose(
        [
            T.RandomResizedCrop(crop_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


_ImageLoader = DataLoader[tuple[torch.Tensor, torch.Tensor]]


def make_loaders(
    crop_size: int = 224, resize_size: int = 256
) -> tuple[_ImageLoader, _ImageLoader]:
    """Create train and validation data loaders for ImageNette."""

    def remap(t: int) -> int:
        return IMAGENETTE_TO_IMAGENET[t]

    train_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "train"),
        transform=_train_transform(crop_size),
        target_transform=remap,
    )
    val_ds = torchvision.datasets.ImageFolder(
        str(IMAGENETTE_DIR / "val"),
        transform=_val_transform(crop_size, resize_size),
        target_transform=remap,
    )
    train = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )
    val = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
    return train, val


# %%
# ONNX export
# ------------
#
# Export the model to ONNX and simplify with ``onnxsim``.

import onnx
import onnxsim


def export_and_simplify(
    model: nn.Module, onnx_path: Path, crop_size: int = 224
) -> ExportResult:
    """Export to ONNX + simplify with onnxsim."""
    model = model.cpu().eval()
    x = torch.randn(1, 3, crop_size, crop_size)
    torch.onnx.export(
        model,
        (x,),
        str(onnx_path),
        opset_version=20,
        input_names=["input"],
        output_names=["output"],
        dynamo=False,
    )
    onnx_model = onnx.load(str(onnx_path))
    simplified, ok = onnxsim.simplify(onnx_model)
    simp_path = onnx_path.with_name(onnx_path.stem + "_sim.onnx")
    if ok:
        onnx.save(simplified, str(simp_path))
        print(f"  Simplified ONNX: {simp_path}")
    else:
        print("  onnxsim failed; using raw export.")
        simp_path = onnx_path
    return ExportResult(
        raw_path=onnx_path, model_path=simp_path, simplified=ok
    )


# %%
# TensorRT engine build and inference
# ------------------------------------
#
# Build a TensorRT engine from an ONNX file and run inference.

import statistics
import time


def _parse_onnx(onnx_path: Path) -> tuple[trt.Builder, trt.INetworkDefinition]:
    """Parse an ONNX model into a TensorRT network."""
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network()
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with onnx_path.open("rb") as f:
        if not parser.parse(f.read()):
            for err_idx in range(parser.num_errors):
                print(f"  ONNX parse error: {parser.get_error(err_idx)}")
            raise RuntimeError("Failed to parse ONNX model")
    return builder, network


def _load_timing_cache(config: trt.IBuilderConfig, cache_path: Path) -> None:
    """Load a timing cache from disk into a builder config."""
    data = cache_path.read_bytes() if cache_path.exists() else b""
    cache = config.create_timing_cache(data)
    config.set_timing_cache(cache, ignore_mismatch=True)


def build_trt_engine(
    onnx_path: Path,
    engine_path: Path,
    fp16: bool = True,
    int8: bool = False,
    best: bool = False,
) -> Path:
    """Build a TensorRT engine from an ONNX file."""
    builder, network = _parse_onnx(onnx_path)

    config = builder.create_builder_config()
    config.builder_optimization_level = 5

    cache_path = engine_path.parent / "timing.cache"
    _load_timing_cache(config, cache_path)

    if best:
        config.set_flag(trt.BuilderFlag.FP16)
        config.set_flag(trt.BuilderFlag.INT8)
    elif int8:
        config.set_flag(trt.BuilderFlag.INT8)
    elif fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    print("  Building TensorRT engine ...")
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("TensorRT engine build failed")

    cache_path.write_bytes(
        bytes(memoryview(config.get_timing_cache().serialize()))
    )
    engine_path.write_bytes(bytes(serialized))
    print(f"  Engine: {engine_path} ({len(bytes(serialized)) / 1e6:.1f} MB)")
    return engine_path


def _prepare_trt_context(
    engine_path: Path,
) -> tuple[trt.IExecutionContext, torch.cuda.Stream]:
    """Load engine, allocate GPU buffers, return ready context."""
    runtime = trt.Runtime(TRT_LOGGER)
    engine = runtime.deserialize_cuda_engine(engine_path.read_bytes())
    context = engine.create_execution_context()

    input_name = engine.get_tensor_name(0)
    output_name = engine.get_tensor_name(1)
    input_shape = tuple(engine.get_tensor_shape(input_name))
    output_shape = tuple(engine.get_tensor_shape(output_name))

    inp = torch.empty(input_shape, dtype=torch.float32, device="cuda")
    out = torch.empty(output_shape, dtype=torch.float32, device="cuda")
    context.set_tensor_address(input_name, inp.data_ptr())
    context.set_tensor_address(output_name, out.data_ptr())

    return context, torch.cuda.Stream()


def measure_latency(
    engine_path: Path,
    warmup_ms: int = 1000,
    duration: int = 30,
) -> dict[str, float]:
    """Measure inference latency using the TensorRT Python API.

    :param engine_path:
        Path to a serialized TensorRT engine.
    :param warmup_ms:
        Warm-up time in milliseconds before recording.
    :param duration:
        Benchmarking duration in seconds.
    :returns:
        Dict with mean/median/p99 latency (ms) and throughput (qps).
    """
    context, stream = _prepare_trt_context(engine_path)

    # Warm up.
    print("  Warming up ...")
    warmup_until = time.perf_counter() + warmup_ms / 1000
    while time.perf_counter() < warmup_until:
        context.execute_async_v3(stream.cuda_stream)
    stream.synchronize()

    # Timed iterations.
    print(f"  Benchmarking ({duration}s) ...")
    timings: list[float] = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    deadline = time.perf_counter() + duration
    while time.perf_counter() < deadline:
        start.record(stream)
        context.execute_async_v3(stream.cuda_stream)
        end.record(stream)
        end.synchronize()
        timings.append(start.elapsed_time(end))

    mean_ms = statistics.mean(timings)
    median_ms = statistics.median(timings)
    p99_ms = statistics.quantiles(timings, n=100)[-1]
    throughput = len(timings) / (sum(timings) / 1000)

    print(
        f"  Latency: mean={mean_ms:.3f} ms, "
        f"median={median_ms:.3f} ms, "
        f"p99={p99_ms:.3f} ms, throughput={throughput:.1f} qps"
    )
    return {
        "mean_latency_ms": mean_ms,
        "median_latency_ms": median_ms,
        "p99_latency_ms": p99_ms,
        "throughput_qps": throughput,
    }


class TRTInferencer:
    """TensorRT engine wrapper for batch inference."""

    def __init__(self, engine_path: Path):
        """Initialize the TensorRT inferencer."""
        with engine_path.open("rb") as f:
            engine_data = f.read()
        runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
        self.engine = runtime.deserialize_cuda_engine(engine_data)
        self.context = self.engine.create_execution_context()
        self.stream = torch.cuda.Stream()

    def infer(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Run inference on a single batch.

        :param input_tensor:
            NCHW image batch on CUDA.

        :returns:
            Model output tensor.
        """
        input_name = self.engine.get_tensor_name(0)
        output_name = self.engine.get_tensor_name(1)
        self.context.set_input_shape(input_name, tuple(input_tensor.shape))
        output_shape = tuple(self.context.get_tensor_shape(output_name))
        output_tensor = torch.empty(
            output_shape, dtype=torch.float32, device="cuda"
        )
        self.context.set_tensor_address(input_name, input_tensor.data_ptr())
        self.context.set_tensor_address(output_name, output_tensor.data_ptr())
        self.context.execute_async_v3(self.stream.cuda_stream)
        self.stream.synchronize()
        return output_tensor


def evaluate_trt(
    engine_path: Path, val_loader: _ImageLoader
) -> dict[str, float]:
    """Top-1/Top-5 accuracy on the validation set."""
    inferencer = TRTInferencer(engine_path)
    top1_sum = top5_sum = total = 0.0

    for batch, targets in val_loader:
        batch_gpu = batch.cuda()
        output = inferencer.infer(batch_gpu)
        # Use numpy for topk — avoids torch CUDA kernels missing for SM87 (CC 8.7)

        output_np = output.cpu().numpy()
        targets_np = targets.numpy()
        pred = np.argsort(output_np, axis=1)[:, ::-1][:, :5]
        top1_sum += (pred[:, :1] == targets_np[:, None]).sum()
        top5_sum += (pred == targets_np[:, None]).any(axis=1).sum()
        total += len(targets_np)

    return {
        "top1": top1_sum / total * 100,
        "top5": top5_sum / total * 100,
    }


# %%
# Quantization with Embedl Deploy
# ---------------------------------
#
# Fuse layers, insert QDQ stubs with the selective pattern list, and
# calibrate on training data.

from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy import transform
from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy.quantize import (
    QuantConfig,
    QuantStub,
    TensorQuantConfig,
    WeightFakeQuantize,
    quantize,
)


def quantize_embedl(
    pretrained_model: nn.Module,
    calib_batches: list[torch.Tensor],
    crop_size: int = 224,
) -> QuantizationResult:
    """Fuse + quantize + calibrate using Embedl Deploy."""
    print("\n=== Embedl Deploy PTQ ===")

    gm = symbolic_trace(pretrained_model.cpu().eval())
    ShapeProp(gm).propagate(torch.randn(1, 3, crop_size, crop_size))

    fused_model = transform(gm, patterns=TENSORRT_PATTERNS).model

    # Verify lossless fusion.
    with torch.no_grad():
        x = torch.randn(1, 3, crop_size, crop_size)
        max_diff = (pretrained_model.cpu()(x) - fused_model(x)).abs().max()
    assert max_diff < 1e-4, f"Fusion diverged: {max_diff:.2e}"
    print(f"  Fusion check passed (max diff = {max_diff:.2e})")

    print(f"  Calibrating ({len(calib_batches)} batches) ...")

    def forward_loop(model: nn.Module) -> None:
        with torch.no_grad():
            for batch in calib_batches:
                model(batch)

    quantized = quantize(
        fused_model,
        (torch.randn(1, 3, crop_size, crop_size),),
        config=QuantConfig(
            activation=TensorQuantConfig(n_bits=8, symmetric=True),
            weight=TensorQuantConfig(
                n_bits=8,
                symmetric=True,
                per_channel=True,
            ),
        ),
        forward_loop=forward_loop,
    )

    n_act = sum(1 for m in quantized.modules() if isinstance(m, QuantStub))
    n_wt = sum(
        1 for m in quantized.modules() if isinstance(m, WeightFakeQuantize)
    )
    print(f"  QuantStubs: {n_act}, WeightFakeQuantize: {n_wt}")
    print("  Calibration complete.")
    return QuantizationResult(
        model=quantized,
        fusion_max_diff=float(max_diff),
        activation_quantizers=n_act,
        weight_quantizers=n_wt,
    )


# %%
# Benchmark runner
# ----------------
#
# Helper that runs the full export -> build -> evaluate -> latency
# pipeline for a single model variant.


def compile_variant(
    tag: str,
    model: nn.Module,
    *,
    dest: Path,
    calib_data: list[torch.Tensor] | None = None,
    best: bool = False,
    fp16: bool = True,
) -> CompileResult:
    """Export and build one TensorRT deployment candidate."""
    print(f"\n{'=' * 60}\n{tag}\n{'=' * 60}")
    phase_seconds: dict[str, float] = {}
    timing_cache_path = dest / "timing.cache"
    model_to_compile = model
    quantization_result: QuantizationResult | None = None

    if calib_data is not None:
        quant_started = time.perf_counter()
        quantization_result = quantize_embedl(model, calib_data, CROP_SIZE)
        model_to_compile = quantization_result.model
        phase_seconds["quantization_seconds"] = (
            time.perf_counter() - quant_started
        )

    export_started = time.perf_counter()
    export = export_and_simplify(
        model_to_compile, dest / f"{tag}.onnx", CROP_SIZE
    )
    phase_seconds["export_seconds"] = time.perf_counter() - export_started

    engine_path = dest / f"{tag}.engine"
    build_started = time.perf_counter()
    build_trt_engine(
        export.model_path,
        engine_path,
        fp16=fp16,
        best=best,
    )
    phase_seconds["engine_build_seconds"] = time.perf_counter() - build_started

    quant_metrics: dict[str, float] = {}
    if quantization_result is not None:
        quant_metrics = {
            "fusion_max_diff": quantization_result.fusion_max_diff,
            "activation_quantizers": float(
                quantization_result.activation_quantizers
            ),
            "weight_quantizers": float(quantization_result.weight_quantizers),
            "calibration_batches": float(len(calib_data or [])),
        }

    print(f"  Engine: {engine_path} ({_file_size_mb(engine_path):.1f} MB)")
    return CompileResult(
        tag=tag,
        export=export,
        engine_path=engine_path,
        timing_cache_path=timing_cache_path,
        phase_seconds=phase_seconds,
        quantization_metrics=quant_metrics,
    )


def log_compile_run(  # noqa: PLR0913
    client: Client,
    *,
    result: CompileResult,
    tag: str,
    parameter_count: int,
    trainable_parameters: int,
    best: bool,
    fp16: bool,
    candidate_tag: str,
    precision_tag: str,
    upload_large_artifacts: bool,
) -> None:
    """Log compile parameters, metrics, and artifacts."""
    _log_demo_tags(
        client,
        ("stage", "compile"),
        ("candidate", candidate_tag),
        ("model", MODEL_NAME),
        ("target", "h200"),
        ("precision", precision_tag),
    )
    client.log_batch(
        params=_params(
            {
                "variant": tag,
                "model_name": MODEL_NAME,
                "parameter_count": parameter_count,
                "trainable_parameters": trainable_parameters,
                "crop_size": CROP_SIZE,
                "precision": "FP16+INT8" if best else "FP16",
                "tensorrt_best": best,
                "tensorrt_fp16": fp16 or best,
                "onnx_opset": 20,
                "builder_optimization_level": 5,
                "raw_onnx_path": result.export.raw_path,
                "deployment_onnx_path": result.export.model_path,
                "engine_path": result.engine_path,
                "timing_cache_path": result.timing_cache_path,
            }
        ),
        metrics=_metrics(
            {
                **result.phase_seconds,
                **result.quantization_metrics,
                "raw_onnx_size_mb": _file_size_mb(result.export.raw_path),
                "deployment_onnx_size_mb": _file_size_mb(
                    result.export.model_path
                ),
                "engine_size_mb": _file_size_mb(result.engine_path),
                "timing_cache_size_mb": _file_size_mb(
                    result.timing_cache_path
                ),
                "onnx_simplified": (1.0 if result.export.simplified else 0.0),
            }
        ),
    )

    if upload_large_artifacts:
        client.log_artifact(result.export.raw_path, name="raw_onnx")
        if result.export.model_path != result.export.raw_path:
            client.log_artifact(
                result.export.model_path, name="deployment_onnx"
            )
        client.log_artifact(result.engine_path, name="path")
        if result.timing_cache_path.exists():
            client.log_artifact(result.timing_cache_path, name="timing_cache")


def profile_variant(
    tag: str,
    compile_result: CompileResult,
    *,
    val_loader: _ImageLoader,
) -> ProfileResult:
    """Evaluate accuracy and measure latency for one compiled candidate."""
    phase_seconds: dict[str, float] = {}

    eval_started = time.perf_counter()
    accuracy = evaluate_trt(compile_result.engine_path, val_loader)
    phase_seconds["evaluation_seconds"] = time.perf_counter() - eval_started

    latency_started = time.perf_counter()
    latency = measure_latency(
        compile_result.engine_path,
        warmup_ms=WARMUP_MS,
        duration=BENCHMARK_DURATION,
    )
    phase_seconds["latency_benchmark_seconds"] = (
        time.perf_counter() - latency_started
    )

    print(f"  Top-1: {accuracy['top1']:.2f}%  Top-5: {accuracy['top5']:.2f}%")
    print(
        f"  Latency: {latency['mean_latency_ms']:.3f} ms  "
        f"Throughput: {latency['throughput_qps']:.1f} qps"
    )
    return ProfileResult(
        tag=tag,
        accuracy=accuracy,
        latency=latency,
        phase_seconds=phase_seconds,
    )


def log_profile_run(
    client: Client,
    *,
    result: ProfileResult,
    compile_run_id: str,
    compile_result: CompileResult,
    candidate_tag: str,
    precision_tag: str,
) -> None:
    """Log profile parameters and measurements."""
    _log_demo_tags(
        client,
        ("stage", "profile"),
        ("candidate", candidate_tag),
        ("model", MODEL_NAME),
        ("target", "h200"),
        ("precision", precision_tag),
    )
    client.log_batch(
        params=_params(
            {
                "model_name": MODEL_NAME,
                "compile_run_id": compile_run_id,
                "engine_path": compile_result.engine_path,
                "warmup_ms": WARMUP_MS,
                "benchmark_duration_seconds": BENCHMARK_DURATION,
            }
        ),
        metrics=_metrics(
            {
                **result.accuracy,
                **result.latency,
                **result.phase_seconds,
            }
        ),
    )


def create_root_run(
    client: Client,
    *,
    project_name: str,
) -> str:
    """Create the top-level Hub graph run for the product demo."""
    with client.start_run(
        "graph",
        name=f"Deploy Demo {MODEL_NAME} TensorRT PTQ workflow",
    ) as root_run:
        _log_demo_tags(client, ("stage", "root"))
        client.log_batch(
            params=_params(
                {
                    "project": project_name,
                    "model_name": MODEL_NAME,
                    "dataset": "ImageNette",
                    "source_library": "torchvision",
                    "deployment_runtime": "TensorRT",
                    "optimization": "Embedl Deploy PTQ",
                    "hub_base_url": client.api_config.base_url,
                    "crop_size": CROP_SIZE,
                    "resize_size": RESIZE_SIZE,
                    "batch_size": BATCH_SIZE,
                    "num_workers": NUM_WORKERS,
                    "benchmark_dir": BENCHMARK_DIR,
                    "created_at": datetime.now(timezone.utc).isoformat(),
                    "python": platform.python_version(),
                    "platform": platform.platform(),
                    "torch": torch.__version__,
                    "torchvision": torchvision.__version__,
                    "tensorrt": trt.__version__,
                    "cuda_device": torch.cuda.get_device_name(0),
                }
            ),
        )
        return root_run.id


def prepare_demo_setup() -> SetupResult:
    """Prepare data, calibration batches, and the pretrained model."""
    started = time.perf_counter()
    downloaded = download_imagenette()
    train_dl, val_dl = make_loaders(CROP_SIZE, RESIZE_SIZE)

    train_images = len(train_dl.dataset)  # type: ignore[arg-type]
    val_images = len(val_dl.dataset)  # type: ignore[arg-type]

    calib_data: list[torch.Tensor] = []
    for batch_idx, (imgs, _) in enumerate(train_dl):
        if batch_idx >= CALIBRATION_BATCHES:
            break
        calib_data.append(imgs)
    print(f"Collected {len(calib_data)} calibration batches.")

    model_started = time.perf_counter()
    pretrained = torchvision.models.get_model(
        MODEL_NAME, weights="DEFAULT"
    ).eval()
    model_import_seconds = time.perf_counter() - model_started
    parameter_count = sum(p.numel() for p in pretrained.parameters())
    trainable_parameters = sum(
        p.numel() for p in pretrained.parameters() if p.requires_grad
    )
    print(f"Loaded {MODEL_NAME} ({parameter_count:,} params)")

    elapsed = time.perf_counter() - started
    mapping_path = _write_json(
        BENCHMARK_DIR / "imagenette_class_mapping.json",
        {
            "imagenette_url": IMAGENETTE_URL,
            "imagenette_dir": str(IMAGENETTE_DIR),
            "imagenette_to_imagenet": IMAGENETTE_TO_IMAGENET,
            "train_images": train_images,
            "val_images": val_images,
        },
    )

    return SetupResult(
        train_loader=train_dl,
        val_loader=val_dl,
        calibration_batches=calib_data,
        pretrained_model=pretrained,
        parameter_count=parameter_count,
        trainable_parameters=trainable_parameters,
        train_images=train_images,
        val_images=val_images,
        downloaded=downloaded,
        model_import_seconds=model_import_seconds,
        setup_seconds=elapsed,
        mapping_path=mapping_path,
    )


def log_setup_run(
    client: Client,
    *,
    setup: SetupResult,
) -> None:
    """Log setup metadata after data and model preparation."""
    _log_demo_tags(
        client,
        ("stage", "setup"),
        ("model", MODEL_NAME),
        ("custom_type", "demo_setup"),
    )
    client.log_batch(
        params=_params(
            {
                "model_name": MODEL_NAME,
                "source": "torchvision",
                "weights": "DEFAULT",
                "dataset": "ImageNette",
                "dataset_url": IMAGENETTE_URL,
                "dataset_dir": IMAGENETTE_DIR,
                "downloaded_this_run": setup.downloaded,
                "crop_size": CROP_SIZE,
                "resize_size": RESIZE_SIZE,
                "requested_calibration_batches": CALIBRATION_BATCHES,
                "normalization_mean": IMAGENET_MEAN,
                "normalization_std": IMAGENET_STD,
            }
        ),
        metrics=_metrics(
            {
                "train_images": setup.train_images,
                "val_images": setup.val_images,
                "classes": len(IMAGENETTE_TO_IMAGENET),
                "calibration_batches": len(setup.calibration_batches),
                "calibration_images": (
                    len(setup.calibration_batches) * BATCH_SIZE
                ),
                "parameter_count": setup.parameter_count,
                "trainable_parameters": setup.trainable_parameters,
                "model_import_seconds": setup.model_import_seconds,
                "setup_seconds": setup.setup_seconds,
            }
        ),
    )
    client.log_artifact(setup.mapping_path, name="class_mapping")


def log_comparison_run(  # noqa: PLR0913
    client: Client,
    *,
    parent_run_id: str,
    baseline_compile_run_id: str,
    baseline_profile_run_id: str,
    embedl_compile_run_id: str,
    embedl_profile_run_id: str,
    baseline_compile: CompileResult,
    embedl_compile: CompileResult,
    baseline_profile: ProfileResult,
    embedl_profile: ProfileResult,
) -> None:
    """Log the final product-demo comparison and summary artifacts."""
    baseline_acc = baseline_profile.accuracy
    baseline_lat = baseline_profile.latency
    embedl_acc = embedl_profile.accuracy
    embedl_lat = embedl_profile.latency

    speedup_e = baseline_lat["mean_latency_ms"] / max(
        embedl_lat["mean_latency_ms"], 1e-6
    )
    drop_e = baseline_acc["top1"] - embedl_acc["top1"]
    top5_drop_e = baseline_acc["top5"] - embedl_acc["top5"]
    baseline_engine_size = _file_size_mb(baseline_compile.engine_path)
    embedl_engine_size = _file_size_mb(embedl_compile.engine_path)
    engine_size_reduction_pct = (
        (baseline_engine_size - embedl_engine_size)
        / max(baseline_engine_size, 1e-6)
        * 100
    )

    header = (
        f"{'Variant':<25s} {'Top-1':>7s} {'Top-5':>7s} "
        f"{'Latency(ms)':>12s} {'Throughput':>12s}"
    )
    rows = [
        ("Baseline (FP16)", baseline_acc, baseline_lat),
        ("Embedl Deploy (best)", embedl_acc, embedl_lat),
    ]

    print(f"\n{'=' * 80}")
    print(f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette")
    print("=" * 80)
    print(header)
    print("-" * len(header))
    for label, row_acc, row_lat in rows:
        print(
            f"{label:<25s} {row_acc['top1']:6.2f}% "
            f"{row_acc['top5']:6.2f}% "
            f"{row_lat['mean_latency_ms']:11.3f} "
            f"{row_lat['throughput_qps']:10.1f} qps"
        )
    print()
    print(
        f"  Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, speedup: {speedup_e:.2f}x"
    )
    print("=" * 80)

    results_path = BENCHMARK_DIR / f"{MODEL_NAME}_results.txt"
    lines = [
        f"TRT {trt.__version__}",
        "=" * 80,
        f"BENCHMARK SUMMARY - {MODEL_NAME} PTQ on ImageNette",
        "=" * 80,
        header,
        "-" * len(header),
    ]
    for label, row_acc, row_lat in rows:
        lines.append(
            f"{label:<25s} {row_acc['top1']:6.2f}% "
            f"{row_acc['top5']:6.2f}% "
            f"{row_lat['mean_latency_ms']:11.3f} "
            f"{row_lat['throughput_qps']:10.1f} qps"
        )
    lines += [
        "",
        f"  Embedl Deploy - Top-1 drop: {drop_e:+.2f}pp, speedup: {speedup_e:.2f}x",
        "=" * 80,
    ]
    results_path.write_text("\n".join(lines) + "\n")

    summary_path = _write_json(
        BENCHMARK_DIR / f"{MODEL_NAME}_hub_summary.json",
        {
            "model_name": MODEL_NAME,
            "baseline_compile_run_id": baseline_compile_run_id,
            "baseline_profile_run_id": baseline_profile_run_id,
            "embedl_compile_run_id": embedl_compile_run_id,
            "embedl_profile_run_id": embedl_profile_run_id,
            "baseline": {
                "accuracy": baseline_acc,
                "latency": baseline_lat,
                "engine_path": str(baseline_compile.engine_path),
            },
            "embedl": {
                "accuracy": embedl_acc,
                "latency": embedl_lat,
                "engine_path": str(embedl_compile.engine_path),
            },
            "comparison": {
                "speedup_x": speedup_e,
                "top1_drop_pp": drop_e,
                "top5_drop_pp": top5_drop_e,
                "engine_size_reduction_pct": engine_size_reduction_pct,
                "mean_latency_ms_delta": (
                    baseline_lat["mean_latency_ms"]
                    - embedl_lat["mean_latency_ms"]
                ),
            },
        },
    )

    with client.start_run(
        "eval",
        name="Compare deployment candidates",
        parent_run_id=parent_run_id,
    ):
        _log_demo_tags(
            client,
            ("stage", "eval"),
            ("candidate", "comparison"),
            ("model", MODEL_NAME),
        )
        client.log_batch(
            params=_params(
                {
                    "model_name": MODEL_NAME,
                    "baseline_compile_run_id": baseline_compile_run_id,
                    "baseline_profile_run_id": baseline_profile_run_id,
                    "embedl_compile_run_id": embedl_compile_run_id,
                    "embedl_profile_run_id": embedl_profile_run_id,
                }
            ),
            metrics=_metrics(
                {
                    "baseline_top1": baseline_acc["top1"],
                    "baseline_top5": baseline_acc["top5"],
                    "baseline_mean_latency_ms": baseline_lat[
                        "mean_latency_ms"
                    ],
                    "baseline_throughput_qps": baseline_lat["throughput_qps"],
                    "embedl_top1": embedl_acc["top1"],
                    "embedl_top5": embedl_acc["top5"],
                    "embedl_mean_latency_ms": embedl_lat["mean_latency_ms"],
                    "embedl_throughput_qps": embedl_lat["throughput_qps"],
                    "speedup_x": speedup_e,
                    "top1_drop_pp": drop_e,
                    "top5_drop_pp": top5_drop_e,
                    "baseline_engine_size_mb": baseline_engine_size,
                    "embedl_engine_size_mb": embedl_engine_size,
                    "engine_size_reduction_pct": engine_size_reduction_pct,
                }
            ),
        )
        client.log_artifact(results_path, name="benchmark_results_text")
        client.log_artifact(summary_path, name="benchmark_summary_json")
    print(f"\nResults saved to {results_path}")


# %%
# Environment check and Hub client
# --------------------------------
#
# Confirm a Hub API key is configured, make sure the local benchmark
# directory exists, and verify we have a CUDA GPU — the TensorRT
# compile + INT8 calibration steps below require one. Then create the
# :class:`Client`, which is the single entry point for talking to the
# Embedl Hub, and select (or create) the project that will hold this
# experiment's runs.

BENCHMARK_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert device.type == "cuda", "This benchmark requires a CUDA GPU."

client = Client()
project = client.set_project(HUB_PROJECT)

print(f"Model: {MODEL_NAME}  crop={CROP_SIZE}  resize={RESIZE_SIZE}")
print(f"Output: {BENCHMARK_DIR}")
print(f"Hub project: {_project_url(client, project.id)}")

# %%
# Root run and dataset setup
# --------------------------
#
# Every variant we benchmark is logged as a child of one *root* run so
# they show up grouped together in the Hub UI. ``create_root_run``
# creates that parent. The first child run, ``demo_setup``, downloads
# and preprocesses ImageNette and builds the calibration / validation
# loaders that every subsequent run will share.

root_run_id = create_root_run(
    client,
    project_name=project.name,
)
with client.start_run(
    "demo_setup",
    name="Prepare ImageNette",
    parent_run_id=root_run_id,
) as setup_run:
    setup = prepare_demo_setup()
    log_setup_run(
        client,
        setup=setup,
    )
setup_run_id = setup_run.id

# %%
# Baseline: compile the unmodified model in FP16
# ----------------------------------------------
#
# The first variant is the reference point: the pretrained model
# exported and compiled with TensorRT in plain FP16 — no Embedl
# transforms, no INT8. ``compile_variant`` exports the model to ONNX,
# runs ``trtexec``, and returns paths to the engine + metadata. The
# ``log_compile_run`` call attaches engine size, build time, and the
# precision tag to the Hub run.

with client.start_run(
    "compile",
    name="Compile baseline FP16",
    parent_run_id=setup_run_id,
) as baseline_compile_run:
    baseline_compile = compile_variant(
        "baseline_fp16",
        setup.pretrained_model,
        dest=BENCHMARK_DIR,
    )
    log_compile_run(
        client,
        result=baseline_compile,
        tag="baseline_fp16",
        parameter_count=setup.parameter_count,
        trainable_parameters=setup.trainable_parameters,
        best=False,
        fp16=True,
        candidate_tag="baseline-fp16",
        precision_tag="fp16",
        upload_large_artifacts=UPLOAD_LARGE_ARTIFACTS,
    )
baseline_compile_run_id = baseline_compile_run.id

# %%
# Profile the baseline engine
# ---------------------------
#
# With the FP16 engine built, the profile run measures accuracy on the
# ImageNette validation set and end-to-end latency / throughput.
# Nesting this run under the compile run keeps the relationship
# explicit on the Hub: each profile points back at the exact engine it
# measured.

with client.start_run(
    "profile",
    name="Profile baseline FP16",
    parent_run_id=baseline_compile_run_id,
) as baseline_profile_run:
    baseline_profile = profile_variant(
        "baseline_fp16",
        baseline_compile,
        val_loader=setup.val_loader,
    )
    log_profile_run(
        client,
        result=baseline_profile,
        compile_run_id=baseline_compile_run_id,
        compile_result=baseline_compile,
        candidate_tag="baseline-fp16",
        precision_tag="fp16",
    )
baseline_profile_run_id = baseline_profile_run.id

# %%
# Embedl variant: PTQ compile in mixed INT8/FP16
# ----------------------------------------------
#
# This is the variant the benchmark is actually evaluating. We hand
# the same pretrained model to ``compile_variant`` with
# ``calib_data=setup.calibration_batches`` and ``best=True``, which
# applies the Embedl TensorRT patterns, runs INT8 post-training
# quantization using the calibration batches, then compiles. The
# resulting engine runs in mixed INT8/FP16 precision — INT8 wherever
# the patterns placed QDQ pairs, FP16 elsewhere.

with client.start_run(
    "compile",
    name="Compile Embedl PTQ",
    parent_run_id=setup_run_id,
) as embedl_compile_run:
    embedl_compile = compile_variant(
        "embedl_ptq",
        setup.pretrained_model,
        dest=BENCHMARK_DIR,
        calib_data=setup.calibration_batches,
        best=True,
    )
    log_compile_run(
        client,
        result=embedl_compile,
        tag="embedl_ptq",
        parameter_count=setup.parameter_count,
        trainable_parameters=setup.trainable_parameters,
        best=True,
        fp16=True,
        candidate_tag="embedl-ptq",
        precision_tag="mixed-int8-fp16",
        upload_large_artifacts=UPLOAD_LARGE_ARTIFACTS,
    )
embedl_compile_run_id = embedl_compile_run.id

# %%
# Profile the Embedl engine
# -------------------------
#
# Same accuracy + latency measurements as the baseline profile, but on
# the INT8/FP16 engine. Running both profiles with identical loaders
# and the same ``profile_variant`` helper is what makes the side-by-
# side numbers in the next step a fair comparison.

with client.start_run(
    "profile",
    name="Profile Embedl PTQ",
    parent_run_id=embedl_compile_run_id,
) as embedl_profile_run:
    embedl_profile = profile_variant(
        "embedl_ptq",
        embedl_compile,
        val_loader=setup.val_loader,
    )
    log_profile_run(
        client,
        result=embedl_profile,
        compile_run_id=embedl_compile_run_id,
        compile_result=embedl_compile,
        candidate_tag="embedl-ptq",
        precision_tag="mixed-int8-fp16",
    )
embedl_profile_run_id = embedl_profile_run.id

# %%
# Comparison run
# --------------
#
# Finally, log a comparison run as a direct child of the root run.
# This computes the headline numbers — speedup, top-1 / top-5 drop,
# engine size reduction — and uploads them along with the rendered
# benchmark report. Putting it under the root run (not under either
# variant) is what gives the Hub a single place to surface the
# baseline-vs-Embedl summary for this experiment.

log_comparison_run(
    client,
    parent_run_id=root_run_id,
    baseline_compile_run_id=baseline_compile_run_id,
    baseline_profile_run_id=baseline_profile_run_id,
    embedl_compile_run_id=embedl_compile_run_id,
    embedl_profile_run_id=embedl_profile_run_id,
    baseline_compile=baseline_compile,
    embedl_compile=embedl_compile,
    baseline_profile=baseline_profile,
    embedl_profile=embedl_profile,
)
