embedl_deploy.tensorrt.patterns package#

Module contents:

Public re-exports of TensorRT pattern classes.

Users import from here:

from embedl_deploy.tensorrt.patterns import ConvBNPattern
class embedl_deploy.tensorrt.patterns.AdaptiveAvgPoolPattern[source]#

Bases: Pattern

Match AdaptiveAvgPool2d and wrap in a fused module.

Although there is nothing to fuse, wrapping the pool in a recognized module allows the Q/DQ insertion pass to place quantize / dequantize stubs around it.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT, QDQPoint.OUTPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.ConvBNActPattern[source]#

Bases: Pattern

Match Conv2d [BatchNorm2d] Activation and fuse.

Any activation included in ActivationLike is accepted. The BatchNorm2d is optional.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.conv.Conv2d'>, Wildcard(check=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, quantifier='?', nodes=()), torch.nn.modules.activation.ReLU | torch.nn.modules.activation.ReLU6 | torch.nn.modules.activation.GELU | torch.nn.modules.activation.SiLU | torch.nn.modules.activation.Mish | torch.nn.modules.activation.Hardswish | torch.nn.modules.activation.Hardsigmoid | torch.nn.modules.activation.LeakyReLU | torch.nn.modules.activation.PReLU | torch.nn.modules.activation.ELU | torch.nn.modules.activation.Sigmoid | torch.nn.modules.activation.Tanh)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.ConvBNAddActPattern[source]#

Bases: Pattern

Match Conv2d BatchNorm2d add(·, residual) Activation.

