# Copyright (C) 2026 Embedl AB
"""Attention-related conversion patterns.
Decomposes ``nn.MultiheadAttention`` and ``shifted_window_attention`` into
explicit sub-modules so downstream fusion and Q/DQ passes can match the
resulting :class:`~embedl_deploy._internal.tensorrt.modules.attention` and
:class:`~embedl_deploy._internal.tensorrt.modules.swin_attention` modules.
"""
import logging
import operator
from typing import TypeVar, cast
import torch
import torch.nn.functional as F
from torch import fx, nn
from embedl_deploy._internal.core.patterns.main import (
Pattern,
PatternMatch,
Phase,
)
from embedl_deploy._internal.core.patterns.utils import is_list_of
from embedl_deploy._internal.core.tree.match import node_check
from embedl_deploy._internal.core.tree.replace import get_auto_name
from embedl_deploy._internal.core.tree.state import (
SharedNodeCheck,
get_replaced_nodes,
)
from embedl_deploy._internal.core.tree.types import (
Fork,
Graft,
NodeInserter,
Tree,
TreeMatch,
Wildcard,
)
from embedl_deploy._internal.core.tree.utils import get_module, resolve_module
from embedl_deploy._internal.tensorrt.modules.attention import (
MHAInProjection,
ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
SwinAttention,
SwinSpatialState,
SwinWindowPartition,
SwinWindowReverse,
)
try:
from torchvision.models.swin_transformer import (
shifted_window_attention,
)
_TORCHVISION_AVAILABLE = True
except ModuleNotFoundError: # pragma: no cover
_TORCHVISION_AVAILABLE = False
_LOG = logging.getLogger(__name__)
def _is_supported_mha(node: fx.Node) -> bool:
"""Return ``True`` when ``node`` is a supported ``MultiheadAttention``.
Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``)
without masks or ``add_zero_attn`` is supported. Unsupported
configurations are skipped with a warning.
"""
mod = get_module(node)
if not isinstance(mod, nn.MultiheadAttention):
return False
warnings: list[str] = []
if not getattr(mod, '_qkv_same_embed_dim'):
warnings.append(
f"Skipping MHA at {node.target}: cross-attention "
f"(_qkv_same_embed_dim=False) not supported"
)
if not mod.batch_first:
warnings.append(
f"Skipping MHA at {node.target}: batch_first=False not supported"
)
if mod.add_zero_attn:
warnings.append(
f"Skipping MHA at {node.target}: add_zero_attn=True not supported"
)
if len(node.args) > 3:
warnings.append(
f"Skipping MHA at {node.target}: extra positional args "
f"(possible mask arguments)"
)
for mask_key in ("attn_mask", "key_padding_mask"):
val = node.kwargs.get(mask_key)
if val is not None:
warnings.append(
f"Skipping MHA at {node.target}: {mask_key} is not None"
)
if warnings:
for msg in warnings:
_LOG.warning("%s", msg)
return False
return True
def _decompose_mha(
tree_match: TreeMatch,
) -> tuple[MHAInProjection, NodeInserter, nn.Linear]:
"""Split the matched MHA into in-proj, attention-insert, and out-proj."""
mha_node = tree_match.get_node(0)
mha: nn.MultiheadAttention = resolve_module(
mha_node, nn.MultiheadAttention
)
embed_dim = mha.embed_dim
num_heads = mha.num_heads
head_dim = embed_dim // num_heads
linear = nn.Linear(
embed_dim,
3 * embed_dim,
bias=mha.in_proj_bias is not None,
)
linear.weight = mha.in_proj_weight
if mha.in_proj_bias is not None:
linear.bias = mha.in_proj_bias
in_proj = MHAInProjection(linear, num_heads, head_dim)
attention = ScaledDotProductAttention(
num_heads,
head_dim,
dropout=mha.dropout,
)
return in_proj, _get_attention_insert(attention), mha.out_proj
def _unwrap_mha_getitems(
mha_node: fx.Node,
) -> None:
"""Replace ``getitem`` users of `mha_node` with direct references.
``MultiheadAttention`` returns ``(output, weights)``. ``getitem(mha, 0)``
users are rewired to reference `mha_node` directly and ``getitem(mha, 1)``
users are replaced with a ``None`` constant. After this call `mha_node`
behaves as a single-output node, allowing
:func:`~embedl_deploy._internal.core.replace.replace_tree` to handle it
normally.
"""
graph_module = mha_node.graph.owning_module
assert isinstance(graph_module, fx.GraphModule)
replaced_nodes = get_replaced_nodes()
resolved = replaced_nodes.get(mha_node, mha_node)
for user in list(resolved.users.keys()):
if user.op == "call_function" and user.target is operator.getitem:
idx = user.args[1]
if idx == 0:
user.replace_all_uses_with(resolved)
graph_module.graph.erase_node(user)
elif idx == 1:
with graph_module.graph.inserting_after(resolved):
none_node = graph_module.graph.call_function(
operator.getitem, ((None,), 0)
)
user.replace_all_uses_with(none_node)
graph_module.graph.erase_node(user)
def _get_attention_insert(
attention: ScaledDotProductAttention,
) -> NodeInserter:
"""Return a replacement function for Q/K/V splitting.
Splits the in-projection output into Q/K/V via ``getitem`` and feeds them
into `attention`. The attention module returns a 4-D
``[B, num_heads, S, head_dim]`` tensor; this insert appends a
``transpose(1, 2) → reshape`` chain that flattens it back to the
``[B, S, embed_dim]`` layout expected by the downstream out-projection.
"""
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
name = get_auto_name(graph_module, attention)
graph_module.add_module(name, attention)
(proj_node,) = prev_args
graph = graph_module.graph
with graph.inserting_after(proj_node):
q = graph.call_function(operator.getitem, (proj_node, 0))
with graph.inserting_after(q):
k = graph.call_function(operator.getitem, (proj_node, 1))
with graph.inserting_after(k):
v = graph.call_function(operator.getitem, (proj_node, 2))
with graph.inserting_after(v):
attn = graph.call_module(name, (q, k, v))
with graph.inserting_after(attn):
attn_t = graph.call_method("transpose", (attn, 1, 2))
with graph.inserting_after(attn_t):
attn_flat = graph.call_method("flatten", (attn_t, 2))
return [q, k, v, attn, attn_t, attn_flat]
return _insert
[docs]
class DecomposeMultiheadAttentionPattern(Pattern):
"""Decompose ``nn.MultiheadAttention`` into explicit sub-modules.
Replaces each ``MultiheadAttention`` node with three sub-modules visible in
the FX graph:
1. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`
2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
3. ``nn.Linear`` (output projection, reused from the original MHA)
Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``)
without masks is supported. Unsupported configurations are skipped with a
warning.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_supported_mha,)
graft: Graft = (_decompose_mha,)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
_unwrap_mha_getitems(pattern_match.tree_match.get_node(0))
return super().replace(pattern_match)
# -- Swin Transformer attention decomposition -------------------------
def _is_shifted_window_attention(node: fx.Node) -> bool:
"""Return ``True`` for a supported ``shifted_window_attention`` call.
Only Swin V1 is supported. V2 nodes (which carry a ``logit_scale``
kwarg) are rejected. Returns ``False`` when *torchvision* is not
installed.
"""
if not _TORCHVISION_AVAILABLE:
return False
if node.op != "call_function":
return False
if node.target is not shifted_window_attention:
return False
if node.kwargs.get("logit_scale") is not None:
_LOG.warning(
"Skipping %s: Swin V2 (logit_scale) not supported",
node.name,
)
return False
return True
def _resolve_get_attr(
graph_module: fx.GraphModule,
node: fx.Node,
) -> torch.Tensor | nn.Parameter:
"""Follow a ``get_attr`` node to the actual tensor."""
assert isinstance(node.target, str)
obj: object = graph_module
for part in node.target.split("."):
obj = getattr(obj, part)
assert isinstance(obj, torch.Tensor)
return obj
def _decompose_swin_attention( # pylint: disable=too-many-locals
swa_node: fx.Node,
graph_module: fx.GraphModule,
) -> tuple[
SwinWindowPartition,
MHAInProjection,
SwinAttention,
nn.Linear,
SwinWindowReverse,
]:
"""Split a ``shifted_window_attention`` call into five sub-modules.
Extracts weights and parameters from the ``get_attr`` nodes in the
function call's arguments, since the ``fx.GraphModule`` stores
traced-through modules as generic ``nn.Module`` wrappers.
"""
args = swa_node.args
assert isinstance(args[1], fx.Node)
qkv_weight = _resolve_get_attr(graph_module, args[1])
assert isinstance(args[2], fx.Node)
proj_weight = _resolve_get_attr(graph_module, args[2])
assert is_list_of(args[4], int)
window_size = args[4]
assert isinstance(args[5], int)
num_heads = args[5]
kwargs = swa_node.kwargs
assert is_list_of(kwargs["shift_size"], int)
shift_size = kwargs["shift_size"]
assert isinstance(kwargs["attention_dropout"], float)
attention_dropout = kwargs["attention_dropout"]
qkv_bias_node = kwargs.get("qkv_bias")
proj_bias_node = kwargs.get("proj_bias")
embed_dim = qkv_weight.shape[1]
head_dim = embed_dim // num_heads
# Relative position bias from the _get_relative_position_bias call.
assert isinstance(args[3], fx.Node)
rel_bias_node = args[3]
assert isinstance(rel_bias_node.args[0], fx.Node)
bias_table = _resolve_get_attr(graph_module, rel_bias_node.args[0])
assert isinstance(rel_bias_node.args[1], fx.Node)
bias_index = _resolve_get_attr(graph_module, rel_bias_node.args[1])
# Reconstruct nn.Linear for QKV projection.
bias = None
if isinstance(qkv_bias_node, fx.Node):
bias = nn.Parameter(_resolve_get_attr(graph_module, qkv_bias_node))
qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=bias is not None)
qkv_linear.weight = nn.Parameter(qkv_weight)
if bias is not None:
qkv_linear.bias = bias
# Reconstruct nn.Linear for output projection.
bias = None
if isinstance(proj_bias_node, fx.Node):
bias = nn.Parameter(_resolve_get_attr(graph_module, proj_bias_node))
out_proj = nn.Linear(embed_dim, embed_dim, bias=bias is not None)
out_proj.weight = nn.Parameter(proj_weight)
if bias is not None:
out_proj.bias = bias
spatial_state = SwinSpatialState()
win_part = SwinWindowPartition(
window_size,
shift_size,
spatial_state,
)
in_proj = MHAInProjection(qkv_linear, num_heads, head_dim)
attention = SwinAttention(
relative_position_bias_table=nn.Parameter(bias_table),
relative_position_index=bias_index,
window_size=window_size,
num_heads=num_heads,
head_dim=head_dim,
attention_dropout=attention_dropout,
spatial_state=spatial_state,
)
win_rev = SwinWindowReverse(window_size, spatial_state)
return win_part, in_proj, attention, out_proj, win_rev
def _get_swin_replace(
tree_match: TreeMatch,
) -> tuple[NodeInserter]:
"""Return a replacement function for the full Swin decomposition."""
swa_node = tree_match.get_node(0)
graph_module = swa_node.graph.owning_module
assert isinstance(graph_module, fx.GraphModule)
win_part, in_proj, attention, out_proj, win_rev = (
_decompose_swin_attention(swa_node, graph_module)
)
def _insert( # pylint: disable=too-many-locals
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
# prev_args contains all fx.Node positional args of the
# shifted_window_attention call. Only the first (input) is
# used; the rest (weight get_attr, rel-bias node) are superseded
# by the new modules.
input_node = prev_args[0]
graph = graph_module.graph
# 1. SwinWindowPartition
wp_name = get_auto_name(graph_module, win_part)
graph_module.add_module(wp_name, win_part)
with graph.inserting_after(input_node):
wp_node = graph.call_module(wp_name, (input_node,))
# 2. MHAInProjection (self-attention: q=k=v)
ip_name = get_auto_name(graph_module, in_proj)
graph_module.add_module(ip_name, in_proj)
with graph.inserting_after(wp_node):
ip_node = graph.call_module(
ip_name,
(wp_node, wp_node, wp_node),
)
# 3. getitem nodes for Q, K, V
with graph.inserting_after(ip_node):
q = graph.call_function(
operator.getitem,
(ip_node, 0),
)
with graph.inserting_after(q):
k = graph.call_function(
operator.getitem,
(ip_node, 1),
)
with graph.inserting_after(k):
v = graph.call_function(
operator.getitem,
(ip_node, 2),
)
# 4. SwinAttention
at_name = get_auto_name(graph_module, attention)
graph_module.add_module(at_name, attention)
with graph.inserting_after(v):
at_node = graph.call_module(at_name, (q, k, v))
# 5. nn.Linear (output projection)
op_name = get_auto_name(graph_module, out_proj)
graph_module.add_module(op_name, out_proj)
with graph.inserting_after(at_node):
op_node = graph.call_module(op_name, (at_node,))
# 6. SwinWindowReverse
wr_name = get_auto_name(graph_module, win_rev)
graph_module.add_module(wr_name, win_rev)
with graph.inserting_after(op_node):
wr_node = graph.call_module(wr_name, (op_node,))
return [
wp_node,
ip_node,
q,
k,
v,
at_node,
op_node,
wr_node,
]
return (_insert,)
class DecomposeSwinAttentionPattern(Pattern):
"""Decompose ``shifted_window_attention`` into explicit sub-modules.
Matches the ``call_function`` node produced when
``torchvision.models.swin_transformer.ShiftedWindowAttention`` is
FX-traced (``shifted_window_attention`` is ``torch.fx.wrap``-ped).
Replaces each node with five sub-modules:
1. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowPartition`
2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` (reused)
3. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinAttention`
4. ``nn.Linear`` (output projection, reused from the original)
5. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowReverse`
Only Swin V1 (``ShiftedWindowAttention``) is supported. V2 nodes
that carry a ``logit_scale`` keyword argument are skipped.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_shifted_window_attention,)
graft: Graft = (_get_swin_replace,)
# -- Compose raw SDPA operations ----------------------------------------
def _is_transpose_last_two(node: fx.Node) -> bool:
"""Return ``True`` for ``tensor.transpose(-1, -2)``."""
if node.op != "call_method" or node.target != "transpose":
return False
non_node = [a for a in node.args if not isinstance(a, fx.Node)]
return set(non_node) == {-1, -2}
@node_check
def _is_matmul(node: fx.Node) -> bool:
"""Return ``True`` for a ``torch.matmul`` call."""
return node.op == "call_function" and node.target is torch.matmul
def _is_scale(node: fx.Node) -> bool:
"""Return ``True`` for a scalar scale (div or mul) with one tensor."""
if node.op != "call_function":
return False
if node.target not in {operator.truediv, operator.mul}:
return False
node_args = [a for a in node.args if isinstance(a, fx.Node)]
return len(node_args) == 1
def _is_softmax(node: fx.Node) -> bool:
"""Return ``True`` for ``F.softmax`` or ``torch.softmax``."""
if node.op != "call_function":
return False
return node.target in {F.softmax, torch.softmax}
def _is_dropout(node: fx.Node) -> bool:
"""Return ``True`` for an ``nn.Dropout`` call."""
return isinstance(get_module(node), nn.Dropout)
def _is_transpose_0213(node: fx.Node) -> bool:
"""Return ``True`` for ``tensor.permute(0, 2, 1, 3)``."""
if node.op != "call_method" or node.target != "permute":
return False
non_node = [a for a in node.args if not isinstance(a, fx.Node)]
return tuple(non_node) == (0, 2, 1, 3)
@node_check
def _is_view_or_reshape(node: fx.Node) -> bool:
"""Return ``True`` for view or reshape."""
return node.op == "call_method" and node.target in {"view", "reshape"}
@node_check
def _is_attention_matmul(node: fx.Node) -> bool:
"""Return ``True`` for the outer ``torch.matmul(attn_probs, V)``.
Requires 4-D shape metadata so ``num_heads`` and ``head_dim`` can be
read in :meth:`ComposeScaledDotProductAttentionPattern.replace`.
"""
if node.op != "call_function" or node.target is not torch.matmul:
return False
meta = node.meta.get("tensor_meta")
return meta is not None and len(meta.shape) == 4
def _make_sdpa(
tree_match: TreeMatch,
) -> tuple[ScaledDotProductAttention]:
"""Build a ``ScaledDotProductAttention`` from matched attention ops."""
# The pattern's pre-trunk node is matmul(attn_probs, V) with
# 4-D meta [B, H, S, D] (guaranteed by ``_is_attention_matmul``).
attn_matmul = tree_match.pre_trunk_nodes[0]
shape = attn_matmul.meta["tensor_meta"].shape
num_heads = int(shape[1])
head_dim = int(shape[3])
# Optional dropout: nested[0] is the inner matmul fork whose
# trunk[2] is the dropout wildcard.
dropout_wild = tree_match.get_node(0, 2, is_wildcard=True)
dropout = 0.0
if dropout_wild.nodes:
drop_mod = resolve_module(
dropout_wild.nodes[0],
nn.Dropout,
)
dropout = drop_mod.p
sdpa = ScaledDotProductAttention(num_heads, head_dim, dropout)
return (sdpa,)
class ComposeScaledDotProductAttentionPattern(Pattern):
"""Compose raw ``matmul/scale/softmax/matmul`` into ``ScaledDotProductAttention``.
Matches the raw operations that implement scaled dot-product attention:
.. code-block:: text
Q ──────→ matmul → scale → softmax → [dropout] → matmul
K.T ────→↗ ↗
V ──────────────────────────────────────────────→↗
Replaces them with a single
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
module that takes ``(Q, K, V)`` and returns
``[B, num_heads, S, head_dim]``. Any trailing
``permute → contiguous → view`` that flattens the heads is left
untouched in the graph and consumes the module's 4-D output.
"""
phase = Phase.CONVERSION
tree: Tree = Fork(
inputs=(
Fork(
inputs=(
(),
(_is_transpose_last_two,),
),
operator=_is_matmul,
output=(
_is_scale,
_is_softmax,
Wildcard(_is_dropout, "?"),
),
),
(),
),
operator=_is_attention_matmul,
output=(),
)
graft: Graft = (_make_sdpa,)
# -- Compose parallel linears into MHAInProjection ----------------------
def _is_transpose_1_2(node: fx.Node) -> bool:
"""Return ``True`` for ``tensor.transpose(1, 2)``."""
if node.op != "call_method" or node.target != "transpose":
return False
non_node = [a for a in node.args if not isinstance(a, fx.Node)]
return set(non_node) == {1, 2}
def _is_transpose_1_2_or_permute_0213(node: fx.Node) -> bool:
"""Return ``True`` for either ``transpose(1, 2)`` or ``permute(0, 2, 1, 3)``."""
return _is_transpose_1_2(node) or _is_transpose_0213(node)
@node_check
def _is_sdpa_module(node: fx.Node) -> bool:
"""Return ``True`` for a ``ScaledDotProductAttention`` call_module node."""
return isinstance(get_module(node), ScaledDotProductAttention)
#: Shared across the three Q/K/V branches so they are all constrained to
#: the same physical source tensor. Re-using one instance means the
#: first branch to run caches the source node and the other two succeed
#: only when they see that exact node.
_parallel_linears_shared_input = SharedNodeCheck(lambda _: True)
#: One of the three Q/K/V projection branches: a ``view``/``reshape`` with
#: shape arguments, a head-ordering permutation (``transpose(1,2)`` or
#: ``permute(0,2,1,3)``), and the ``nn.Linear`` whose input is constrained
#: by :data:`_parallel_linears_shared_input`.
_parallel_linears_branch: Tree = Fork(
inputs=(
(_parallel_linears_shared_input, nn.Linear),
(),
(),
),
operator=_is_view_or_reshape,
output=(_is_transpose_1_2_or_permute_0213,),
perms_override=((0,), (0, 1), (0, 1, 2)),
)
def _get_parallel_linears_insert(
shared_input: fx.Node,
in_proj: MHAInProjection,
sdpa_node: fx.Node,
) -> NodeInserter:
"""Return a replacement that rewires SDPA inputs through an in-projection.
The three old ``transpose → view → nn.Linear`` input chains are part
of the matched tree and erased by
:func:`~embedl_deploy._internal.core.replace.replace_tree` once the
SDPA's args are rewired.
"""
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
del prev_args # inputs are derived from the shared pre-linear tensor
replaced = get_replaced_nodes()
resolved_input = replaced.get(shared_input, shared_input)
graph = graph_module.graph
ip_name = get_auto_name(graph_module, in_proj)
graph_module.add_module(ip_name, in_proj)
with graph.inserting_after(resolved_input):
ip_node = graph.call_module(
ip_name,
(resolved_input, resolved_input, resolved_input),
)
gis: list[fx.Node] = []
prev = ip_node
for i in range(3):
with graph.inserting_after(prev):
gi = graph.call_function(
operator.getitem,
(ip_node, i),
)
gis.append(gi)
prev = gi
# Preserve any positional args after Q/K/V (e.g. ``attn_mask``)
# so masked-SDPA call sites keep their mask after rewiring.
sdpa_node.args = tuple(gis) + tuple(sdpa_node.args[3:])
return [resolved_input, ip_node, *gis, sdpa_node]
return _insert
def _branch_linear(tree_match: TreeMatch, branch: int) -> nn.Linear:
"""Return the ``nn.Linear`` module matched in the *branch*-th Q/K/V arm."""
linear_node = tree_match.get_node(branch, 0, 1)
return resolve_module(linear_node, nn.Linear)
def _make_parallel_linears(
tree_match: TreeMatch,
) -> tuple[NodeInserter]:
"""Pack three Q/K/V linears into an ``MHAInProjection``."""
sdpa_node = tree_match.pre_trunk_nodes[0]
sdpa_mod = resolve_module(sdpa_node, ScaledDotProductAttention)
num_heads = sdpa_mod.num_heads
head_dim = sdpa_mod.head_dim
embed_dim = num_heads * head_dim
shared_input = tree_match.get_node(0, 0, 0)
q_lin = _branch_linear(tree_match, 0)
k_lin = _branch_linear(tree_match, 1)
v_lin = _branch_linear(tree_match, 2)
has_bias = q_lin.bias is not None
packed = nn.Linear(embed_dim, 3 * embed_dim, bias=has_bias)
packed.weight = nn.Parameter(
torch.cat(
[q_lin.weight, k_lin.weight, v_lin.weight],
dim=0,
)
)
if has_bias:
packed.bias = nn.Parameter(
torch.cat(
[
q_lin.bias,
k_lin.bias,
v_lin.bias,
],
dim=0,
)
)
in_proj = MHAInProjection(packed, num_heads, head_dim)
return (_get_parallel_linears_insert(shared_input, in_proj, sdpa_node),)
class ComposeParallelLinearsPattern(Pattern):
"""Compose three parallel ``nn.Linear`` Q/K/V into ``MHAInProjection``.
Matches a
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
node whose three inputs each trace back through
``transpose(1, 2) → view → nn.Linear`` from the same source tensor.
The three branches are tied to a single source node by a
:class:`~embedl_deploy._internal.core.pattern.SharedNodeCheck` shared
across their data sub-trunks.
Packs the three separate linear weights into a single
``nn.Linear(embed_dim, 3 * embed_dim)`` and wraps it in an
:class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`.
Supports both the no-mask SDPA call shape ``(Q, K, V)`` and the
masked shape ``(Q, K, V, attn_mask)`` produced when
:class:`WrapFunctionalSDPAPattern` forwards an ``attn_mask`` positional
from the source ``torch.nn.functional.scaled_dot_product_attention`` call.
The 4th branch ``()`` accepts any single node as the mask edge; the
rewriter preserves it on the rewired SDPA call.
Depends on
:class:`ComposeScaledDotProductAttentionPattern` having run
first (handled automatically by the iterative conversion loop).
"""
phase = Phase.CONVERSION
tree: Tree = Fork(
inputs=(
_parallel_linears_branch,
_parallel_linears_branch,
_parallel_linears_branch,
(),
),
operator=_is_sdpa_module,
output=(),
perms_override=((0, 1, 2), (0, 1, 2, 3)),
)
graft: Graft = (_make_parallel_linears,)
def match(
self,
graph_module: fx.GraphModule,
) -> list[PatternMatch]:
matches = super().match(graph_module)
return [m for m in matches if self._linears_compatible(m)]
@staticmethod
def _linears_compatible(pattern_match: PatternMatch) -> bool:
"""Return ``True`` when all three matched Linears are shape-compatible.
Required for weight packing: shape/bias constraints can't be
expressed in the tree grammar, so they are checked here to
reject otherwise-structural matches before replacement runs.
"""
first = _branch_linear(pattern_match.tree_match, 0)
for i in (1, 2):
lin = _branch_linear(pattern_match.tree_match, i)
if lin.in_features != first.in_features:
return False
if lin.out_features != first.out_features:
return False
if (lin.bias is None) != (first.bias is None):
return False
return True
# -- Wrap F.scaled_dot_product_attention into ScaledDotProductAttention --
@node_check
def _is_functional_sdpa(node: fx.Node) -> bool:
"""Return ``True`` for an ``F.scaled_dot_product_attention`` call.
Requires 4-D shape metadata so ``num_heads`` and ``head_dim`` can be
read from the output tensor in :meth:`WrapFunctionalSDPAPattern.replace`.
"""
if (
node.op != "call_function"
or node.target is not F.scaled_dot_product_attention
):
return False
meta = node.meta.get("tensor_meta")
return meta is not None and len(meta.shape) == 4
def _wrap_functional_sdpa(
tree_match: TreeMatch,
) -> tuple[ScaledDotProductAttention]:
"""Build a ``ScaledDotProductAttention`` from a functional SDPA node.
``F.scaled_dot_product_attention``'s signature is
``(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False,
scale=None, enable_gqa=False)``. The first four args are forwarded
by the replacement ``call_module`` node, so only the trailing
scalar kwargs need to be lifted onto the
:class:`ScaledDotProductAttention` instance.
Models that pre-scale Q themselves (e.g. chronos-2's RoPE path,
where Q is multiplied by ``1/√d_head`` before SDPA) explicitly
call ``F.scaled_dot_product_attention(..., scale=1.0)`` to disable
the default ``1/√d_head`` scaling. Dropping ``scale`` here would
apply the default on top of the model's pre-scaling, collapsing
softmax toward uniform attention — same reasoning for
``is_causal``.
"""
sdpa_node = tree_match.get_node(0)
shape = sdpa_node.meta["tensor_meta"].shape
num_heads = int(shape[1])
head_dim = int(shape[3])
_T = TypeVar("_T")
def _read_scalar_kwarg(
idx: int, name: str, default: _T
) -> float | bool | _T:
"""Read kwarg ``name`` from positional slot ``idx`` or kwargs dict."""
if len(sdpa_node.args) > idx and not isinstance(
sdpa_node.args[idx], fx.Node
):
return cast("float | bool", sdpa_node.args[idx])
if name in sdpa_node.kwargs:
return cast("float | bool", sdpa_node.kwargs[name])
return default
dropout_p = float(_read_scalar_kwarg(4, "dropout_p", 0.0))
is_causal = bool(_read_scalar_kwarg(5, "is_causal", False))
scale_val = _read_scalar_kwarg(6, "scale", None)
scale = float(scale_val) if scale_val is not None else None
return (
ScaledDotProductAttention(
num_heads,
head_dim,
dropout_p,
is_causal,
scale,
),
)
class WrapFunctionalSDPAPattern(Pattern):
"""Wrap functional SDPA into ``ScaledDotProductAttention``.
Matches the ``torch.nn.functional.scaled_dot_product_attention`` call
produced either directly by user code or by
:class:`~embedl_deploy._internal.core.patterns.recompositions.functional.AtenScaledDotProductAttentionPattern`
rewriting an exported ``aten.scaled_dot_product_attention`` node:
.. code-block:: text
Q ─┐
K ─┼→ F.scaled_dot_product_attention
V ─┘
Replaces the single node 1:1 with a
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
module call. The module preserves the 4-D
``[B, num_heads, S, head_dim]`` output, so any post-op chain
(head-flatten reshape, batch-first transpose, etc.) is left
untouched in the graph regardless of the originating layout.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_functional_sdpa,)
graft: Graft = (_wrap_functional_sdpa,)