Source code for embedl_deploy._internal.tensorrt.patterns.fusions.conv

# Copyright (C) 2026 Embedl AB

"""Fusion ``Pattern`` subclasses for TensorRT.

Each class wraps the core matching/replacement utilities and wires them to the
appropriate fused ``nn.Module`` from
:mod:`~embedl_deploy._internal.tensorrt.modules`.
"""

import operator
from typing import cast

from torch import fx, nn

from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.pattern import (
    Fork,
    Pattern,
    PatternMatch,
    QDQPoint,
    Tree,
    Wildcard,
    get_module,
    resolve_module,
)
from embedl_deploy._internal.core.replace import (
    replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.conv import (
    FusedConvBN,
    FusedConvBNAct,
    FusedConvBNActMaxPool,
    FusedConvBNAddAct,
)


def _is_stem_conv(node: fx.Node) -> bool:
    """Return ``True`` for a 7×7 ``Conv2d`` with 3 input channels."""
    module = get_module(node)
    return (
        isinstance(module, nn.Conv2d)
        and module.in_channels == 3
        and module.kernel_size == (7, 7)
    )


#: Wildcard entry for an optional ``BatchNorm2d`` between convolution and
#: activation (or end of chain).
_OPTIONAL_BN = Wildcard(nn.BatchNorm2d, quantifier="?")


[docs] class ConvBNActPattern(Pattern): """Match ``Conv2d → [BatchNorm2d] → Activation`` and fuse. Any activation included in :data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted. The ``BatchNorm2d`` is optional. """ tree: Tree = (nn.Conv2d, _OPTIONAL_BN, ActivationLike) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match conv = resolve_module(tree_match.get_node(0), nn.Conv2d) wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes assert len(wild_nodes) <= 1 bn = None if wild_nodes: bn = resolve_module(wild_nodes[0], nn.BatchNorm2d) act = cast(ActivationLike, resolve_module(tree_match.get_node(2))) fused_module = FusedConvBNAct( conv, bn, act, bn_foldable=bn is not None, ) return replace_tree(pattern_match, [fused_module])
[docs] class ConvBNPattern(Pattern): """Match ``Conv2d → [BatchNorm2d]`` (no activation) and fuse. The ``BatchNorm2d`` is optional. """ tree: Tree = (nn.Conv2d, _OPTIONAL_BN) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match conv = resolve_module(tree_match.get_node(0), nn.Conv2d) wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes assert len(wild_nodes) <= 1 bn = None if wild_nodes: bn = resolve_module(wild_nodes[0], nn.BatchNorm2d) fused_module = FusedConvBN( conv, bn, bn_foldable=bn is not None, ) return replace_tree(pattern_match, [fused_module])
[docs] class StemConvBNActMaxPoolPattern(Pattern): """Match ``Conv2d(3in, 7×7) → [BatchNorm2d] → Activation → MaxPool2d``. Captures the common classification-network stem. The convolution is constrained to ``in_channels == 3, kernel_size == (7, 7)`` so only the actual stem is matched, not arbitrary ``Conv→Act→Pool`` chains. The ``BatchNorm2d`` is optional. """ tree: Tree = ( _is_stem_conv, _OPTIONAL_BN, ActivationLike, nn.MaxPool2d, ) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match conv = resolve_module(tree_match.get_node(0), nn.Conv2d) wild_nodes = tree_match.get_node(1, is_wildcard=True).nodes assert len(wild_nodes) <= 1 bn = None if wild_nodes: bn = resolve_module(wild_nodes[0], nn.BatchNorm2d) act = cast(ActivationLike, resolve_module(tree_match.get_node(2))) maxpool = resolve_module(tree_match.get_node(3), nn.MaxPool2d) fused_module = FusedConvBNActMaxPool( conv, bn, act, maxpool, bn_foldable=bn is not None, ) return replace_tree(pattern_match, [fused_module])
[docs] class ConvBNAddActPattern(Pattern): """Match ``Conv2d → BatchNorm2d → add(·, residual) → Activation``. Captures the tail of ResNet-style bottleneck blocks where the convolution path is element-wise added to a skip connection before the final activation. """ tree: Tree = Fork( ( (nn.Conv2d, nn.BatchNorm2d), (), ), operator.add, (ActivationLike,), ) qdq_points = frozenset({QDQPoint.INPUT, QDQPoint.RESIDUAL_INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match conv = resolve_module(tree_match.get_node(0, 0), nn.Conv2d) bn = resolve_module(tree_match.get_node(0, 1), nn.BatchNorm2d) act = cast(ActivationLike, resolve_module(tree_match.get_node(0))) fused_module = FusedConvBNAddAct( conv, bn, act, bn_foldable=True, ) return replace_tree(pattern_match, [fused_module])