Source code for embedl_deploy._internal.tensorrt.patterns.fusions.attention
# Copyright (C) 2026 Embedl AB
"""Fusion ``Pattern`` subclasses for attention sub-modules."""
from torch import fx
from embedl_deploy._internal.core.match import (
match_tree,
)
from embedl_deploy._internal.core.pattern import (
Pattern,
PatternMatch,
QDQPoint,
Tree,
resolve_module,
)
from embedl_deploy._internal.core.replace import (
replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.attention import (
FusedMHAInProjection,
FusedScaledDotProductAttention,
MHAInProjection,
ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
FusedSwinAttention,
SwinAttention,
)
[docs]
class MHAInProjectionPattern(Pattern):
"""Match ``MHAInProjection`` and wrap in a fused module."""
tree: Tree = (MHAInProjection,)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
in_proj = resolve_module(tree_match.get_node(0), MHAInProjection)
fused_module = FusedMHAInProjection(in_proj)
return replace_tree(pattern_match, [fused_module])
[docs]
class ScaledDotProductAttentionPattern(Pattern):
"""Match ``ScaledDotProductAttention`` and wrap in a fused module."""
tree: Tree = (ScaledDotProductAttention,)
qdq_points = frozenset(
{
QDQPoint.INPUT,
QDQPoint.KEY_INPUT,
QDQPoint.VALUE_INPUT,
}
)
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
attention = resolve_module(
tree_match.get_node(0), ScaledDotProductAttention
)
fused_module = FusedScaledDotProductAttention(attention)
return replace_tree(pattern_match, [fused_module])
class SwinAttentionPattern(Pattern):
"""Match ``SwinAttention`` and wrap in a fused module."""
tree: Tree = (SwinAttention,)
qdq_points = frozenset(
{
QDQPoint.INPUT,
QDQPoint.KEY_INPUT,
QDQPoint.VALUE_INPUT,
}
)
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(graph_module, pattern=self)
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
attention = resolve_module(
tree_match.get_node(0),
SwinAttention,
)
fused_module = FusedSwinAttention(attention)
return replace_tree(pattern_match, [fused_module])