Source code for embedl_deploy._internal.tensorrt.patterns.conversions.general

# Copyright (C) 2026 Embedl AB

"""General-purpose conversion patterns.

Structural rewrites that are not attention-specific: assertion cleanup,
``Flatten → Linear`` rewritten as ``Conv2d(1×1)`` for TensorRT, and
``nn.Identity`` / identity-``AdaptiveAvgPool2d`` removal.
"""

import operator
from typing import TypeAlias

import torch
from torch import fx, nn

from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.replace import keep_node
from embedl_deploy._internal.core.tree.state import get_replaced_nodes
from embedl_deploy._internal.core.tree.types import (
    Graft,
    NodeInserter,
    Tree,
    TreeMatch,
    Wildcard,
)
from embedl_deploy._internal.core.tree.utils import (
    get_input_shape,
    get_module,
    resolve_module,
)


def _is_assert_noise(node: fx.Node) -> bool:
    """Return ``True`` for helper nodes typically used in assertions.

    These nodes are graph noise in exported inference graphs and usually feed
    directly into ``eq`` / ``_assert`` checks.
    """
    if node.op == "call_function" and node.target is getattr:
        return True
    if node.op == "call_method" and node.target in {"dim", "size"}:
        return True
    return False


def _is_eq(node: fx.Node) -> bool:
    """Return ``True`` when `node` is an equality comparison."""
    return node.op == "call_function" and node.target is operator.eq


def _is_assert(node: fx.Node) -> bool:
    """Return ``True`` when `node` is a torch assertion op."""
    target = getattr(torch, "_assert", None)
    return node.op == "call_function" and node.target is target


_OPTIONAL_ASSERT_NOISE = Wildcard(_is_assert_noise, quantifier="+")