Captures the tail of ResNet-style bottleneck blocks where the convolution path is element-wise added to a skip connection before the final activation.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT, QDQPoint.RESIDUAL_INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = Fork(inputs=((<class 'torch.nn.modules.conv.Conv2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>), ()), operator=<built-in function add>, output=(torch.nn.modules.activation.ReLU | torch.nn.modules.activation.ReLU6 | torch.nn.modules.activation.GELU | torch.nn.modules.activation.SiLU | torch.nn.modules.activation.Mish | torch.nn.modules.activation.Hardswish | torch.nn.modules.activation.Hardsigmoid | torch.nn.modules.activation.LeakyReLU | torch.nn.modules.activation.PReLU | torch.nn.modules.activation.ELU | torch.nn.modules.activation.Sigmoid | torch.nn.modules.activation.Tanh,))#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.ConvBNPattern[source]#

Bases: Pattern

Match Conv2d [BatchNorm2d] (no activation) and fuse.

The BatchNorm2d is optional.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.conv.Conv2d'>, Wildcard(check=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, quantifier='?', nodes=()))#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.DecomposeMultiheadAttentionPattern[source]#

Bases: Pattern

Decompose nn.MultiheadAttention into explicit sub-modules.

Replaces each MultiheadAttention node with three sub-modules visible in the FX graph:

  1. MHAInProjection

  2. ScaledDotProductAttention

  3. nn.Linear (output projection, reused from the original MHA)

Only self-attention (_qkv_same_embed_dim=True, batch_first=True) without masks is supported. Unsupported configurations are skipped with a warning.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_supported_mha>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.FlattenLinearToConv1x1Pattern[source]#

Bases: Pattern

Replace Flatten (4D→2D) Linear with Conv2d(1×1) Flatten.

Many classification networks end with:

``AdaptiveAvgPool2d → flatten → [Dropout → ReLU →] Linear``

This conversion rewrites the tail into:

``AdaptiveAvgPool2d → [Dropout → ReLU →] Conv2d(1×1) → flatten``

Element-wise ops between flatten and Linear (activations, dropout) are absorbed by a Wildcard and moved in front of the Conv2d in the replacement.

The resulting Conv2d can then be matched by downstream fusion and Q/DQ patterns.

This is a structural conversion — it changes graph topology rather than collapsing a chain into a fused module. It must be applied before fusion patterns.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_flatten>, Wildcard(check=torch.nn.modules.dropout.Dropout | torch.nn.modules.dropout.Dropout2d | torch.nn.modules.activation.ReLU | torch.nn.modules.activation.ReLU6 | torch.nn.modules.activation.LeakyReLU | torch.nn.modules.activation.ELU | torch.nn.modules.activation.GELU | torch.nn.modules.activation.SiLU | torch.nn.modules.activation.Hardswish | torch.nn.modules.activation.Hardsigmoid, quantifier='*', nodes=()), <class 'torch.nn.modules.linear.Linear'>)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.LayerNormPattern[source]#

Bases: Pattern

Match LayerNorm and wrap in a fused module.

No Q/DQ stubs are placed around LayerNorm (empty qdq_points).

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.normalization.LayerNorm'>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.LinearActPattern[source]#

Bases: Pattern

Match Linear Activation and fuse.

Any activation included in ActivationLike is accepted.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.linear.Linear'>, torch.nn.modules.activation.ReLU | torch.nn.modules.activation.ReLU6 | torch.nn.modules.activation.GELU | torch.nn.modules.activation.SiLU | torch.nn.modules.activation.Mish | torch.nn.modules.activation.Hardswish | torch.nn.modules.activation.Hardsigmoid | torch.nn.modules.activation.LeakyReLU | torch.nn.modules.activation.PReLU | torch.nn.modules.activation.ELU | torch.nn.modules.activation.Sigmoid | torch.nn.modules.activation.Tanh)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.LinearPattern[source]#

Bases: Pattern

Match a standalone Linear and wrap in a fused module.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'torch.nn.modules.linear.Linear'>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.MHAInProjectionPattern[source]#

Bases: Pattern

Match MHAInProjection and wrap in a fused module.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection'>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.RemoveAssertPattern[source]#

Bases: Pattern

Remove assertion subgraphs such as getattr eq _assert.

timm models often contain shape checks in the traced FX graph (for example assert x.dim() == 4). They are not runtime compute ops and can be safely erased for deployment graphs.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (Wildcard(check=<function _is_assert_noise>, quantifier='+', nodes=()), <function _is_eq>, <function _is_assert>)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.RemoveDeadAssertPattern[source]#

Bases: Pattern

Remove dead assertion nodes left after assert removal.

RemoveAssertPattern can erase eq/_assert while leaving the other getattr input dead. This cleanup pass removes those dead nodes.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_dead_assert_noise>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.RemoveIdentityAdaptiveAvgPoolPattern[source]#

Bases: Pattern

Remove AdaptiveAvgPool2d where output size equals input size.

When output_size == (H, W) of the incoming feature map the pooling operation is a mathematical identity and can be safely erased. This is common in ConvNeXt-style architectures.

Assumes shapes have already been propagated into node.meta['tensor_meta'] (e.g. via a prior ShapeProp pass). Nodes with missing shape metadata are skipped with a warning.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_identity_adaptive_avg_pool>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.RemoveIdentityPattern[source]#

Bases: Pattern

Remove nn.Identity modules from the graph.

nn.Identity is a no-op module that passes its input through unchanged. These operations can be safely removed from the graph without affecting model behavior. This simplifies the graph for downstream optimization and fusion patterns.

is_conversion: bool = True#

If True, this pattern is a structural conversion that must be applied before fusion matching.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_identity_passthrough>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.ScaledDotProductAttentionPattern[source]#

Bases: Pattern

Match ScaledDotProductAttention and wrap in a fused module.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT, QDQPoint.KEY_INPUT, QDQPoint.VALUE_INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<class 'embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention'>,)#

The pattern topology to match, if using tree-based matching.

class embedl_deploy.tensorrt.patterns.StemConvBNActMaxPoolPattern[source]#

Bases: Pattern

Match Conv2d(3in, 7×7) [BatchNorm2d] Activation MaxPool2d.

Captures the common classification-network stem. The convolution is constrained to in_channels == 3, kernel_size == (7, 7) so only the actual stem is matched, not arbitrary Conv→Act→Pool chains. The BatchNorm2d is optional.

match(graph_module: GraphModule) list[PatternMatch][source]#

Find all occurrences of this pattern in graph_module.

Returns a list of PatternMatch objects, each describing one non-overlapping occurrence — the matched nodes, the original nn.Module instances, etc.

qdq_points: frozenset[QDQPoint] = frozenset({QDQPoint.INPUT})#

Q/DQ insertion points declared by this pattern.

Subclasses override this to declare where QuantStub nodes should be placed. The default (empty) means no Q/DQ stubs.

replace(pattern_match: PatternMatch) list[Node][source]#

Replace one matched occurrence in-place in a graph_module.

Swaps the matched nodes with the appropriate fused or quantized module. The graph is modified in place; callers must call graph_module.graph.lint() and graph_module.recompile() after all replacements are done.

Parameters:

pattern_match – The pattern match to replace.

Returns:

The replacement nodes inserted into the graph.

tree: Sequence[type[Module] | UnionType | Callable[[Node], bool] | Wildcard] | Fork = (<function _is_stem_conv>, Wildcard(check=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>, quantifier='?', nodes=()), torch.nn.modules.activation.ReLU | torch.nn.modules.activation.ReLU6 | torch.nn.modules.activation.GELU | torch.nn.modules.activation.SiLU | torch.nn.modules.activation.Mish | torch.nn.modules.activation.Hardswish | torch.nn.modules.activation.Hardsigmoid | torch.nn.modules.activation.LeakyReLU | torch.nn.modules.activation.PReLU | torch.nn.modules.activation.ELU | torch.nn.modules.activation.Sigmoid | torch.nn.modules.activation.Tanh, <class 'torch.nn.modules.pooling.MaxPool2d'>)#

The pattern topology to match, if using tree-based matching.