# Operator Fusions Fusions combine sequences of operators into single fused modules that map directly to hardware-accelerated kernels. After fusion, the model graph contains modules like `FusedConvBNReLU` instead of separate `Conv2d`, `BatchNorm2d`, and `ReLU` layers. Fused modules: - Are numerically equivalent to the original operator sequence (no weight folding at this stage — that's handled by the hardware compiler). - Declare QDQ insertion points so quantization stubs are placed correctly. - Export cleanly to ONNX for downstream compilation. ## Convolution fusions ### ConvBNPattern **Matches:** `Conv2d → [BatchNorm2d]` **Produces:** `FusedConvBN` The most basic convolution fusion. The `BatchNorm2d` is optional — a bare `Conv2d` is also matched (useful for ensuring QDQ stub placement even on convolutions without batch normalization). **QDQ points:** `INPUT` ### ConvBNReLUPattern **Matches:** `Conv2d → [BatchNorm2d] → Activation` **Produces:** `FusedConvBNReLU` Fuses convolution, optional batch normalization, and an activation function. The pattern is named after the most common case (ReLU), but it matches any of these activations: `ReLU`, `ReLU6`, `LeakyReLU`, `ELU`, `GELU`, `SiLU`, `Hardswish`, `Hardsigmoid`. **QDQ points:** `INPUT` ### StemConvBNReLUMaxPoolPattern **Matches:** `Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d` **Produces:** `FusedConvBNReLUMaxPool` Captures the classification network stem found in ResNet and similar architectures. The convolution is constrained to `in_channels=3, kernel_size=(7,7)` so only the actual stem is matched. **QDQ points:** `INPUT` ### ConvBNAddReLUPattern **Matches:** `Conv2d → BatchNorm2d → add(·, residual) → Activation` **Produces:** `FusedConvBNAddReLU` Captures the tail of ResNet-style bottleneck blocks where the convolution path merges with a skip connection before the final activation. This is a **branching** pattern — it matches a `Fork` topology with two inputs feeding into an `operator.add` node. **QDQ points:** `INPUT`, `RESIDUAL_INPUT` The `RESIDUAL_INPUT` QDQ point ensures the skip connection is also quantized, which is critical for TensorRT to fuse the entire residual block into a single INT8 kernel. ## Linear fusions ### LinearReLUPattern **Matches:** `Linear → Activation` **Produces:** `FusedLinearReLU` **QDQ points:** `INPUT` ### LinearPattern **Matches:** standalone `Linear` **Produces:** `FusedLinear` **QDQ points:** `INPUT` ### LayerNormPattern **Matches:** `LayerNorm` **Produces:** `FusedLayerNorm` **QDQ points:** none (empty `qdq_points`) LayerNorm is **not quantized** — it operates element-wise and runs efficiently in FP16/FP32. Placing QDQ stubs around LayerNorm would hurt accuracy without improving latency. ## Attention fusions These patterns match the sub-modules produced by the `DecomposeMultiheadAttentionPattern` conversion. ### MHAInProjectionPattern **Matches:** `MHAInProjection` **Produces:** `FusedMHAInProjection` **QDQ points:** `INPUT` ### ScaledDotProductAttentionPattern **Matches:** `ScaledDotProductAttention` **Produces:** `FusedScaledDotProductAttention` **QDQ points:** `INPUT`, `KEY_INPUT`, `VALUE_INPUT` The three QDQ points ensure that Q, K, and V tensors are independently quantized, matching TensorRT's expected input format for fused attention kernels. ## Pooling fusions ### AdaptiveAvgPoolPattern **Matches:** `AdaptiveAvgPool2d` **Produces:** `FusedAdaptiveAvgPool2d` **QDQ points:** `INPUT`, `OUTPUT` :::{note} In smart-quantization workflows, the `AdaptiveAvgPoolPattern` is intentionally omitted from the pattern list to skip QDQ placement around global average pooling. This is because `GlobalAvgPool` is element-wise and memory-bound — quantizing it adds TensorRT reformatting overhead without meaningful compute savings. See {doc}`custom_patterns` for details. ::: ## Fusion summary by architecture ### ResNet50 | Fused module | Count | Description | |---|---|---| | `FusedConvBNReLUMaxPool` | 1 | Stem: Conv(7×7) + BN + ReLU + MaxPool | | `FusedConvBNAddReLU` | 16 | Bottleneck residual blocks | | `FusedConvBNReLU` | 16 | Main-path Conv + BN + ReLU | | `FusedConvBN` | 17 | Conv + BN without activation | | `FusedAdaptiveAvgPool2d` | 1 | Global average pool | | **Total** | **51** | | Conversions applied: `FlattenLinearToConv1x1Pattern` converts the `Flatten → Linear` classifier into `Conv2d(1×1) → Flatten`. ### ConvNeXt (Tiny/Base/Large) | Fused module | Count | Description | |---|---|---| | `FusedConvBN` | 36/72/108 | Depthwise + pointwise convolutions | | `FusedConvBNReLU` | 27/54/81 | Conv + BN + GELU chains | | `FusedLayerNorm` | 27/54/81 | LayerNorm layers (unquantized) | | `FusedLinear` | 27/54/81 | Standalone linear layers | Counts increase with model depth (Tiny: 9 stages × 3, Base: 9 × 6, Large: 9 × 9). Conversions applied: - `RemoveIdentityAdaptiveAvgPoolPattern` removes identity pooling ops. - `FlattenLinearToConv1x1Pattern` converts the classifier head. ConvNeXt uses depthwise separable convolutions extensively. The default pattern set quantizes all convolutions equally, but mixed-precision (see {doc}`custom_patterns`) skips depthwise convolutions for better latency. ### Vision Transformer (ViT-B/16) | Fused module | Count | Description | |---|---|---| | `FusedConvBN` | 1 | Patch embedding Conv2d | | `FusedMHAInProjection` | 12 | Q/K/V projections | | `FusedScaledDotProductAttention` | 12 | Attention computation | | `FusedLinear` | 36+ | Out-proj + MLP linear layers | | `FusedLinearReLU` | 12 | MLP hidden → GELU chains | | `FusedLayerNorm` | 25 | Pre/post-norm layers (unquantized) | Conversions applied: `DecomposeMultiheadAttentionPattern` decomposes all 12 attention layers into explicit sub-modules. ## Running fusions only ```python from embedl_deploy import transform from embedl_deploy.tensorrt import TENSORRT_FUSION_PATTERNS result = transform(model, patterns=TENSORRT_FUSION_PATTERNS) ``` To inspect what was fused: ```python from collections import Counter fused_counts = Counter( type(m).__name__ for m in result.model.modules() if hasattr(type(m).__name__, 'startswith') and type(m).__name__.startswith('Fused') ) for name, count in sorted(fused_counts.items()): print(f" {name}: {count}") ```