Source code for embedl_deploy._internal.tensorrt.patterns.fusions.pool
# Copyright (C) 2026 Embedl AB
"""Fusion ``Pattern`` subclasses for TensorRT.
Each class wraps the core matching/replacement utilities and wires them to the
appropriate fused ``nn.Module`` from
:mod:`~embedl_deploy._internal.tensorrt.modules`.
"""
from torch import fx, nn
from embedl_deploy._internal.core.match import (
match_tree,
)
from embedl_deploy._internal.core.pattern import (
Pattern,
PatternMatch,
QDQPoint,
Tree,
resolve_module,
)
from embedl_deploy._internal.core.replace import (
replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.pool import (
FusedAdaptiveAvgPool2d,
)
[docs]
class AdaptiveAvgPoolPattern(Pattern):
"""Match ``AdaptiveAvgPool2d`` and wrap in a fused module.
Although there is nothing to *fuse*, wrapping the pool in a recognized
module allows the Q/DQ insertion pass to place quantize / dequantize stubs
around it.
"""
tree: Tree = (nn.AdaptiveAvgPool2d,)
qdq_points = frozenset({QDQPoint.INPUT, QDQPoint.OUTPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
pool = resolve_module(tree_match.get_node(0), nn.AdaptiveAvgPool2d)
fused_module = FusedAdaptiveAvgPool2d(pool)
return replace_tree(pattern_match, [fused_module])