[docs] class RemoveAssertPattern(Pattern): """Remove assertion subgraphs such as ``getattr → eq → _assert``. timm models often contain shape checks in the traced FX graph (for example ``assert x.dim() == 4``). They are not runtime compute ops and can be safely erased for deployment graphs. """ phase = Phase.CONVERSION tree: Tree = (_OPTIONAL_ASSERT_NOISE, _is_eq, _is_assert) graft: Graft = ()
def _is_flatten(node: fx.Node) -> bool: """Return ``True`` when `node` is a 4D→2D flatten with shape metadata. Only matches flattens with ``start_dim=1`` on a 4-D input, which produces a 2-D output (the classification-head pattern). Flattens with ``start_dim >= 2`` (e.g. the MHA head-merging ``flatten(2)`` that produces 3-D output) are rejected. """ if node.op == "call_function": is_flat = node.target is torch.flatten elif node.op == "call_method": is_flat = node.target == "flatten" else: is_flat = isinstance(get_module(node), nn.Flatten) if not is_flat: return False shape = get_input_shape(node) if shape is None or len(shape) != 4: return False mod = get_module(node) if isinstance(mod, nn.Flatten): start_dim: int = mod.start_dim elif len(node.args) > 1 and isinstance(node.args[1], int): start_dim = node.args[1] else: start_dim = 0 return start_dim == 1 ElementWiseLike: TypeAlias = ( nn.Dropout | nn.Dropout2d | nn.ReLU | nn.ReLU6 | nn.LeakyReLU | nn.ELU | nn.GELU | nn.SiLU | nn.Hardswish | nn.Hardsigmoid ) #: Optional element-wise ops that can appear between a ``flatten`` and #: a ``Linear`` (activations, dropout). _OPTIONAL_EW = Wildcard(ElementWiseLike) def _ew_nodes( tree_match: TreeMatch, ) -> tuple[fx.Node, ...]: """Return the element-wise wildcard nodes.""" return tree_match.get_node(1, is_wildcard=True).nodes def _get_reshape_insert( flatten_node: fx.Node, linear: nn.Linear, ) -> NodeInserter: """Return a replacement function for reshaping conv inputs.""" def _insert( graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: old_shape = get_input_shape(flatten_node) assert old_shape is not None new_shape = torch.Size([-1, linear.in_features, 1, 1]) if old_shape[1:] == new_shape[1:]: return [] replaced = get_replaced_nodes() resolved = replaced.get(flatten_node, flatten_node) with graph_module.graph.inserting_before(resolved): node = graph_module.graph.call_method( "reshape", (prev_args[0], *new_shape) ) return [node] return _insert def _linear_to_conv1x1(linear: nn.Linear) -> nn.Conv2d: """Convert a ``Linear`` layer to an equivalent ``Conv2d(1×1)``.""" conv = nn.Conv2d( linear.in_features, linear.out_features, kernel_size=1, bias=linear.bias is not None, ) conv.weight.data = linear.weight.data.clone().reshape( linear.out_features, linear.in_features, 1, 1, ) if linear.bias is not None: assert conv.bias is not None conv.bias.data = linear.bias.data.clone() return conv def _reshape_and_conv( tree_match: TreeMatch, ) -> tuple[NodeInserter, nn.Module]: """Return the reshape-insert function and ``Conv2d(1×1)``.""" flatten_node = tree_match.get_node(0) linear = resolve_module(tree_match.get_node(2), nn.Linear) return ( _get_reshape_insert(flatten_node, linear), _linear_to_conv1x1(linear), )
[docs] class FlattenLinearToConv1x1Pattern(Pattern): """Replace ``Flatten (4D→2D) → Linear`` with ``Conv2d(1×1) → Flatten``. Many classification networks end with:: ``AdaptiveAvgPool2d → flatten → [Dropout → ReLU →] Linear`` This conversion rewrites the tail into:: ``AdaptiveAvgPool2d → [Dropout → ReLU →] Conv2d(1×1) → flatten`` Element-wise ops between flatten and Linear (activations, dropout) are absorbed by a :class:`~embedl_deploy._internal.core.tree.types.Wildcard` and moved in front of the ``Conv2d`` in the replacement. The resulting ``Conv2d`` can then be matched by downstream fusion and Q/DQ patterns. This is a *structural* conversion — it changes graph topology rather than collapsing a chain into a fused module. It must be applied **before** fusion patterns. """ phase = Phase.CONVERSION tree: Tree = (_is_flatten, _OPTIONAL_EW, nn.Linear) graft: Graft = (_ew_nodes, _reshape_and_conv, keep_node(0))
def _is_identity_adaptive_avg_pool(node: fx.Node) -> bool: """Return ``True`` when ``node`` is an identity ``AdaptiveAvgPool2d``. The pool is identity when its ``output_size`` matches the spatial dimensions of the input tensor. Requires shape metadata (``tensor_meta``) on the input node. ``None`` in ``output_size`` means "keep the input dimension", which is always identity for that axis. """ mod = get_module(node) if not isinstance(mod, nn.AdaptiveAvgPool2d): return False shape = get_input_shape(node) if shape is None or len(shape) != 4: return False in_h: int = shape[2] in_w: int = shape[3] if isinstance(mod.output_size, tuple): out_h, out_w = mod.output_size else: out_h = out_w = mod.output_size out_h = out_h or in_h out_w = out_w or in_w return (in_h == out_h) and (in_w == out_w)
[docs] class RemoveIdentityAdaptiveAvgPoolPattern(Pattern): """Remove ``AdaptiveAvgPool2d`` where output size equals input size. When ``output_size == (H, W)`` of the incoming feature map the pooling operation is a mathematical identity and can be safely erased. This is common in ConvNeXt-style architectures. Assumes shapes have already been propagated into ``node.meta['tensor_meta']`` (e.g. via a prior ``ShapeProp`` pass). Nodes with missing shape metadata are skipped with a warning. """ phase = Phase.CONVERSION tree: Tree = (_is_identity_adaptive_avg_pool,) graft: Graft = ()
def _is_identity_passthrough(node: fx.Node) -> bool: """Return ``True`` when `node` is an identity operation that passes through its input.""" mod = get_module(node) return isinstance(mod, nn.Identity)
[docs] class RemoveIdentityPattern(Pattern): """Remove ``nn.Identity`` modules from the graph. ``nn.Identity`` is a no-op module that passes its input through unchanged. These operations can be safely removed from the graph without affecting model behavior. This simplifies the graph for downstream optimization and fusion patterns. """ phase = Phase.CONVERSION tree: Tree = (_is_identity_passthrough,) graft: Graft = ()