Source code for embedl_deploy._internal.tensorrt.patterns.fusions

# Copyright (C) 2026 Embedl AB

"""Fusion ``Pattern`` subclasses for TensorRT.

Each class declares a ``tree`` (what to match) and a ``graft`` (what to
replace with).  The base ``Pattern`` class handles matching and replacement.
"""

import operator

from torch import fx, nn

from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.match import node_check
from embedl_deploy._internal.core.tree.types import (
    Fork,
    Graft,
    Tree,
    Wildcard,
)
from embedl_deploy._internal.core.tree.utils import get_module
from embedl_deploy._internal.tensorrt.modules.attention import (
    FusedMHAInProjection,
    FusedScaledDotProductAttention,
    MHAInProjection,
    ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.conv import (
    FusedConvBN,
    FusedConvBNAct,
    FusedConvBNActMaxPool,
    FusedConvBNAddAct,
)
from embedl_deploy._internal.tensorrt.modules.linear import (
    FusedLayerNorm,
    FusedLinear,
    FusedLinearAct,
)
from embedl_deploy._internal.tensorrt.modules.pointwise import FusedActAdd
from embedl_deploy._internal.tensorrt.modules.pool import (
    FusedAdaptiveAvgPool2d,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
    FusedSwinAttention,
    SwinAttention,
)


@node_check
def _is_add(node: fx.Node) -> bool:
    """Match ``operator.add`` or ``operator.iadd``."""
    return node.op == "call_function" and node.target in (
        operator.add,
        operator.iadd,
    )


# -- Conv helpers ----------------------------------------------------


def _is_stem_conv(node: fx.Node) -> bool:
    """Return ``True`` for a 7×7 ``Conv2d`` with 3 input channels."""
    module = get_module(node)
    return (
        isinstance(module, nn.Conv2d)
        and module.in_channels == 3
        and module.kernel_size == (7, 7)
    )


#: Wildcard entry for an optional ``BatchNorm2d`` between convolution and
#: activation (or end of chain).
_OPTIONAL_BN = Wildcard(nn.BatchNorm2d, quantifier="?")


# -- Conv fusions ----------------------------------------------------


[docs] class ConvBNActPattern(Pattern): """Match ``Conv2d → [BatchNorm2d] → Activation`` and fuse. Any activation included in :data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted. The ``BatchNorm2d`` is optional. """ phase = Phase.FUSION tree: Tree = (nn.Conv2d, _OPTIONAL_BN, ActivationLike) graft: Graft = FusedConvBNAct
[docs] class ConvBNPattern(Pattern): """Match ``Conv2d → [BatchNorm2d]`` (no activation) and fuse. The ``BatchNorm2d`` is optional. """ phase = Phase.FUSION tree: Tree = (nn.Conv2d, _OPTIONAL_BN) graft: Graft = FusedConvBN
[docs] class StemConvBNActMaxPoolPattern(Pattern): """Match ``Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d``. Captures the common classification-network stem. The convolution is constrained to ``in_channels == 3, kernel_size == (7, 7)`` so only the actual stem is matched, not arbitrary ``Conv→Act→Pool`` chains. The ``BatchNorm2d`` is optional. """ phase = Phase.FUSION tree: Tree = ( _is_stem_conv, _OPTIONAL_BN, ActivationLike, nn.MaxPool2d, ) graft: Graft = FusedConvBNActMaxPool
[docs] class ConvBNAddActPattern(Pattern): """Match ``Conv2d → BatchNorm2d → add(·, residual) → Activation``. Captures the tail of ResNet-style bottleneck blocks where the convolution path is element-wise added to a skip connection before the final activation. """ phase = Phase.FUSION tree: Tree = Fork( ( (nn.Conv2d, nn.BatchNorm2d), (), ), _is_add, (ActivationLike,), ) graft: Graft = FusedConvBNAddAct
# -- Pointwise fusions -----------------------------------------------
[docs] class ActAddPattern(Pattern): """Match ``Act → add(·, residual)`` and fuse into :class:`FusedActAdd`. Placing this pattern before the Conv-based fusion patterns prevents TensorRT from attempting to merge the upstream convolution into an activation-fused kernel when the activation output feeds a residual add. With ``Act → add`` absorbed into a single pointwise leaf, the upstream ``Conv → BN`` is matched by :class:`ConvBNPattern` and quantized independently. """ phase = Phase.FUSION tree: Tree = Fork( ( (ActivationLike,), (), ), _is_add, (), ) graft: Graft = FusedActAdd
# -- Linear / LayerNorm fusions --------------------------------------
[docs] class LinearActPattern(Pattern): """Match ``Linear → Activation`` and fuse. Any activation included in :data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted. """ phase = Phase.FUSION tree: Tree = (nn.Linear, ActivationLike) graft: Graft = FusedLinearAct
[docs] class LinearPattern(Pattern): """Match a standalone ``Linear`` and wrap in a fused module.""" phase = Phase.FUSION tree: Tree = (nn.Linear,) graft: Graft = FusedLinear
[docs] class LayerNormPattern(Pattern): """Match ``LayerNorm`` and wrap in a fused module.""" phase = Phase.FUSION tree: Tree = (nn.LayerNorm,) graft: Graft = FusedLayerNorm
# -- Attention fusions ------------------------------------------------
[docs] class MHAInProjectionPattern(Pattern): """Match ``MHAInProjection`` and wrap in a fused module.""" phase = Phase.FUSION tree: Tree = (MHAInProjection,) graft: Graft = FusedMHAInProjection
[docs] class ScaledDotProductAttentionPattern(Pattern): """Match ``ScaledDotProductAttention`` and wrap in a fused module.""" phase = Phase.FUSION tree: Tree = (ScaledDotProductAttention,) graft: Graft = FusedScaledDotProductAttention
class SwinAttentionPattern(Pattern): """Match ``SwinAttention`` and wrap in a fused module.""" phase = Phase.FUSION tree: Tree = (SwinAttention,) graft: Graft = FusedSwinAttention # -- Pool fusions -----------------------------------------------------
[docs] class AdaptiveAvgPoolPattern(Pattern): """Match ``AdaptiveAvgPool2d`` and wrap in a fused module. Although there is nothing to *fuse*, wrapping the pool in a recognized module allows the Q/DQ insertion pass to place quantize / dequantize stubs around it. """ phase = Phase.FUSION tree: Tree = (nn.AdaptiveAvgPool2d,) graft: Graft = FusedAdaptiveAvgPool2d