# Copyright (C) 2026 Embedl AB

"""
ResNet50 deployment with TensorRT Patterns
==========================================

This quickstart shows how to prepare a PyTorch ResNet50 model for
deployment with ``embedl_deploy``.

The core idea is simple:

1. A ``transform`` applies a list of graph rewrite patterns.
2. A ``plan`` lets you inspect and edit every match before applying it.

In the current public release, the packaged backend is **TensorRT**.
The core API is backend-agnostic, and additional backends may be added
over time.

For TensorRT, the pattern library is split into three buckets:

- **Conversions**: structural rewrites run first to normalize graphs.
- **Fusions**: combine layer sequences into fused modules.
- **Quantized patterns**: quantization-focused rewrites (currently empty
  in this release, but included in the API).

.. note::

   This tutorial uses random weights for illustration.  In practice you
   would start from a pre-trained checkpoint.
"""

# sphinx_gallery_start_ignore
# pylint: disable=wrong-import-position,wrong-import-order
# sphinx_gallery_end_ignore

# %%
# Setup
# -----
#
# We start by loading a standard ``torchvision`` model in eval mode.

import torch
from torchvision.models import resnet50

model = resnet50(weights=None).eval()
example_input = torch.randn(1, 3, 224, 224)

print(f"Model: {type(model).__name__}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# %%
# One-shot transformation with ``transform()``
# ----------------------------------------------
#
# The simplest way to prepare a model is to call
# :func:`~embedl_deploy.transform` with a pattern list:

from embedl_deploy import transform
from embedl_deploy.tensorrt import (
    TENSORRT_CONVERSION_PATTERNS,
    TENSORRT_FUSION_PATTERNS,
    TENSORRT_PATTERNS,
    TENSORRT_QUANTIZED_PATTERNS,
)

print("\nTensorRT pattern groups:")
print(f"  conversions: {len(TENSORRT_CONVERSION_PATTERNS)}")
print(f"  fusions:     {len(TENSORRT_FUSION_PATTERNS)}")
print(f"  quantized:   {len(TENSORRT_QUANTIZED_PATTERNS)}")
print(f"  total:       {len(TENSORRT_PATTERNS)}")

if TENSORRT_QUANTIZED_PATTERNS:
    print("Quantized patterns in this build:")
    for pat in TENSORRT_QUANTIZED_PATTERNS:
        print(f"  - {type(pat).__name__}")
else:
    print(
        "Quantized patterns are exposed in the API but empty in this release."
    )

deployed = transform(model, patterns=TENSORRT_PATTERNS).model

# %%
# Inspect the fused modules
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The transformed model contains fused ``nn.Module`` subclasses instead
# of the original separate ``Conv``, ``BatchNorm``, and ``ReLU`` layers.

from collections import Counter

from embedl_deploy.tensorrt.modules import (
    FusedAdaptiveAvgPool2d,
    FusedConvBN,
    FusedConvBNAct,
    FusedConvBNActMaxPool,
    FusedConvBNAddAct,
)

FUSED_MODULE_TYPES = (
    FusedConvBN,
    FusedConvBNAddAct,
    FusedConvBNAct,
    FusedConvBNActMaxPool,
    FusedAdaptiveAvgPool2d,
)

fused_counts = Counter(
    type(m).__name__
    for m in deployed.modules()
    if isinstance(m, FUSED_MODULE_TYPES)
)

print("Fused modules in the transformed model:\n")
for name, count in sorted(fused_counts.items()):
    print(f"  {name:<25s} {count}")
print(f"  {'TOTAL':<25s} {sum(fused_counts.values())}")

# In ResNet50, common fusions include:
# - Stem block: ``Conv`` + ``BatchNorm`` + ``ReLU`` + ``MaxPool``
# - Main path: ``Conv`` + ``BatchNorm`` or ``Conv`` + ``BatchNorm`` + ``ReLU``
# - Residual blocks: ``Conv`` + ``BatchNorm`` + ``Add`` + ``ReLU``
# - Tail rewrite: ``AdaptiveAvgPool`` handling before classifier export

# %%
# Verify numerical equivalence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The fused model must produce bit-for-bit identical outputs (no weight
# folding has happened yet — that is left to the TensorRT compiler).

with torch.no_grad():
    y_original = model(example_input)
    y_deployed = deployed(example_input)

max_diff = (y_original - y_deployed).abs().max().item()
print(f"Max output difference: {max_diff:.2e}")
assert max_diff < 1e-5, f"Numerical mismatch: {max_diff}"
print("✓ Fused model is numerically equivalent to the original.")

# %%
# Plan-based workflow
# --------------------
#
# For full control, use the two-step workflow:
# :func:`~embedl_deploy.get_transformation_plan` to discover matches,
# then :func:`~embedl_deploy.apply_transformation_plan` to apply them.
#
# The plan is **editable** — you can toggle ``match.apply = False``
# to skip specific matches before applying.

from embedl_deploy import apply_transformation_plan, get_transformation_plan

graph_module = torch.fx.symbolic_trace(model)
plan = get_transformation_plan(graph_module, patterns=TENSORRT_PATTERNS)

print(f"\nDiscovered {sum(len(v) for v in plan.matches.values())} matches:\n")
for node_name, patterns in plan.matches.items():
    for pat_name, match in patterns.items():
        print(f"  {node_name}: {pat_name} (apply={match.apply})")

# %%
# Apply the plan without changes:

result = apply_transformation_plan(plan)

print(f"Applied: {result.report['applied_count']}")
print(f"Skipped: {result.report['skipped_count']}")

# %%
# .. note::
#
#    A non-zero ``skipped_count`` is expected and intentional.
#
#    All patterns are matched against the graph independently first.
#    The plan then iterates through the results **in pattern-list order**,
#    building a set of *consumed* nodes.  When a match's nodes overlap with
#    nodes already claimed by an earlier match, it is marked
#    ``apply=False`` and counted as skipped.
#
#    For ResNet50 with ``TENSORRT_PATTERNS`` this means:
#    ``StemConvBNActMaxPoolPattern`` (listed first) claims the
#    ``Conv→BN→ReLU→MaxPool`` stem nodes.  The shorter
#    ``ConvBNActPattern``, ``ConvBNPattern``, etc. also match sub-chains
#    within that same stem, so they are skipped because their nodes were
#    already consumed.
#
#    This is how pattern priority works: supply the **longest / most
#    specific** patterns first so they take precedence over shorter,
#    more general ones when sub-graphs overlap.

# %%
# Edit the plan before applying
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Disable a specific match by setting ``apply = False``:

plan2 = get_transformation_plan(graph_module, patterns=TENSORRT_PATTERNS)

# Disable fusion of the stem
plan2.matches["maxpool"]["StemConvBNActMaxPoolPattern"].apply = False

result2 = apply_transformation_plan(plan2)

fused2_counts = Counter(
    type(m).__name__
    for m in result2.model.modules()
    if isinstance(m, FUSED_MODULE_TYPES)
)
print(
    f"Fused modules (stem skipped): {sum(fused2_counts.values())} "
    f"(was {sum(fused_counts.values())})"
)
print(f"Skipped: {result2.report['skipped_count']}")

# %%
# Toggle only conversion or fusion patterns
# ------------------------------------------
#
# You can scope the transformation by choosing pattern groups directly.

only_fusions = transform(model, patterns=TENSORRT_FUSION_PATTERNS).model
only_conversions = transform(
    model, patterns=TENSORRT_CONVERSION_PATTERNS
).model

only_fusions_count = sum(
    1 for m in only_fusions.modules() if isinstance(m, FUSED_MODULE_TYPES)
)
only_conversions_count = sum(
    1 for m in only_conversions.modules() if isinstance(m, FUSED_MODULE_TYPES)
)

print("\nPattern group experiments:")
print(f"  only fusions -> fused module count: {only_fusions_count}")
print(f"  only conversions -> fused module count: {only_conversions_count}")

# %%
# Selective transformation
# -------------------------
#
# Because every transformation is a ``Pattern`` object, you have full
# control.  For example, to fuse only Conv→BN→ReLU chains:
#
# .. code-block:: python
#
#    from embedl_deploy.tensorrt import ConvBNActPattern
#
#    selective = transform(model, patterns=[ConvBNActPattern()])
#
# Or use the plan to cherry-pick:
#
# .. code-block:: python
#
#    plan = get_transformation_plan(model, patterns=TENSORRT_PATTERNS)
#    # Disable everything except ConvBNAct
#    for pats in plan.matches.values():
#        for pat_name, match in pats.items():
#            if pat_name != "ConvBNActPattern":
#                match.apply = False
#    result = apply_transformation_plan(plan)

# %%
# Quantization patterns: inspect and toggle
# ------------------------------------------
#
# Quantization rewrites follow the exact same plan/edit/apply workflow.

quant_plan = get_transformation_plan(
    result.model, patterns=TENSORRT_QUANTIZED_PATTERNS
)
quant_match_count = sum(len(v) for v in quant_plan.matches.values())
print(f"\nQuantization-plan matches: {quant_match_count}")

if quant_match_count:
    # Example: disable all quantization matches before applying.
    for pats in quant_plan.matches.values():
        for match in pats.values():
            match.apply = False

    quant_result = apply_transformation_plan(quant_plan)
    print(
        "Applied quantized patterns after disabling all matches: "
        f"{quant_result.report['applied_count']}"
    )
else:
    print(
        "No quantization matches in this release. "
        "When quantized patterns are added, you can toggle them the same way."
    )

# %%
# Visualising transforms
# ----------------------
#
# The images below show the layer mapping before and after TensorRT
# compilation. `Embedl Visualizer <https://hub.embedl.com/visualizer>`_ renders
# PyTorch graphs, ONNX models, and hardware-compiled artifacts (e.g., TensorRT
# engines) side-by-side for comparison and debugging. It is available online
# for public use on `Embedl Hub <https://hub.embedl.com>`_ and locally for
# enterprise solutions.
#
# .. raw:: html
#
#    <div style="display: flex; gap: 1rem; align-items: flex-start;">
#      <figure style="flex: 1; text-align: center; margin: 0;">
#        <figcaption><strong>Design in PyTorch</strong></figcaption>
#        <img src="https://6631582.fs1.hubspotusercontent-na1.net/hubfs/6631582/deploy_mapping.png"
#             alt="Layer mapping — PyTorch to ONNX"
#             style="width: 100%;">
#      </figure>
#      <figure style="flex: 1; text-align: center; margin: 0;">
#        <figcaption><strong>Deploy on edge</strong></figcaption>
#        <img
#          src="https://6631582.fs1.hubspotusercontent-na1.net/hubfs/6631582/deploy_mapping_trt.png"
#          alt="Layer mapping — ONNX to TensorRT"
#          style="width: 100%;">
#      </figure>
#    </div>
#

# %%
# Next steps
# ----------
#
# After these graph rewrites, the model is ready for:
#
# - **ONNX export** — ``torch.onnx.export(deployed, example_input, "resnet50.onnx")``
# - **Quantization** — enable quantized patterns as they become available
# - **TensorRT compilation** — compile the ONNX model to a TensorRT engine
#
# .. code-block:: bash
#
#    /usr/src/tensorrt/bin/trtexec --onnx=resnet50_fused.onnx \
#        --exportLayerInfo=layer_info.json \
#        --profilingVerbosity=detailed \
#        --exportProfile=profile.json
