Source code for embedl_deploy._internal.tensorrt.patterns.fusions.pool

# Copyright (C) 2026 Embedl AB

"""Fusion ``Pattern`` subclasses for TensorRT.

Each class wraps the core matching/replacement utilities and wires them to the
appropriate fused ``nn.Module`` from
:mod:`~embedl_deploy._internal.tensorrt.modules`.
"""

from torch import fx, nn

from embedl_deploy._internal.core.match import (
    match_tree,
)
from embedl_deploy._internal.core.pattern import (
    Pattern,
    PatternMatch,
    QDQPoint,
    Tree,
    resolve_module,
)
from embedl_deploy._internal.core.replace import (
    replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.pool import (
    FusedAdaptiveAvgPool2d,
)


[docs] class AdaptiveAvgPoolPattern(Pattern): """Match ``AdaptiveAvgPool2d`` and wrap in a fused module. Although there is nothing to *fuse*, wrapping the pool in a recognized module allows the Q/DQ insertion pass to place quantize / dequantize stubs around it. """ tree: Tree = (nn.AdaptiveAvgPool2d,) qdq_points = frozenset({QDQPoint.INPUT, QDQPoint.OUTPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match pool = resolve_module(tree_match.get_node(0), nn.AdaptiveAvgPool2d) fused_module = FusedAdaptiveAvgPool2d(pool) return replace_tree(pattern_match, [fused_module])