Operator Fusions#
Fusions combine sequences of operators into single fused modules that map
directly to hardware-accelerated kernels. After fusion, the model graph
contains modules like FusedConvBNReLU instead of separate Conv2d,
BatchNorm2d, and ReLU layers.
Fused modules:
Are numerically equivalent to the original operator sequence (no weight folding at this stage — that’s handled by the hardware compiler).
Declare QDQ insertion points so quantization stubs are placed correctly.
Export cleanly to ONNX for downstream compilation.
Convolution fusions#
ConvBNPattern#
Matches: Conv2d → [BatchNorm2d]
Produces: FusedConvBN
The most basic convolution fusion. The BatchNorm2d is optional — a bare
Conv2d is also matched (useful for ensuring QDQ stub placement even on
convolutions without batch normalization).
QDQ points: INPUT
ConvBNReLUPattern#
Matches: Conv2d → [BatchNorm2d] → Activation
Produces: FusedConvBNReLU
Fuses convolution, optional batch normalization, and an activation function.
The pattern is named after the most common case (ReLU), but it matches any of
these activations: ReLU, ReLU6, LeakyReLU, ELU, GELU, SiLU,
Hardswish, Hardsigmoid.
QDQ points: INPUT
StemConvBNReLUMaxPoolPattern#
Matches: Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d
Produces: FusedConvBNReLUMaxPool
Captures the classification network stem found in ResNet and similar
architectures. The convolution is constrained to in_channels=3, kernel_size=(7,7) so only the actual stem is matched.
QDQ points: INPUT
ConvBNAddReLUPattern#
Matches: Conv2d → BatchNorm2d → add(·, residual) → Activation
Produces: FusedConvBNAddReLU
Captures the tail of ResNet-style bottleneck blocks where the convolution
path merges with a skip connection before the final activation. This is a
branching pattern — it matches a Fork topology with two inputs feeding
into an operator.add node.
QDQ points: INPUT, RESIDUAL_INPUT
The RESIDUAL_INPUT QDQ point ensures the skip connection is also quantized,
which is critical for TensorRT to fuse the entire residual block into a single
INT8 kernel.
Linear fusions#
LinearReLUPattern#
Matches: Linear → Activation
Produces: FusedLinearReLU
QDQ points: INPUT
LinearPattern#
Matches: standalone Linear
Produces: FusedLinear
QDQ points: INPUT
LayerNormPattern#
Matches: LayerNorm
Produces: FusedLayerNorm
QDQ points: none (empty qdq_points)
LayerNorm is not quantized — it operates element-wise and runs efficiently in FP16/FP32. Placing QDQ stubs around LayerNorm would hurt accuracy without improving latency.
Attention fusions#
These patterns match the sub-modules produced by the
DecomposeMultiheadAttentionPattern conversion.
MHAInProjectionPattern#
Matches: MHAInProjection
Produces: FusedMHAInProjection
QDQ points: INPUT
ScaledDotProductAttentionPattern#
Matches: ScaledDotProductAttention
Produces: FusedScaledDotProductAttention
QDQ points: INPUT, KEY_INPUT, VALUE_INPUT
The three QDQ points ensure that Q, K, and V tensors are independently quantized, matching TensorRT’s expected input format for fused attention kernels.
Pooling fusions#
AdaptiveAvgPoolPattern#
Matches: AdaptiveAvgPool2d
Produces: FusedAdaptiveAvgPool2d
QDQ points: INPUT, OUTPUT
Note
In smart-quantization workflows, the AdaptiveAvgPoolPattern is
intentionally omitted from the pattern list to skip QDQ placement around
global average pooling. This is because GlobalAvgPool is element-wise and
memory-bound — quantizing it adds TensorRT reformatting overhead without
meaningful compute savings. See Custom Patterns for details.
Fusion summary by architecture#
ResNet50#
Fused module |
Count |
Description |
|---|---|---|
|
1 |
Stem: Conv(7×7) + BN + ReLU + MaxPool |
|
16 |
Bottleneck residual blocks |
|
16 |
Main-path Conv + BN + ReLU |
|
17 |
Conv + BN without activation |
|
1 |
Global average pool |
Total |
51 |
Conversions applied: FlattenLinearToConv1x1Pattern converts the
Flatten → Linear classifier into Conv2d(1×1) → Flatten.
ConvNeXt (Tiny/Base/Large)#
Fused module |
Count |
Description |
|---|---|---|
|
36/72/108 |
Depthwise + pointwise convolutions |
|
27/54/81 |
Conv + BN + GELU chains |
|
27/54/81 |
LayerNorm layers (unquantized) |
|
27/54/81 |
Standalone linear layers |
Counts increase with model depth (Tiny: 9 stages × 3, Base: 9 × 6, Large: 9 × 9).
Conversions applied:
RemoveIdentityAdaptiveAvgPoolPatternremoves identity pooling ops.FlattenLinearToConv1x1Patternconverts the classifier head.
ConvNeXt uses depthwise separable convolutions extensively. The default pattern set quantizes all convolutions equally, but mixed-precision (see Custom Patterns) skips depthwise convolutions for better latency.
Vision Transformer (ViT-B/16)#
Fused module |
Count |
Description |
|---|---|---|
|
1 |
Patch embedding Conv2d |
|
12 |
Q/K/V projections |
|
12 |
Attention computation |
|
36+ |
Out-proj + MLP linear layers |
|
12 |
MLP hidden → GELU chains |
|
25 |
Pre/post-norm layers (unquantized) |
Conversions applied: DecomposeMultiheadAttentionPattern decomposes all 12
attention layers into explicit sub-modules.
Running fusions only#
from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_FUSION_PATTERNS
result = transform(model, patterns=TENSORRT_FUSION_PATTERNS)
To inspect what was fused:
from collections import Counter
fused_counts = Counter(
type(m).__name__
for m in result.model.modules()
if hasattr(type(m).__name__, 'startswith')
and type(m).__name__.startswith('Fused')
)
for name, count in sorted(fused_counts.items()):
print(f" {name}: {count}")