Source code for embedl_deploy._internal.tensorrt.patterns.conversions

# Copyright (C) 2026 Embedl AB

"""Conversion ``Pattern`` subclasses for TensorRT.

Conversions are *structural* rewrites that change the graph topology so that
downstream fusion patterns can match.  They run **before** fusion patterns.

Example: ``Flatten → Linear → Conv2d(1×1) → Flatten``.
"""

import logging
import operator
from typing import TypeAlias

import torch
from torch import fx, nn

from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.pattern import (
    Pattern,
    PatternMatch,
    Tree,
    Wildcard,
    get_module,
    resolve_module,
)
from embedl_deploy._internal.core.replace import (
    ReplacementFn,
    get_auto_name,
    get_replaced_nodes,
    replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.attention import (
    MHAInProjection,
    ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
    SwinAttention,
    SwinSpatialState,
    SwinWindowPartition,
    SwinWindowReverse,
)

try:
    from torchvision.models.swin_transformer import (
        shifted_window_attention,
    )
except ImportError:  # pragma: no cover
    shifted_window_attention = None  # type: ignore[assignment]

_LOG = logging.getLogger(__name__)


def _is_assert_noise(node: fx.Node) -> bool:
    """Return ``True`` for helper nodes typically used in assertions.

    These nodes are graph noise in exported inference graphs and usually feed
    directly into ``eq`` / ``_assert`` checks.
    """
    if node.op == "call_function" and node.target is getattr:
        return True
    if node.op == "call_method" and node.target in {"dim", "size"}:
        return True
    return False


def _is_eq(node: fx.Node) -> bool:
    """Return ``True`` when `node` is an equality comparison."""
    return node.op == "call_function" and node.target is operator.eq


def _is_assert(node: fx.Node) -> bool:
    """Return ``True`` when `node` is a torch assertion op."""
    target = getattr(torch, "_assert", None)
    return node.op == "call_function" and node.target is target


_OPTIONAL_ASSERT_NOISE = Wildcard(_is_assert_noise, quantifier="+")


def _is_dead_assert_noise(node: fx.Node) -> bool:
    """Return ``True`` for assertion-noise nodes with no users."""
    return _is_assert_noise(node) and len(node.users) == 0


[docs] class RemoveAssertPattern(Pattern): """Remove assertion subgraphs such as ``getattr → eq → _assert``. timm models often contain shape checks in the traced FX graph (for example ``assert x.dim() == 4``). They are not runtime compute ops and can be safely erased for deployment graphs. """ is_conversion = True tree: Tree = (_OPTIONAL_ASSERT_NOISE, _is_eq, _is_assert)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self return replace_tree(pattern_match, [])
[docs] class RemoveDeadAssertPattern(Pattern): """Remove dead assertion nodes left after assert removal. ``RemoveAssertPattern`` can erase ``eq``/``_assert`` while leaving the other ``getattr`` input dead. This cleanup pass removes those dead nodes. """ is_conversion = True tree: Tree = (_is_dead_assert_noise,)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self return replace_tree(pattern_match, [])
def _has_input_shape(node: fx.Node) -> bool: """Return ``True`` when the first input to `node` has shape metadata.""" try: if not isinstance(arg := node.args[0], fx.Node): return False if not hasattr(arg.meta["tensor_meta"], "shape"): return False except (AttributeError, IndexError, KeyError): return False return True def _is_flatten(node: fx.Node) -> bool: """Return ``True`` when `node` is a flatten call with shape metadata.""" if node.op == "call_function": is_flat = node.target is torch.flatten elif node.op == "call_method": is_flat = node.target == "flatten" else: is_flat = isinstance(get_module(node), nn.Flatten) if not is_flat: return False if not _has_input_shape(node): msg = f"Skipping {node.target}: missing shape metadata on input node" _LOG.warning("%s", msg) return False return True ElementWiseLike: TypeAlias = ( nn.Dropout | nn.Dropout2d | nn.ReLU | nn.ReLU6 | nn.LeakyReLU | nn.ELU | nn.GELU | nn.SiLU | nn.Hardswish | nn.Hardsigmoid ) #: Optional element-wise ops that can appear between a ``flatten`` and #: a ``Linear`` (activations, dropout). _OPTIONAL_EW = Wildcard(ElementWiseLike) def _get_reshape_insert( flatten_node: fx.Node, linear: nn.Linear, ) -> ReplacementFn: """Return a replacement function for reshaping conv inputs.""" def _insert( graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: assert isinstance(flatten_node.args[0], fx.Node) old_shape = flatten_node.args[0].meta["tensor_meta"].shape new_shape = torch.Size([-1, linear.in_features, 1, 1]) if old_shape[1:] == new_shape[1:]: return [] replaced = get_replaced_nodes(graph_module) resolved = replaced.get(flatten_node, flatten_node) with graph_module.graph.inserting_before(resolved): node = graph_module.graph.call_method( "reshape", (prev_args[0], *new_shape) ) return [node] return _insert def _linear_to_conv1x1(linear: nn.Linear) -> nn.Conv2d: """Convert a ``Linear`` layer to an equivalent ``Conv2d(1×1)``.""" conv = nn.Conv2d( linear.in_features, linear.out_features, kernel_size=1, bias=linear.bias is not None, ) conv.weight.data = linear.weight.data.clone().reshape( linear.out_features, linear.in_features, 1, 1, ) if linear.bias is not None: assert conv.bias is not None conv.bias.data = linear.bias.data.clone() return conv
[docs] class FlattenLinearToConv1x1Pattern(Pattern): """Replace ``Flatten (4D→2D) → Linear`` with ``Conv2d(1×1) → Flatten``. Many classification networks end with:: ``AdaptiveAvgPool2d → flatten → [Dropout → ReLU →] Linear`` This conversion rewrites the tail into:: ``AdaptiveAvgPool2d → [Dropout → ReLU →] Conv2d(1×1) → flatten`` Element-wise ops between flatten and Linear (activations, dropout) are absorbed by a :class:`~embedl_deploy._internal.core.pattern.Wildcard` and moved in front of the ``Conv2d`` in the replacement. The resulting ``Conv2d`` can then be matched by downstream fusion and Q/DQ patterns. This is a *structural* conversion — it changes graph topology rather than collapsing a chain into a fused module. It must be applied **before** fusion patterns. """ is_conversion = True tree: Tree = (_is_flatten, _OPTIONAL_EW, nn.Linear)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match flatten_node = tree_match.get_node(0) wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes linear = resolve_module(tree_match.get_node(2), nn.Linear) reshape_insert = _get_reshape_insert(flatten_node, linear) conv = _linear_to_conv1x1(linear) return replace_tree( pattern_match, [*wild_nodes, reshape_insert, conv, flatten_node] )
def _is_identity_adaptive_avg_pool(node: fx.Node) -> bool: """Return ``True`` when ``node`` is an identity ``AdaptiveAvgPool2d``. The pool is identity when its ``output_size`` matches the spatial dimensions of the input tensor. Requires shape metadata (``tensor_meta``) on the input node. ``None`` in ``output_size`` means "keep the input dimension", which is always identity for that axis. """ mod = get_module(node) if not isinstance(mod, nn.AdaptiveAvgPool2d): return False if not _has_input_shape(node): msg = f"Skipping {node.target}: missing shape metadata on input node" _LOG.warning("%s", msg) return False assert isinstance(node.args[0], fx.Node) meta = node.args[0].meta["tensor_meta"] if len(meta.shape) != 4: return False in_h: int = meta.shape[2] in_w: int = meta.shape[3] if isinstance(mod.output_size, tuple): out_h, out_w = mod.output_size else: out_h = out_w = mod.output_size out_h = out_h or in_h out_w = out_w or in_w return (in_h == out_h) and (in_w == out_w)
[docs] class RemoveIdentityAdaptiveAvgPoolPattern(Pattern): """Remove ``AdaptiveAvgPool2d`` where output size equals input size. When ``output_size == (H, W)`` of the incoming feature map the pooling operation is a mathematical identity and can be safely erased. This is common in ConvNeXt-style architectures. Assumes shapes have already been propagated into ``node.meta['tensor_meta']`` (e.g. via a prior ``ShapeProp`` pass). Nodes with missing shape metadata are skipped with a warning. """ is_conversion = True tree: Tree = (_is_identity_adaptive_avg_pool,)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self return replace_tree(pattern_match, [])
def _is_identity_passthrough(node: fx.Node) -> bool: """Return ``True`` when `node` is an identity operation that passes through its input.""" mod = get_module(node) return isinstance(mod, nn.Identity)
[docs] class RemoveIdentityPattern(Pattern): """Remove ``nn.Identity`` modules from the graph. ``nn.Identity`` is a no-op module that passes its input through unchanged. These operations can be safely removed from the graph without affecting model behavior. This simplifies the graph for downstream optimization and fusion patterns. """ is_conversion = True tree: Tree = (_is_identity_passthrough,)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self return replace_tree(pattern_match, [])
def _is_supported_mha(node: fx.Node) -> bool: """Return ``True`` when ``node`` is a supported ``MultiheadAttention``. Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``) without masks or ``add_zero_attn`` is supported. Unsupported configurations are skipped with a warning. """ mod = get_module(node) if not isinstance(mod, nn.MultiheadAttention): return False warnings: list[str] = [] # pylint: disable-next=protected-access if not mod._qkv_same_embed_dim: warnings.append( f"Skipping MHA at {node.target}: cross-attention " f"(_qkv_same_embed_dim=False) not supported" ) if not mod.batch_first: warnings.append( f"Skipping MHA at {node.target}: batch_first=False not supported" ) if mod.add_zero_attn: warnings.append( f"Skipping MHA at {node.target}: add_zero_attn=True not supported" ) if len(node.args) > 3: warnings.append( f"Skipping MHA at {node.target}: extra positional args " f"(possible mask arguments)" ) for mask_key in ("attn_mask", "key_padding_mask"): val = node.kwargs.get(mask_key) if val is not None: warnings.append( f"Skipping MHA at {node.target}: {mask_key} is not None" ) if warnings: for msg in warnings: _LOG.warning("%s", msg) return False return True def _decompose_mha( mha_node: fx.Node, ) -> tuple[MHAInProjection, ScaledDotProductAttention, nn.Linear]: """Split `mha_node` into in-proj, attention and out-proj modules.""" mha: nn.MultiheadAttention = resolve_module( mha_node, nn.MultiheadAttention ) embed_dim = mha.embed_dim num_heads = mha.num_heads head_dim = embed_dim // num_heads linear = nn.Linear( embed_dim, 3 * embed_dim, bias=mha.in_proj_bias is not None, ) linear.weight = mha.in_proj_weight if mha.in_proj_bias is not None: linear.bias = mha.in_proj_bias in_proj = MHAInProjection(linear, num_heads, head_dim) attention = ScaledDotProductAttention( num_heads, head_dim, dropout=mha.dropout, ) return in_proj, attention, mha.out_proj def _unwrap_mha_getitems( mha_node: fx.Node, ) -> None: """Replace ``getitem`` users of `mha_node` with direct references. ``MultiheadAttention`` returns ``(output, weights)``. ``getitem(mha, 0)`` users are rewired to reference `mha_node` directly and ``getitem(mha, 1)`` users are replaced with a ``None`` constant. After this call `mha_node` behaves as a single-output node, allowing :func:`~embedl_deploy._internal.core.replace.replace_tree` to handle it normally. """ graph_module = mha_node.graph.owning_module assert isinstance(graph_module, fx.GraphModule) replaced_nodes = get_replaced_nodes(graph_module) resolved = replaced_nodes.get(mha_node, mha_node) for user in list(resolved.users.keys()): if user.op == "call_function" and user.target is operator.getitem: idx = user.args[1] if idx == 0: user.replace_all_uses_with(resolved) graph_module.graph.erase_node(user) elif idx == 1: with graph_module.graph.inserting_after(resolved): none_node = graph_module.graph.call_function( operator.getitem, ((None,), 0) ) user.replace_all_uses_with(none_node) graph_module.graph.erase_node(user) def _get_attention_insert( attention: ScaledDotProductAttention, ) -> ReplacementFn: """Return a replacement function for Q/K/V splitting. Splits the in-projection output into Q/K/V via ``getitem`` and feeds them into `attention`. """ def _insert( graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: name = get_auto_name(graph_module, attention) graph_module.add_module(name, attention) (proj_node,) = prev_args graph = graph_module.graph with graph.inserting_after(proj_node): q = graph.call_function(operator.getitem, (proj_node, 0)) with graph.inserting_after(q): k = graph.call_function(operator.getitem, (proj_node, 1)) with graph.inserting_after(k): v = graph.call_function(operator.getitem, (proj_node, 2)) with graph.inserting_after(v): attn = graph.call_module(name, (q, k, v)) return [q, k, v, attn] return _insert
[docs] class DecomposeMultiheadAttentionPattern(Pattern): """Decompose ``nn.MultiheadAttention`` into explicit sub-modules. Replaces each ``MultiheadAttention`` node with three sub-modules visible in the FX graph: 1. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` 2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` 3. ``nn.Linear`` (output projection, reused from the original MHA) Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``) without masks is supported. Unsupported configurations are skipped with a warning. """ is_conversion = True tree: Tree = (_is_supported_mha,)
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match mha_node = tree_match.get_node(0) in_proj, attention, out_proj = _decompose_mha(mha_node) attention_insert = _get_attention_insert(attention) _unwrap_mha_getitems(mha_node) return replace_tree( pattern_match, [in_proj, attention_insert, out_proj] )
# -- Swin Transformer attention decomposition ------------------------- def _is_shifted_window_attention(node: fx.Node) -> bool: """Return ``True`` for a supported ``shifted_window_attention`` call. Only Swin V1 is supported. V2 nodes (which carry a ``logit_scale`` kwarg) are rejected. Returns ``False`` when *torchvision* is not installed. """ if shifted_window_attention is None: return False if node.op != "call_function": return False if node.target is not shifted_window_attention: return False if node.kwargs.get("logit_scale") is not None: _LOG.warning( "Skipping %s: Swin V2 (logit_scale) not supported", node.name, ) return False return True def _resolve_get_attr( graph_module: fx.GraphModule, node: fx.Node, ) -> torch.Tensor | nn.Parameter: """Follow a ``get_attr`` node to the actual tensor.""" obj: object = graph_module for part in node.target.split("."): obj = getattr(obj, part) assert isinstance(obj, torch.Tensor) return obj def _decompose_swin_attention( # pylint: disable=too-many-locals swa_node: fx.Node, graph_module: fx.GraphModule, ) -> tuple[ SwinWindowPartition, MHAInProjection, SwinAttention, nn.Linear, SwinWindowReverse, ]: """Split a ``shifted_window_attention`` call into five sub-modules. Extracts weights and parameters from the ``get_attr`` nodes in the function call's arguments, since the ``fx.GraphModule`` stores traced-through modules as generic ``nn.Module`` wrappers. """ # Positional args. qkv_weight = _resolve_get_attr(graph_module, swa_node.args[1]) proj_weight = _resolve_get_attr(graph_module, swa_node.args[2]) window_size: list[int] = list(swa_node.args[4]) num_heads: int = swa_node.args[5] # Keyword args. shift_size: list[int] = list(swa_node.kwargs["shift_size"]) attention_dropout: float = swa_node.kwargs["attention_dropout"] qkv_bias_node = swa_node.kwargs.get("qkv_bias") proj_bias_node = swa_node.kwargs.get("proj_bias") embed_dim = qkv_weight.shape[1] head_dim = embed_dim // num_heads # Relative position bias from the _get_relative_position_bias call. rel_bias_node = swa_node.args[3] bias_table = _resolve_get_attr(graph_module, rel_bias_node.args[0]) bias_index = _resolve_get_attr(graph_module, rel_bias_node.args[1]) # Reconstruct nn.Linear for QKV projection. has_qkv_bias = isinstance(qkv_bias_node, fx.Node) qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=has_qkv_bias) qkv_linear.weight = nn.Parameter(qkv_weight) if has_qkv_bias: qkv_linear.bias = nn.Parameter( _resolve_get_attr(graph_module, qkv_bias_node), ) # Reconstruct nn.Linear for output projection. has_proj_bias = isinstance(proj_bias_node, fx.Node) out_proj = nn.Linear(embed_dim, embed_dim, bias=has_proj_bias) out_proj.weight = nn.Parameter(proj_weight) if has_proj_bias: out_proj.bias = nn.Parameter( _resolve_get_attr(graph_module, proj_bias_node), ) spatial_state = SwinSpatialState() win_part = SwinWindowPartition( window_size, shift_size, spatial_state, ) in_proj = MHAInProjection(qkv_linear, num_heads, head_dim) attention = SwinAttention( relative_position_bias_table=nn.Parameter(bias_table), relative_position_index=bias_index, window_size=window_size, num_heads=num_heads, head_dim=head_dim, attention_dropout=attention_dropout, spatial_state=spatial_state, ) win_rev = SwinWindowReverse(window_size, spatial_state) return win_part, in_proj, attention, out_proj, win_rev def _get_swin_replace( win_part: SwinWindowPartition, in_proj: MHAInProjection, attention: SwinAttention, out_proj: nn.Linear, win_rev: SwinWindowReverse, ) -> ReplacementFn: """Return a replacement function for the full Swin decomposition.""" def _insert( # pylint: disable=too-many-locals graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: # prev_args contains all fx.Node positional args of the # shifted_window_attention call. Only the first (input) is # used; the rest (weight get_attr, rel-bias node) are superseded # by the new modules. input_node = prev_args[0] graph = graph_module.graph # 1. SwinWindowPartition wp_name = get_auto_name(graph_module, win_part) graph_module.add_module(wp_name, win_part) with graph.inserting_after(input_node): wp_node = graph.call_module(wp_name, (input_node,)) # 2. MHAInProjection (self-attention: q=k=v) ip_name = get_auto_name(graph_module, in_proj) graph_module.add_module(ip_name, in_proj) with graph.inserting_after(wp_node): ip_node = graph.call_module( ip_name, (wp_node, wp_node, wp_node), ) # 3. getitem nodes for Q, K, V with graph.inserting_after(ip_node): q = graph.call_function( operator.getitem, (ip_node, 0), ) with graph.inserting_after(q): k = graph.call_function( operator.getitem, (ip_node, 1), ) with graph.inserting_after(k): v = graph.call_function( operator.getitem, (ip_node, 2), ) # 4. SwinAttention at_name = get_auto_name(graph_module, attention) graph_module.add_module(at_name, attention) with graph.inserting_after(v): at_node = graph.call_module(at_name, (q, k, v)) # 5. nn.Linear (output projection) op_name = get_auto_name(graph_module, out_proj) graph_module.add_module(op_name, out_proj) with graph.inserting_after(at_node): op_node = graph.call_module(op_name, (at_node,)) # 6. SwinWindowReverse wr_name = get_auto_name(graph_module, win_rev) graph_module.add_module(wr_name, win_rev) with graph.inserting_after(op_node): wr_node = graph.call_module(wr_name, (op_node,)) return [ wp_node, ip_node, q, k, v, at_node, op_node, wr_node, ] return _insert def _cleanup_swin_dead_nodes( graph_module: fx.GraphModule, swa_node: fx.Node, ) -> None: """Erase orphaned ``get_attr`` and ``_get_relative_position_bias`` nodes. After :func:`replace_tree` erases the ``shifted_window_attention`` node, its input ``get_attr`` nodes and the preceding ``_get_relative_position_bias`` call may have no remaining users. """ graph = graph_module.graph # Collect candidate dead nodes: positional args + kwargs values. candidates: list[fx.Node] = [ a for a in swa_node.args if isinstance(a, fx.Node) ] candidates.extend( v for v in swa_node.kwargs.values() if isinstance(v, fx.Node) ) # The _get_relative_position_bias node's inputs are also candidates. rel_bias_node = swa_node.args[3] if isinstance(rel_bias_node, fx.Node): candidates.extend( a for a in rel_bias_node.args if isinstance(a, fx.Node) ) # Erase nodes with no remaining users (rel-bias first, then attrs). for node in candidates: if len(node.users) == 0: graph.erase_node(node) class DecomposeSwinAttentionPattern(Pattern): """Decompose ``shifted_window_attention`` into explicit sub-modules. Matches the ``call_function`` node produced when ``torchvision.models.swin_transformer.ShiftedWindowAttention`` is FX-traced (``shifted_window_attention`` is ``torch.fx.wrap``-ped). Replaces each node with five sub-modules: 1. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowPartition` 2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` (reused) 3. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinAttention` 4. ``nn.Linear`` (output projection, reused from the original) 5. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowReverse` Only Swin V1 (``ShiftedWindowAttention``) is supported. V2 nodes that carry a ``logit_scale`` keyword argument are skipped. """ is_conversion = True tree: Tree = (_is_shifted_window_attention,) def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self) def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match swa_node = tree_match.get_node(0) graph_module = pattern_match.graph_module modules = _decompose_swin_attention(swa_node, graph_module) swin_replace = _get_swin_replace(*modules) nodes = replace_tree(pattern_match, [swin_replace]) _cleanup_swin_dead_nodes(graph_module, swa_node) return nodes