Graph Conversions#
Conversions are structural graph rewrites that normalize the computation graph before fusion patterns are applied. They change the graph topology — replacing one set of operators with a functionally equivalent set that downstream patterns can match.
Conversion patterns have is_conversion = True and are applied iteratively
by transform() until no new matches are found.
Built-in TensorRT conversions#
FlattenLinearToConv1x1Pattern#
Matches: Flatten (4D→2D) → [Dropout/Activation]* → Linear
Replaces with: [Dropout/Activation]* → Conv2d(1×1) → Flatten
Many classification networks end with:
AdaptiveAvgPool2d → Flatten → Linear(1000)
The Linear layer cannot be fused with preceding operations in TensorRT.
After this conversion it becomes:
AdaptiveAvgPool2d → Conv2d(1×1, out=1000) → Flatten
The Conv2d(1×1) can then be matched by downstream fusion patterns
(ConvBNPattern, etc.) and benefits from INT8 quantization.
Weight conversion: The Linear weight matrix of shape (out_features, in_features) is reshaped to (out_features, in_features, 1, 1) for the
equivalent Conv2d.
Element-wise ops: Activations or dropout layers between Flatten and
Linear are absorbed by a Wildcard and moved before the Conv2d in the
replacement graph.
Affected architectures: ResNet, ConvNeXt, EfficientNet, MobileNet, and
most classification backbones with a Flatten → Linear classifier head.
RemoveIdentityAdaptiveAvgPoolPattern#
Matches: AdaptiveAvgPool2d where output_size == input spatial dims
Replaces with: nothing (erases the node)
In ConvNeXt-style architectures, AdaptiveAvgPool2d(output_size=(7, 7)) is
applied to a 7×7 feature map — a mathematical identity. Removing it simplifies
the graph and prevents it from blocking fusion of surrounding operators.
Note
Requires shape metadata (via torch.fx.passes.shape_prop.ShapeProp) to
determine that input and output spatial dimensions match. Nodes without shape
metadata are skipped with a warning.
DecomposeMultiheadAttentionPattern#
Matches: nn.MultiheadAttention (self-attention, batch_first=True)
Replaces with: three explicit modules:
MHAInProjection— the combined Q/K/V linear projectionScaledDotProductAttention— the attention computationnn.Linear— the output projection
PyTorch’s nn.MultiheadAttention is a monolithic module. TensorRT cannot fuse
or quantize its internal operations. By decomposing it into visible sub-modules
in the FX graph, each component can be independently fused and quantized.
Restrictions: Only self-attention (_qkv_same_embed_dim=True,
batch_first=True) without masks or add_zero_attn is supported. Unsupported
configurations are skipped with a warning.
Affected architectures: Vision Transformer (ViT), DeiT, and any model
using nn.MultiheadAttention.
When conversions matter#
ResNet50#
The FlattenLinearToConv1x1Pattern converts the final classifier:
avgpool (AdaptiveAvgPool2d) → flatten → fc (Linear)
becomes:
avgpool (AdaptiveAvgPool2d) → fc (Conv2d 1×1) → flatten
The Conv2d is then picked up by ConvBNPattern during the fusion stage.
ConvNeXt#
ConvNeXt uses AdaptiveAvgPool2d at several points where the spatial
dimensions do not change. The RemoveIdentityAdaptiveAvgPoolPattern cleans
these up, and FlattenLinearToConv1x1Pattern converts the head.
For ConvNeXt models, shape propagation is required before running conversions:
from torch.fx.passes.shape_prop import ShapeProp
graph_module = torch.fx.symbolic_trace(model)
ShapeProp(graph_module).propagate(torch.randn(1, 3, 224, 224))
result = transform(graph_module, patterns=TENSORRT_PATTERNS)
Vision Transformer (ViT)#
DecomposeMultiheadAttentionPattern expands each nn.MultiheadAttention
into three sub-modules. After decomposition, the fusion pass applies
MHAInProjectionPattern and ScaledDotProductAttentionPattern to wrap
them in quantization-aware fused modules. The output projection nn.Linear
is matched by LinearPattern.
Running conversions only#
from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_CONVERSION_PATTERNS
# Apply only structural conversions
result = transform(model, patterns=TENSORRT_CONVERSION_PATTERNS)
converted_model = result.model
This is useful for debugging: you can inspect the graph after conversions to verify the structural rewrites before running fusions.