Optimization Pipeline#
Embedl Deploy prepares PyTorch models for hardware deployment through a
three-stage pipeline. Each stage operates on a standard torch.fx.GraphModule
and produces a numerically equivalent model.
Pipeline stages#
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Conversions │ ──▶ │ Fusions │ ──▶ │ Quantization │
│ (structural) │ │ (operator) │ │ (INT8) │
└──────────────┘ └──────────────┘ └──────────────┘
Conversions — structural graph rewrites that normalize the computation graph for the target hardware. These are necessary because hardware compilers like TensorRT have specific operator requirements and fusion patterns. For example, replacing
Flatten → LinearwithConv2d(1×1) → Flattenenables TensorRT to fuse the classifier head with preceding convolutions. Similarly, decomposingnn.MultiheadAttentioninto explicit sub-modules exposes the Q/K/V projections for proper QDQ placement. Conversions run iteratively until no new matches are found.Fusions — combine sequences of operators into single fused modules that map directly to hardware-accelerated kernels. For example,
Conv2d → BatchNorm2d → ReLUbecomes a singleFusedConvBNReLUmodule. Fusions are matched and applied in a single pass, with longer patterns taking priority over shorter ones.Quantization — insert Quantize/DeQuantize (QDQ) stubs at pattern-declared positions, calibrate scale/zero-point from representative data, and optionally run Quantization-Aware Training (QAT). The result is an ONNX-exportable model with explicit QDQ nodes that hardware compilers (like TensorRT) use for INT8 inference.
One-shot API#
The simplest way to run the pipeline is transform():
import torch
from torchvision.models import resnet50
from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
model = resnet50(weights="DEFAULT").eval()
result = transform(model, patterns=TENSORRT_PATTERNS)
deployable = result.model # fused GraphModule
print(f"Applied: {result.report['applied_count']}")
print(f"Skipped: {result.report['skipped_count']}")
transform() accepts both nn.Module and fx.GraphModule. When given an
nn.Module, it traces internally with torch.fx.symbolic_trace. For models
with dynamic control flow or custom tracing requirements, trace the model
yourself and pass the GraphModule directly:
# Option 1: Let transform() handle tracing
result = transform(model, patterns=TENSORRT_PATTERNS)
# Option 2: Trace yourself for full control
graph_module = torch.fx.symbolic_trace(model)
result = transform(graph_module, patterns=TENSORRT_PATTERNS)
When transform() receives an nn.Module, it internally:
Traces the model with
torch.fx.symbolic_trace.Runs all conversion patterns (
is_conversion=True) iteratively.Runs all fusion patterns in a single pass.
Plan-based API#
For full control, use the two-step workflow:
from embedl_deploy import get_transformation_plan, apply_transformation_plan
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
graph_module = torch.fx.symbolic_trace(model)
plan = get_transformation_plan(graph_module, patterns=TENSORRT_PATTERNS)
# Inspect matches
for node_name, pats in plan.matches.items():
for pat_name, match in pats.items():
print(f"{node_name}: {pat_name} (apply={match.apply})")
# Disable specific matches
plan.matches["maxpool"]["StemConvBNReLUMaxPoolPattern"].apply = False
# Apply
result = apply_transformation_plan(plan)
The plan is a TransformationPlan dataclass with:
model— deep copy of the traced graph (never modifies the original).matches— nested dict:node_name → pattern_name → PatternMatch.
Each PatternMatch has an apply flag. Set it to False to skip that match.
Pattern priority#
Patterns are matched in list order. When multiple patterns match overlapping
nodes, the first match claims the nodes and later matches are marked
apply=False. This is why pattern lists are ordered longest-first:
TENSORRT_FUSION_PATTERNS = [
StemConvBNReLUMaxPoolPattern(), # 4 nodes
ConvBNAddReLUPattern(), # 4 nodes (branching)
ConvBNReLUPattern(), # 3 nodes
LinearReLUPattern(), # 2 nodes
ConvBNPattern(), # 2 nodes
LinearPattern(), # 1 node
LayerNormPattern(), # 1 node
AdaptiveAvgPoolPattern(), # 1 node
MHAInProjectionPattern(), # 1 node
ScaledDotProductAttentionPattern(), # 1 node
]
For ResNet50, StemConvBNReLUMaxPoolPattern claims the stem’s
Conv → BN → ReLU → MaxPool nodes. The shorter ConvBNReLUPattern also
matches the first three of those nodes, but they are already consumed — so it
is automatically skipped for that subgraph.
Pattern groups#
The TensorRT backend exposes pre-built pattern lists:
List |
Contents |
|---|---|
|
Structural conversions (run first) |
|
Operator fusions |
|
Quantization-focused rewrites |
|
Union of all the above |
from embedl_deploy.tensorrt import (
TENSORRT_CONVERSION_PATTERNS,
TENSORRT_FUSION_PATTERNS,
TENSORRT_PATTERNS,
TENSORRT_QUANTIZED_PATTERNS,
)
# Run only fusions (skip conversions)
result = transform(model, patterns=TENSORRT_FUSION_PATTERNS)
# Run only conversions
result = transform(model, patterns=TENSORRT_CONVERSION_PATTERNS)
Verifying numerical equivalence#
Conversions and fusions must not change model outputs. Always verify:
with torch.no_grad():
y_original = model(example_input)
y_deployed = result.model(example_input)
max_diff = (y_original - y_deployed).abs().max().item()
assert max_diff < 1e-5, f"Numerical mismatch: {max_diff}"
ONNX export and compilation#
After transformation, export and compile as usual:
torch.onnx.export(
result.model.cpu().eval(),
torch.randn(1, 3, 224, 224),
"model_fused.onnx",
opset_version=20,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)
Then compile with TensorRT:
trtexec \
--onnx=model_fused.onnx \
--fp16 \
--exportLayerInfo=layer_info.json \
--profilingVerbosity=detailed