# Copyright (C) 2026 Embedl AB
"""Fusion ``Pattern`` subclasses for TensorRT.
Each class declares a ``tree`` (what to match) and a ``graft`` (what to
replace with). The base ``Pattern`` class handles matching and replacement.
"""
import operator
from torch import fx, nn
from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.match import node_check
from embedl_deploy._internal.core.tree.types import (
Fork,
Graft,
Tree,
Wildcard,
)
from embedl_deploy._internal.core.tree.utils import get_module
from embedl_deploy._internal.tensorrt.modules.attention import (
FusedMHAInProjection,
FusedScaledDotProductAttention,
MHAInProjection,
ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.conv import (
FusedConvBN,
FusedConvBNAct,
FusedConvBNActMaxPool,
FusedConvBNAddAct,
)
from embedl_deploy._internal.tensorrt.modules.linear import (
FusedLayerNorm,
FusedLinear,
FusedLinearAct,
)
from embedl_deploy._internal.tensorrt.modules.pointwise import FusedActAdd
from embedl_deploy._internal.tensorrt.modules.pool import (
FusedAdaptiveAvgPool2d,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
FusedSwinAttention,
SwinAttention,
)
@node_check
def _is_add(node: fx.Node) -> bool:
"""Match ``operator.add`` or ``operator.iadd``."""
return node.op == "call_function" and node.target in (
operator.add,
operator.iadd,
)
# -- Conv helpers ----------------------------------------------------
def _is_stem_conv(node: fx.Node) -> bool:
"""Return ``True`` for a 7×7 ``Conv2d`` with 3 input channels."""
module = get_module(node)
return (
isinstance(module, nn.Conv2d)
and module.in_channels == 3
and module.kernel_size == (7, 7)
)
#: Wildcard entry for an optional ``BatchNorm2d`` between convolution and
#: activation (or end of chain).
_OPTIONAL_BN = Wildcard(nn.BatchNorm2d, quantifier="?")
# -- Conv fusions ----------------------------------------------------
[docs]
class ConvBNActPattern(Pattern):
"""Match ``Conv2d → [BatchNorm2d] → Activation`` and fuse.
Any activation included in
:data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted.
The ``BatchNorm2d`` is optional.
"""
phase = Phase.FUSION
tree: Tree = (nn.Conv2d, _OPTIONAL_BN, ActivationLike)
graft: Graft = FusedConvBNAct
[docs]
class ConvBNPattern(Pattern):
"""Match ``Conv2d → [BatchNorm2d]`` (no activation) and fuse.
The ``BatchNorm2d`` is optional.
"""
phase = Phase.FUSION
tree: Tree = (nn.Conv2d, _OPTIONAL_BN)
graft: Graft = FusedConvBN
[docs]
class StemConvBNActMaxPoolPattern(Pattern):
"""Match ``Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d``.
Captures the common classification-network stem. The convolution is
constrained to ``in_channels == 3, kernel_size == (7, 7)`` so only the
actual stem is matched, not arbitrary ``Conv→Act→Pool`` chains.
The ``BatchNorm2d`` is optional.
"""
phase = Phase.FUSION
tree: Tree = (
_is_stem_conv,
_OPTIONAL_BN,
ActivationLike,
nn.MaxPool2d,
)
graft: Graft = FusedConvBNActMaxPool
[docs]
class ConvBNAddActPattern(Pattern):
"""Match ``Conv2d → BatchNorm2d → add(·, residual) → Activation``.
Captures the tail of ResNet-style bottleneck blocks where the convolution
path is element-wise added to a skip connection before the final
activation.
"""
phase = Phase.FUSION
tree: Tree = Fork(
(
(nn.Conv2d, nn.BatchNorm2d),
(),
),
_is_add,
(ActivationLike,),
)
graft: Graft = FusedConvBNAddAct
# -- Pointwise fusions -----------------------------------------------
[docs]
class ActAddPattern(Pattern):
"""Match ``Act → add(·, residual)`` and fuse into :class:`FusedActAdd`.
Placing this pattern before the Conv-based fusion patterns prevents
TensorRT from attempting to merge the upstream convolution into an
activation-fused kernel when the activation output feeds a residual add.
With ``Act → add`` absorbed into a single pointwise leaf, the upstream
``Conv → BN`` is matched by :class:`ConvBNPattern` and quantized
independently.
"""
phase = Phase.FUSION
tree: Tree = Fork(
(
(ActivationLike,),
(),
),
_is_add,
(),
)
graft: Graft = FusedActAdd
# -- Linear / LayerNorm fusions --------------------------------------
[docs]
class LinearActPattern(Pattern):
"""Match ``Linear → Activation`` and fuse.
Any activation included in
:data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted.
"""
phase = Phase.FUSION
tree: Tree = (nn.Linear, ActivationLike)
graft: Graft = FusedLinearAct
[docs]
class LinearPattern(Pattern):
"""Match a standalone ``Linear`` and wrap in a fused module."""
phase = Phase.FUSION
tree: Tree = (nn.Linear,)
graft: Graft = FusedLinear
[docs]
class LayerNormPattern(Pattern):
"""Match ``LayerNorm`` and wrap in a fused module."""
phase = Phase.FUSION
tree: Tree = (nn.LayerNorm,)
graft: Graft = FusedLayerNorm
# -- Attention fusions ------------------------------------------------
[docs]
class MHAInProjectionPattern(Pattern):
"""Match ``MHAInProjection`` and wrap in a fused module."""
phase = Phase.FUSION
tree: Tree = (MHAInProjection,)
graft: Graft = FusedMHAInProjection
[docs]
class ScaledDotProductAttentionPattern(Pattern):
"""Match ``ScaledDotProductAttention`` and wrap in a fused module."""
phase = Phase.FUSION
tree: Tree = (ScaledDotProductAttention,)
graft: Graft = FusedScaledDotProductAttention
class SwinAttentionPattern(Pattern):
"""Match ``SwinAttention`` and wrap in a fused module."""
phase = Phase.FUSION
tree: Tree = (SwinAttention,)
graft: Graft = FusedSwinAttention
# -- Pool fusions -----------------------------------------------------
[docs]
class AdaptiveAvgPoolPattern(Pattern):
"""Match ``AdaptiveAvgPool2d`` and wrap in a fused module.
Although there is nothing to *fuse*, wrapping the pool in a recognized
module allows the Q/DQ insertion pass to place quantize / dequantize stubs
around it.
"""
phase = Phase.FUSION
tree: Tree = (nn.AdaptiveAvgPool2d,)
graft: Graft = FusedAdaptiveAvgPool2d