# Copyright (C) 2026 Embedl AB
"""General-purpose conversion patterns.
Structural rewrites that are not attention-specific: assertion cleanup,
``Flatten → Linear`` rewritten as ``Conv2d(1×1)`` for TensorRT, and
``nn.Identity`` / identity-``AdaptiveAvgPool2d`` removal.
"""
import operator
from typing import TypeAlias
import torch
from torch import fx, nn
from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.replace import keep_node
from embedl_deploy._internal.core.tree.state import get_replaced_nodes
from embedl_deploy._internal.core.tree.types import (
Graft,
NodeInserter,
Tree,
TreeMatch,
Wildcard,
)
from embedl_deploy._internal.core.tree.utils import (
get_input_shape,
get_module,
resolve_module,
)
def _is_assert_noise(node: fx.Node) -> bool:
"""Return ``True`` for helper nodes typically used in assertions.
These nodes are graph noise in exported inference graphs and usually feed
directly into ``eq`` / ``_assert`` checks.
"""
if node.op == "call_function" and node.target is getattr:
return True
if node.op == "call_method" and node.target in {"dim", "size"}:
return True
return False
def _is_eq(node: fx.Node) -> bool:
"""Return ``True`` when `node` is an equality comparison."""
return node.op == "call_function" and node.target is operator.eq
def _is_assert(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a torch assertion op."""
target = getattr(torch, "_assert", None)
return node.op == "call_function" and node.target is target
_OPTIONAL_ASSERT_NOISE = Wildcard(_is_assert_noise, quantifier="+")
[docs]
class RemoveAssertPattern(Pattern):
"""Remove assertion subgraphs such as ``getattr → eq → _assert``.
timm models often contain shape checks in the traced FX graph (for
example ``assert x.dim() == 4``). They are not runtime compute ops and can
be safely erased for deployment graphs.
"""
phase = Phase.CONVERSION
tree: Tree = (_OPTIONAL_ASSERT_NOISE, _is_eq, _is_assert)
graft: Graft = ()
def _is_flatten(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a 4D→2D flatten with shape metadata.
Only matches flattens with ``start_dim=1`` on a 4-D input, which
produces a 2-D output (the classification-head pattern). Flattens
with ``start_dim >= 2`` (e.g. the MHA head-merging ``flatten(2)``
that produces 3-D output) are rejected.
"""
if node.op == "call_function":
is_flat = node.target is torch.flatten
elif node.op == "call_method":
is_flat = node.target == "flatten"
else:
is_flat = isinstance(get_module(node), nn.Flatten)
if not is_flat:
return False
shape = get_input_shape(node)
if shape is None or len(shape) != 4:
return False
mod = get_module(node)
if isinstance(mod, nn.Flatten):
start_dim: int = mod.start_dim
elif len(node.args) > 1 and isinstance(node.args[1], int):
start_dim = node.args[1]
else:
start_dim = 0
return start_dim == 1
ElementWiseLike: TypeAlias = (
nn.Dropout
| nn.Dropout2d
| nn.ReLU
| nn.ReLU6
| nn.LeakyReLU
| nn.ELU
| nn.GELU
| nn.SiLU
| nn.Hardswish
| nn.Hardsigmoid
)
#: Optional element-wise ops that can appear between a ``flatten`` and
#: a ``Linear`` (activations, dropout).
_OPTIONAL_EW = Wildcard(ElementWiseLike)
def _ew_nodes(
tree_match: TreeMatch,
) -> tuple[fx.Node, ...]:
"""Return the element-wise wildcard nodes."""
return tree_match.get_node(1, is_wildcard=True).nodes
def _get_reshape_insert(
flatten_node: fx.Node,
linear: nn.Linear,
) -> NodeInserter:
"""Return a replacement function for reshaping conv inputs."""
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
old_shape = get_input_shape(flatten_node)
assert old_shape is not None
new_shape = torch.Size([-1, linear.in_features, 1, 1])
if old_shape[1:] == new_shape[1:]:
return []
replaced = get_replaced_nodes()
resolved = replaced.get(flatten_node, flatten_node)
with graph_module.graph.inserting_before(resolved):
node = graph_module.graph.call_method(
"reshape", (prev_args[0], *new_shape)
)
return [node]
return _insert
def _linear_to_conv1x1(linear: nn.Linear) -> nn.Conv2d:
"""Convert a ``Linear`` layer to an equivalent ``Conv2d(1×1)``."""
conv = nn.Conv2d(
linear.in_features,
linear.out_features,
kernel_size=1,
bias=linear.bias is not None,
)
conv.weight.data = linear.weight.data.clone().reshape(
linear.out_features,
linear.in_features,
1,
1,
)
if linear.bias is not None:
assert conv.bias is not None
conv.bias.data = linear.bias.data.clone()
return conv
def _reshape_and_conv(
tree_match: TreeMatch,
) -> tuple[NodeInserter, nn.Module]:
"""Return the reshape-insert function and ``Conv2d(1×1)``."""
flatten_node = tree_match.get_node(0)
linear = resolve_module(tree_match.get_node(2), nn.Linear)
return (
_get_reshape_insert(flatten_node, linear),
_linear_to_conv1x1(linear),
)
[docs]
class FlattenLinearToConv1x1Pattern(Pattern):
"""Replace ``Flatten (4D→2D) → Linear`` with ``Conv2d(1×1) → Flatten``.
Many classification networks end with::
``AdaptiveAvgPool2d → flatten → [Dropout → ReLU →] Linear``
This conversion rewrites the tail into::
``AdaptiveAvgPool2d → [Dropout → ReLU →] Conv2d(1×1) → flatten``
Element-wise ops between flatten and Linear (activations, dropout) are
absorbed by a :class:`~embedl_deploy._internal.core.tree.types.Wildcard` and
moved in front of the ``Conv2d`` in the replacement.
The resulting ``Conv2d`` can then be matched by downstream fusion and Q/DQ
patterns.
This is a *structural* conversion — it changes graph topology rather than
collapsing a chain into a fused module. It must be applied **before**
fusion patterns.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_flatten, _OPTIONAL_EW, nn.Linear)
graft: Graft = (_ew_nodes, _reshape_and_conv, keep_node(0))
def _is_identity_adaptive_avg_pool(node: fx.Node) -> bool:
"""Return ``True`` when ``node`` is an identity ``AdaptiveAvgPool2d``.
The pool is identity when its ``output_size`` matches the spatial
dimensions of the input tensor. Requires shape metadata (``tensor_meta``)
on the input node.
``None`` in ``output_size`` means "keep the input dimension", which is
always identity for that axis.
"""
mod = get_module(node)
if not isinstance(mod, nn.AdaptiveAvgPool2d):
return False
shape = get_input_shape(node)
if shape is None or len(shape) != 4:
return False
in_h: int = shape[2]
in_w: int = shape[3]
if isinstance(mod.output_size, tuple):
out_h, out_w = mod.output_size
else:
out_h = out_w = mod.output_size
out_h = out_h or in_h
out_w = out_w or in_w
return (in_h == out_h) and (in_w == out_w)
[docs]
class RemoveIdentityAdaptiveAvgPoolPattern(Pattern):
"""Remove ``AdaptiveAvgPool2d`` where output size equals input size.
When ``output_size == (H, W)`` of the incoming feature map the pooling
operation is a mathematical identity and can be safely erased. This is
common in ConvNeXt-style architectures.
Assumes shapes have already been propagated into
``node.meta['tensor_meta']`` (e.g. via a prior ``ShapeProp`` pass). Nodes
with missing shape metadata are skipped with a warning.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_identity_adaptive_avg_pool,)
graft: Graft = ()
def _is_identity_passthrough(node: fx.Node) -> bool:
"""Return ``True`` when `node` is an identity operation that passes through its input."""
mod = get_module(node)
return isinstance(mod, nn.Identity)
[docs]
class RemoveIdentityPattern(Pattern):
"""Remove ``nn.Identity`` modules from the graph.
``nn.Identity`` is a no-op module that passes its input through unchanged.
These operations can be safely removed from the graph without affecting
model behavior. This simplifies the graph for downstream optimization
and fusion patterns.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_identity_passthrough,)
graft: Graft = ()