Graph Conversions#

Conversions are structural graph rewrites that normalize the computation graph before fusion patterns are applied. They change the graph topology — replacing one set of operators with a functionally equivalent set that downstream patterns can match.

Conversion patterns have is_conversion = True and are applied iteratively by transform() until no new matches are found.

Built-in TensorRT conversions#

FlattenLinearToConv1x1Pattern#

Matches: Flatten (4D→2D) [Dropout/Activation]* Linear

Replaces with: [Dropout/Activation]* Conv2d(1×1) Flatten

Many classification networks end with:

AdaptiveAvgPool2d → Flatten → Linear(1000)

The Linear layer cannot be fused with preceding operations in TensorRT. After this conversion it becomes:

AdaptiveAvgPool2d → Conv2d(1×1, out=1000) → Flatten

The Conv2d(1×1) can then be matched by downstream fusion patterns (ConvBNPattern, etc.) and benefits from INT8 quantization.

Weight conversion: The Linear weight matrix of shape (out_features, in_features) is reshaped to (out_features, in_features, 1, 1) for the equivalent Conv2d.

Element-wise ops: Activations or dropout layers between Flatten and Linear are absorbed by a Wildcard and moved before the Conv2d in the replacement graph.

Affected architectures: ResNet, ConvNeXt, EfficientNet, MobileNet, and most classification backbones with a Flatten Linear classifier head.

RemoveIdentityAdaptiveAvgPoolPattern#

Matches: AdaptiveAvgPool2d where output_size == input spatial dims

Replaces with: nothing (erases the node)

In ConvNeXt-style architectures, AdaptiveAvgPool2d(output_size=(7, 7)) is applied to a 7×7 feature map — a mathematical identity. Removing it simplifies the graph and prevents it from blocking fusion of surrounding operators.

Note

Requires shape metadata (via torch.fx.passes.shape_prop.ShapeProp) to determine that input and output spatial dimensions match. Nodes without shape metadata are skipped with a warning.

DecomposeMultiheadAttentionPattern#

Matches: nn.MultiheadAttention (self-attention, batch_first=True)

Replaces with: three explicit modules:

  1. MHAInProjection — the combined Q/K/V linear projection

  2. ScaledDotProductAttention — the attention computation

  3. nn.Linear — the output projection

PyTorch’s nn.MultiheadAttention is a monolithic module. TensorRT cannot fuse or quantize its internal operations. By decomposing it into visible sub-modules in the FX graph, each component can be independently fused and quantized.

Restrictions: Only self-attention (_qkv_same_embed_dim=True, batch_first=True) without masks or add_zero_attn is supported. Unsupported configurations are skipped with a warning.

Affected architectures: Vision Transformer (ViT), DeiT, and any model using nn.MultiheadAttention.

When conversions matter#

ResNet50#

The FlattenLinearToConv1x1Pattern converts the final classifier:

avgpool (AdaptiveAvgPool2d) → flatten → fc (Linear)

becomes:

avgpool (AdaptiveAvgPool2d) → fc (Conv2d 1×1) → flatten

The Conv2d is then picked up by ConvBNPattern during the fusion stage.

ConvNeXt#

ConvNeXt uses AdaptiveAvgPool2d at several points where the spatial dimensions do not change. The RemoveIdentityAdaptiveAvgPoolPattern cleans these up, and FlattenLinearToConv1x1Pattern converts the head.

For ConvNeXt models, shape propagation is required before running conversions:

from torch.fx.passes.shape_prop import ShapeProp

graph_module = torch.fx.symbolic_trace(model)
ShapeProp(graph_module).propagate(torch.randn(1, 3, 224, 224))

result = transform(graph_module, patterns=TENSORRT_PATTERNS)

Vision Transformer (ViT)#

DecomposeMultiheadAttentionPattern expands each nn.MultiheadAttention into three sub-modules. After decomposition, the fusion pass applies MHAInProjectionPattern and ScaledDotProductAttentionPattern to wrap them in quantization-aware fused modules. The output projection nn.Linear is matched by LinearPattern.

Running conversions only#

from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_CONVERSION_PATTERNS

# Apply only structural conversions
result = transform(model, patterns=TENSORRT_CONVERSION_PATTERNS)
converted_model = result.model

This is useful for debugging: you can inspect the graph after conversions to verify the structural rewrites before running fusions.