# Custom Patterns The built-in `TENSORRT_PATTERNS` list is a good starting point, but real-world deployment often benefits from a **custom pattern list** that skips quantization on operators where INT8 hurts more than it helps. This is the "mixed-precision" strategy: selectively place QDQ stubs based on the compute characteristics of each operator and its behavior on specific hardware. ## Why customize? Two operators commonly benefit from being left in FP16: ### Depthwise convolutions Depthwise convolutions (`groups == in_channels`) are **memory-bound**, not compute-bound. Quantizing them to INT8 forces TensorRT to insert data reformatting layers (INT8 → FP16 → INT8) around the convolution, and this reformatting overhead typically exceeds the compute savings from INT8. This effect is especially pronounced on ConvNeXt, which uses depthwise 7×7 convolutions throughout. ### Global average pooling `AdaptiveAvgPool2d` is an element-wise operation with no matrix multiplication. INT8 quantization adds QDQ stubs on both sides but provides negligible compute benefit while risking accuracy loss. ## Writing a custom pattern A custom pattern is a subclass of `Pattern` with: - `tree` — the node topology to match - `qdq_points` — where QDQ stubs should be placed (empty = no quantization) - `match()` — how to find occurrences - `replace()` — how to rewrite the graph Here is a complete example that matches depthwise convolutions and **skips quantization** by declaring empty `qdq_points`: ```python import torch.nn as nn from torch import fx from embedl_deploy._internal.core.match import match_tree from embedl_deploy._internal.core.pattern import ( Pattern, PatternMatch, Wildcard, get_module, ) from embedl_deploy._internal.core.replace import replace_tree from embedl_deploy._internal.tensorrt.modules.conv import FusedConvBN def _is_depthwise_conv(node: fx.Node) -> bool: """Return True for a depthwise Conv2d (groups == in_channels > 1).""" module = get_module(node) return ( isinstance(module, nn.Conv2d) and module.groups > 1 and module.groups == module.in_channels ) class DepthwiseConvBNPattern(Pattern): """Match depthwise Conv2d → [BatchNorm2d] without quantization.""" tree = (_is_depthwise_conv, Wildcard((nn.BatchNorm2d,))) qdq_points = frozenset() # no QDQ stubs def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self) def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: tree_match = pattern_match.tree_match conv = get_module(tree_match.get_node(0)) wild_nodes = tree_match.get_node(1).nodes bn = get_module(wild_nodes[0]) if wild_nodes else None return replace_tree( pattern_match, [FusedConvBN(conv, bn, bn_foldable=bn is not None)], ) ``` Key points: - The `_is_depthwise_conv` predicate uses a callable check instead of a module type, allowing fine-grained matching. - `qdq_points = frozenset()` means `insert_qdq` will not place any quantization stubs around matched depthwise convolutions. - The replacement uses the existing `FusedConvBN` module, which is mathematically equivalent to the original `Conv2d → BatchNorm2d` sequence — BatchNorm can be folded into Conv2d without changing the output. ## Building a custom pattern list A custom pattern list is assembled by combining built-in patterns with your custom ones. **Order matters** — longer/more-specific patterns first: ```python # NOTE: These imports use internal APIs. Public exports may be added # in a future release. from embedl_deploy._internal.tensorrt.patterns.conversions import ( DecomposeMultiheadAttentionPattern, FlattenLinearToConv1x1Pattern, RemoveIdentityAdaptiveAvgPoolPattern, ) from embedl_deploy._internal.tensorrt.patterns.fusions.conv import ( ConvBNAddReLUPattern, ConvBNPattern, ConvBNReLUPattern, StemConvBNReLUMaxPoolPattern, ) from embedl_deploy._internal.tensorrt.patterns.fusions.linear import ( LayerNormPattern, LinearPattern, LinearReLUPattern, ) from embedl_deploy._internal.tensorrt.patterns.fusions.attention import ( MHAInProjectionPattern, ScaledDotProductAttentionPattern, ) SMART_PATTERNS = [ # -- Conversions (applied first, iteratively) -- DecomposeMultiheadAttentionPattern(), FlattenLinearToConv1x1Pattern(), RemoveIdentityAdaptiveAvgPoolPattern(), # -- Fusions (longest first) -- StemConvBNReLUMaxPoolPattern(), ConvBNAddReLUPattern(), ConvBNReLUPattern(), LinearReLUPattern(), DepthwiseConvBNPattern(), # custom: no QDQ on depthwise ConvBNPattern(), LinearPattern(), LayerNormPattern(), MHAInProjectionPattern(), ScaledDotProductAttentionPattern(), # NOTE: AdaptiveAvgPoolPattern() intentionally omitted # → no QDQ stubs around GlobalAveragePooling ] ``` Two deliberate omissions: 1. **No `AdaptiveAvgPoolPattern`** — global average pooling stays in FP16. 2. **`DepthwiseConvBNPattern` with empty `qdq_points`** — depthwise convolutions are fused but not quantized. ## Using the custom pattern list ```python from torch.fx.passes.shape_prop import ShapeProp # NOTE: These imports use internal APIs. Public exports may be added # in a future release. from embedl_deploy._internal.core.modules import symbolic_trace from embedl_deploy import transform from embedl_deploy.quantize import ( QuantConfig, TensorQuantConfig, insert_qdq, calibrate, ) # Define your model (e.g., from torchvision) # my_model = torchvision.models.convnext_tiny(weights="DEFAULT") # Trace and propagate shapes (needed for conversion patterns) model = my_model.cpu().eval() gm = symbolic_trace(model) ShapeProp(gm).propagate(torch.randn(1, 3, 224, 224)) # Transform with custom patterns result = transform(gm, patterns=SMART_PATTERNS) fused_model = result.model matches = result.matches # Verify lossless fusion with torch.no_grad(): y_orig = model(torch.randn(1, 3, 224, 224)) y_fused = fused_model(torch.randn(1, 3, 224, 224)) max_diff = (y_orig - y_fused).abs().max().item() assert max_diff < 1e-4 # Insert QDQ and calibrate quantized = insert_qdq( fused_model, matches, config=QuantConfig( activation=TensorQuantConfig(n_bits=8, symmetric=True), weight=TensorQuantConfig(n_bits=8, symmetric=True, per_channel=True), skip_weight_quant_for=(nn.LayerNorm,), ), ) def forward_loop(model): for batch in calibration_batches[:32]: model(batch) calibrate(quantized, forward_loop) ``` ## Impact on benchmarks The mixed-precision strategy makes a significant difference on models with depthwise convolutions. On ConvNeXt Large (NVIDIA RTX 4090, TensorRT 10.9): | Variant | Latency | Speedup | |---|---|---| | Baseline FP16 | 2.19 ms | 1.00x | | Blanket INT8 (ModelOpt) | 2.15 ms | 1.02x | | **Smart INT8 (Embedl Deploy)** | **1.70 ms** | **1.29x** | Blanket quantization barely improves over FP16 because TensorRT reformatting overhead around depthwise convolutions offsets the INT8 compute gains. Smart quantization avoids this by leaving depthwise convolutions in FP16. See {doc}`benchmarks` for full results across architectures.