# Copyright (C) 2026 Embedl AB
"""Fusion ``Pattern`` subclasses for TensorRT.
Each class wraps the core matching/replacement utilities and wires them to the
appropriate fused ``nn.Module`` from
:mod:`~embedl_deploy._internal.tensorrt.modules`.
"""
import operator
from typing import cast
from torch import fx, nn
from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.pattern import (
Fork,
Pattern,
PatternMatch,
QDQPoint,
Tree,
Wildcard,
get_module,
resolve_module,
)
from embedl_deploy._internal.core.replace import (
replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.conv import (
FusedConvBN,
FusedConvBNAct,
FusedConvBNActMaxPool,
FusedConvBNAddAct,
)
def _is_stem_conv(node: fx.Node) -> bool:
"""Return ``True`` for a 7×7 ``Conv2d`` with 3 input channels."""
module = get_module(node)
return (
isinstance(module, nn.Conv2d)
and module.in_channels == 3
and module.kernel_size == (7, 7)
)
#: Wildcard entry for an optional ``BatchNorm2d`` between convolution and
#: activation (or end of chain).
_OPTIONAL_BN = Wildcard(nn.BatchNorm2d, quantifier="?")
[docs]
class ConvBNActPattern(Pattern):
"""Match ``Conv2d → [BatchNorm2d] → Activation`` and fuse.
Any activation included in
:data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted.
The ``BatchNorm2d`` is optional.
"""
tree: Tree = (nn.Conv2d, _OPTIONAL_BN, ActivationLike)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
conv = resolve_module(tree_match.get_node(0), nn.Conv2d)
wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes
assert len(wild_nodes) <= 1
bn = None
if wild_nodes:
bn = resolve_module(wild_nodes[0], nn.BatchNorm2d)
act = cast(ActivationLike, resolve_module(tree_match.get_node(2)))
fused_module = FusedConvBNAct(
conv,
bn,
act,
bn_foldable=bn is not None,
)
return replace_tree(pattern_match, [fused_module])
[docs]
class ConvBNPattern(Pattern):
"""Match ``Conv2d → [BatchNorm2d]`` (no activation) and fuse.
The ``BatchNorm2d`` is optional.
"""
tree: Tree = (nn.Conv2d, _OPTIONAL_BN)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
conv = resolve_module(tree_match.get_node(0), nn.Conv2d)
wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes
assert len(wild_nodes) <= 1
bn = None
if wild_nodes:
bn = resolve_module(wild_nodes[0], nn.BatchNorm2d)
fused_module = FusedConvBN(
conv,
bn,
bn_foldable=bn is not None,
)
return replace_tree(pattern_match, [fused_module])
[docs]
class StemConvBNActMaxPoolPattern(Pattern):
"""Match ``Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d``.
Captures the common classification-network stem. The convolution is
constrained to ``in_channels == 3, kernel_size == (7, 7)`` so only the
actual stem is matched, not arbitrary ``Conv→Act→Pool`` chains.
The ``BatchNorm2d`` is optional.
"""
tree: Tree = (
_is_stem_conv,
_OPTIONAL_BN,
ActivationLike,
nn.MaxPool2d,
)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
conv = resolve_module(tree_match.get_node(0), nn.Conv2d)
wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes
assert len(wild_nodes) <= 1
bn = None
if wild_nodes:
bn = resolve_module(wild_nodes[0], nn.BatchNorm2d)
act = cast(ActivationLike, resolve_module(tree_match.get_node(2)))
maxpool = resolve_module(tree_match.get_node(3), nn.MaxPool2d)
fused_module = FusedConvBNActMaxPool(
conv,
bn,
act,
maxpool,
bn_foldable=bn is not None,
)
return replace_tree(pattern_match, [fused_module])
[docs]
class ConvBNAddActPattern(Pattern):
"""Match ``Conv2d → BatchNorm2d → add(·, residual) → Activation``.
Captures the tail of ResNet-style bottleneck blocks where the convolution
path is element-wise added to a skip connection before the final
activation.
"""
tree: Tree = Fork(
(
(nn.Conv2d, nn.BatchNorm2d),
(),
),
operator.add,
(ActivationLike,),
)
qdq_points = frozenset({QDQPoint.INPUT, QDQPoint.RESIDUAL_INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
conv = resolve_module(tree_match.get_node(0, 0), nn.Conv2d)
bn = resolve_module(tree_match.get_node(0, 1), nn.BatchNorm2d)
act = cast(ActivationLike, resolve_module(tree_match.get_node(0)))
fused_module = FusedConvBNAddAct(
conv,
bn,
act,
bn_foldable=True,
)
return replace_tree(pattern_match, [fused_module])