Source code for embedl_deploy._internal.tensorrt.patterns.fusions.attention

# Copyright (C) 2026 Embedl AB

"""Fusion ``Pattern`` subclasses for attention sub-modules."""

from torch import fx

from embedl_deploy._internal.core.match import (
    match_tree,
)
from embedl_deploy._internal.core.pattern import (
    Pattern,
    PatternMatch,
    QDQPoint,
    Tree,
    resolve_module,
)
from embedl_deploy._internal.core.replace import (
    replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.attention import (
    FusedMHAInProjection,
    FusedScaledDotProductAttention,
    MHAInProjection,
    ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
    FusedSwinAttention,
    SwinAttention,
)


[docs] class MHAInProjectionPattern(Pattern): """Match ``MHAInProjection`` and wrap in a fused module.""" tree: Tree = (MHAInProjection,) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match in_proj = resolve_module(tree_match.get_node(0), MHAInProjection) fused_module = FusedMHAInProjection(in_proj) return replace_tree(pattern_match, [fused_module])
[docs] class ScaledDotProductAttentionPattern(Pattern): """Match ``ScaledDotProductAttention`` and wrap in a fused module.""" tree: Tree = (ScaledDotProductAttention,) qdq_points = frozenset( { QDQPoint.INPUT, QDQPoint.KEY_INPUT, QDQPoint.VALUE_INPUT, } )
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match attention = resolve_module( tree_match.get_node(0), ScaledDotProductAttention ) fused_module = FusedScaledDotProductAttention(attention) return replace_tree(pattern_match, [fused_module])
class SwinAttentionPattern(Pattern): """Match ``SwinAttention`` and wrap in a fused module.""" tree: Tree = (SwinAttention,) qdq_points = frozenset( { QDQPoint.INPUT, QDQPoint.KEY_INPUT, QDQPoint.VALUE_INPUT, } ) def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree(graph_module, pattern=self) def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match attention = resolve_module( tree_match.get_node(0), SwinAttention, ) fused_module = FusedSwinAttention(attention) return replace_tree(pattern_match, [fused_module])