Optimization Pipeline#

Embedl Deploy prepares PyTorch models for hardware deployment through a three-stage pipeline. Each stage operates on a standard torch.fx.GraphModule and produces a numerically equivalent model.

Pipeline stages#

┌──────────────┐     ┌──────────────┐     ┌──────────────┐
│ Conversions  │ ──▶ │   Fusions    │ ──▶ │ Quantization │
│ (structural) │     │ (operator)   │     │   (INT8)     │
└──────────────┘     └──────────────┘     └──────────────┘
  1. Conversions — structural graph rewrites that normalize the computation graph for the target hardware. These are necessary because hardware compilers like TensorRT have specific operator requirements and fusion patterns. For example, replacing Flatten Linear with Conv2d(1×1) Flatten enables TensorRT to fuse the classifier head with preceding convolutions. Similarly, decomposing nn.MultiheadAttention into explicit sub-modules exposes the Q/K/V projections for proper QDQ placement. Conversions run iteratively until no new matches are found.

  2. Fusions — combine sequences of operators into single fused modules that map directly to hardware-accelerated kernels. For example, Conv2d BatchNorm2d ReLU becomes a single FusedConvBNReLU module. Fusions are matched and applied in a single pass, with longer patterns taking priority over shorter ones.

  3. Quantization — insert Quantize/DeQuantize (QDQ) stubs at pattern-declared positions, calibrate scale/zero-point from representative data, and optionally run Quantization-Aware Training (QAT). The result is an ONNX-exportable model with explicit QDQ nodes that hardware compilers (like TensorRT) use for INT8 inference.

One-shot API#

The simplest way to run the pipeline is transform():

import torch
from torchvision.models import resnet50

from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_PATTERNS

model = resnet50(weights="DEFAULT").eval()
result = transform(model, patterns=TENSORRT_PATTERNS)

deployable = result.model  # fused GraphModule
print(f"Applied: {result.report['applied_count']}")
print(f"Skipped: {result.report['skipped_count']}")

transform() accepts both nn.Module and fx.GraphModule. When given an nn.Module, it traces internally with torch.fx.symbolic_trace. For models with dynamic control flow or custom tracing requirements, trace the model yourself and pass the GraphModule directly:

# Option 1: Let transform() handle tracing
result = transform(model, patterns=TENSORRT_PATTERNS)

# Option 2: Trace yourself for full control
graph_module = torch.fx.symbolic_trace(model)
result = transform(graph_module, patterns=TENSORRT_PATTERNS)

When transform() receives an nn.Module, it internally:

  1. Traces the model with torch.fx.symbolic_trace.

  2. Runs all conversion patterns (is_conversion=True) iteratively.

  3. Runs all fusion patterns in a single pass.

Plan-based API#

For full control, use the two-step workflow:

from embedl_deploy import get_transformation_plan, apply_transformation_plan
from embedl_deploy.tensorrt import TENSORRT_PATTERNS

graph_module = torch.fx.symbolic_trace(model)
plan = get_transformation_plan(graph_module, patterns=TENSORRT_PATTERNS)

# Inspect matches
for node_name, pats in plan.matches.items():
    for pat_name, match in pats.items():
        print(f"{node_name}: {pat_name} (apply={match.apply})")

# Disable specific matches
plan.matches["maxpool"]["StemConvBNReLUMaxPoolPattern"].apply = False

# Apply
result = apply_transformation_plan(plan)

The plan is a TransformationPlan dataclass with:

  • model — deep copy of the traced graph (never modifies the original).

  • matches — nested dict: node_name pattern_name PatternMatch.

Each PatternMatch has an apply flag. Set it to False to skip that match.

Pattern priority#

Patterns are matched in list order. When multiple patterns match overlapping nodes, the first match claims the nodes and later matches are marked apply=False. This is why pattern lists are ordered longest-first:

TENSORRT_FUSION_PATTERNS = [
    StemConvBNReLUMaxPoolPattern(),   # 4 nodes
    ConvBNAddReLUPattern(),           # 4 nodes (branching)
    ConvBNReLUPattern(),              # 3 nodes
    LinearReLUPattern(),              # 2 nodes
    ConvBNPattern(),                  # 2 nodes
    LinearPattern(),                  # 1 node
    LayerNormPattern(),               # 1 node
    AdaptiveAvgPoolPattern(),         # 1 node
    MHAInProjectionPattern(),         # 1 node
    ScaledDotProductAttentionPattern(), # 1 node
]

For ResNet50, StemConvBNReLUMaxPoolPattern claims the stem’s Conv BN ReLU MaxPool nodes. The shorter ConvBNReLUPattern also matches the first three of those nodes, but they are already consumed — so it is automatically skipped for that subgraph.

Pattern groups#

The TensorRT backend exposes pre-built pattern lists:

List

Contents

TENSORRT_CONVERSION_PATTERNS

Structural conversions (run first)

TENSORRT_FUSION_PATTERNS

Operator fusions

TENSORRT_QUANTIZED_PATTERNS

Quantization-focused rewrites

TENSORRT_PATTERNS

Union of all the above

from embedl_deploy.tensorrt import (
    TENSORRT_CONVERSION_PATTERNS,
    TENSORRT_FUSION_PATTERNS,
    TENSORRT_PATTERNS,
    TENSORRT_QUANTIZED_PATTERNS,
)

# Run only fusions (skip conversions)
result = transform(model, patterns=TENSORRT_FUSION_PATTERNS)

# Run only conversions
result = transform(model, patterns=TENSORRT_CONVERSION_PATTERNS)

Verifying numerical equivalence#

Conversions and fusions must not change model outputs. Always verify:

with torch.no_grad():
    y_original = model(example_input)
    y_deployed = result.model(example_input)

max_diff = (y_original - y_deployed).abs().max().item()
assert max_diff < 1e-5, f"Numerical mismatch: {max_diff}"

ONNX export and compilation#

After transformation, export and compile as usual:

torch.onnx.export(
    result.model.cpu().eval(),
    torch.randn(1, 3, 224, 224),
    "model_fused.onnx",
    opset_version=20,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)

Then compile with TensorRT:

trtexec \
  --onnx=model_fused.onnx \
  --fp16 \
  --exportLayerInfo=layer_info.json \
  --profilingVerbosity=detailed