Source code for embedl_deploy._internal.tensorrt.patterns.conversions.attention

# Copyright (C) 2026 Embedl AB

"""Attention-related conversion patterns.

Decomposes ``nn.MultiheadAttention`` and ``shifted_window_attention`` into
explicit sub-modules so downstream fusion and Q/DQ passes can match the
resulting :class:`~embedl_deploy._internal.tensorrt.modules.attention` and
:class:`~embedl_deploy._internal.tensorrt.modules.swin_attention` modules.
"""

import logging
import operator
from typing import TypeVar, cast

import torch
import torch.nn.functional as F
from torch import fx, nn

from embedl_deploy._internal.core.patterns.main import (
    Pattern,
    PatternMatch,
    Phase,
)
from embedl_deploy._internal.core.patterns.utils import is_list_of
from embedl_deploy._internal.core.tree.match import node_check
from embedl_deploy._internal.core.tree.replace import get_auto_name
from embedl_deploy._internal.core.tree.state import (
    SharedNodeCheck,
    get_replaced_nodes,
)
from embedl_deploy._internal.core.tree.types import (
    Fork,
    Graft,
    NodeInserter,
    Tree,
    TreeMatch,
    Wildcard,
)
from embedl_deploy._internal.core.tree.utils import get_module, resolve_module
from embedl_deploy._internal.tensorrt.modules.attention import (
    MHAInProjection,
    ScaledDotProductAttention,
)
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
    SwinAttention,
    SwinSpatialState,
    SwinWindowPartition,
    SwinWindowReverse,
)

try:
    from torchvision.models.swin_transformer import (
        shifted_window_attention,
    )

    _TORCHVISION_AVAILABLE = True
except ModuleNotFoundError:  # pragma: no cover
    _TORCHVISION_AVAILABLE = False

_LOG = logging.getLogger(__name__)


def _is_supported_mha(node: fx.Node) -> bool:
    """Return ``True`` when ``node`` is a supported ``MultiheadAttention``.

    Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``)
    without masks or ``add_zero_attn`` is supported.  Unsupported
    configurations are skipped with a warning.
    """
    mod = get_module(node)
    if not isinstance(mod, nn.MultiheadAttention):
        return False

    warnings: list[str] = []

    if not getattr(mod, '_qkv_same_embed_dim'):
        warnings.append(
            f"Skipping MHA at {node.target}: cross-attention "
            f"(_qkv_same_embed_dim=False) not supported"
        )

    if not mod.batch_first:
        warnings.append(
            f"Skipping MHA at {node.target}: batch_first=False not supported"
        )

    if mod.add_zero_attn:
        warnings.append(
            f"Skipping MHA at {node.target}: add_zero_attn=True not supported"
        )

    if len(node.args) > 3:
        warnings.append(
            f"Skipping MHA at {node.target}: extra positional args "
            f"(possible mask arguments)"
        )

    for mask_key in ("attn_mask", "key_padding_mask"):
        val = node.kwargs.get(mask_key)
        if val is not None:
            warnings.append(
                f"Skipping MHA at {node.target}: {mask_key} is not None"
            )

    if warnings:
        for msg in warnings:
            _LOG.warning("%s", msg)
        return False

    return True


def _decompose_mha(
    tree_match: TreeMatch,
) -> tuple[MHAInProjection, NodeInserter, nn.Linear]:
    """Split the matched MHA into in-proj, attention-insert, and out-proj."""
    mha_node = tree_match.get_node(0)
    mha: nn.MultiheadAttention = resolve_module(
        mha_node, nn.MultiheadAttention
    )
    embed_dim = mha.embed_dim
    num_heads = mha.num_heads
    head_dim = embed_dim // num_heads
    linear = nn.Linear(
        embed_dim,
        3 * embed_dim,
        bias=mha.in_proj_bias is not None,
    )
    linear.weight = mha.in_proj_weight
    if mha.in_proj_bias is not None:
        linear.bias = mha.in_proj_bias
    in_proj = MHAInProjection(linear, num_heads, head_dim)
    attention = ScaledDotProductAttention(
        num_heads,
        head_dim,
        dropout=mha.dropout,
    )
    return in_proj, _get_attention_insert(attention), mha.out_proj


def _unwrap_mha_getitems(
    mha_node: fx.Node,
) -> None:
    """Replace ``getitem`` users of `mha_node` with direct references.

    ``MultiheadAttention`` returns ``(output, weights)``. ``getitem(mha, 0)``
    users are rewired to reference `mha_node` directly and ``getitem(mha, 1)``
    users are replaced with a ``None`` constant.  After this call `mha_node`
    behaves as a single-output node, allowing
    :func:`~embedl_deploy._internal.core.replace.replace_tree` to handle it
    normally.
    """
    graph_module = mha_node.graph.owning_module
    assert isinstance(graph_module, fx.GraphModule)
    replaced_nodes = get_replaced_nodes()
    resolved = replaced_nodes.get(mha_node, mha_node)
    for user in list(resolved.users.keys()):
        if user.op == "call_function" and user.target is operator.getitem:
            idx = user.args[1]
            if idx == 0:
                user.replace_all_uses_with(resolved)
                graph_module.graph.erase_node(user)
            elif idx == 1:
                with graph_module.graph.inserting_after(resolved):
                    none_node = graph_module.graph.call_function(
                        operator.getitem, ((None,), 0)
                    )
                user.replace_all_uses_with(none_node)
                graph_module.graph.erase_node(user)


def _get_attention_insert(
    attention: ScaledDotProductAttention,
) -> NodeInserter:
    """Return a replacement function for Q/K/V splitting.

    Splits the in-projection output into Q/K/V via ``getitem`` and feeds them
    into `attention`.  The attention module returns a 4-D
    ``[B, num_heads, S, head_dim]`` tensor; this insert appends a
    ``transpose(1, 2) → reshape`` chain that flattens it back to the
    ``[B, S, embed_dim]`` layout expected by the downstream out-projection.
    """

    def _insert(
        graph_module: fx.GraphModule,
        prev_args: tuple[fx.Node, ...],
    ) -> list[fx.Node]:
        name = get_auto_name(graph_module, attention)
        graph_module.add_module(name, attention)
        (proj_node,) = prev_args
        graph = graph_module.graph
        with graph.inserting_after(proj_node):
            q = graph.call_function(operator.getitem, (proj_node, 0))
        with graph.inserting_after(q):
            k = graph.call_function(operator.getitem, (proj_node, 1))
        with graph.inserting_after(k):
            v = graph.call_function(operator.getitem, (proj_node, 2))
        with graph.inserting_after(v):
            attn = graph.call_module(name, (q, k, v))
        with graph.inserting_after(attn):
            attn_t = graph.call_method("transpose", (attn, 1, 2))
        with graph.inserting_after(attn_t):
            attn_flat = graph.call_method("flatten", (attn_t, 2))
        return [q, k, v, attn, attn_t, attn_flat]

    return _insert


[docs] class DecomposeMultiheadAttentionPattern(Pattern): """Decompose ``nn.MultiheadAttention`` into explicit sub-modules. Replaces each ``MultiheadAttention`` node with three sub-modules visible in the FX graph: 1. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` 2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` 3. ``nn.Linear`` (output projection, reused from the original MHA) Only self-attention (``_qkv_same_embed_dim=True``, ``batch_first=True``) without masks is supported. Unsupported configurations are skipped with a warning. """ phase = Phase.CONVERSION tree: Tree = (_is_supported_mha,) graft: Graft = (_decompose_mha,)
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: _unwrap_mha_getitems(pattern_match.tree_match.get_node(0)) return super().replace(pattern_match)
# -- Swin Transformer attention decomposition ------------------------- def _is_shifted_window_attention(node: fx.Node) -> bool: """Return ``True`` for a supported ``shifted_window_attention`` call. Only Swin V1 is supported. V2 nodes (which carry a ``logit_scale`` kwarg) are rejected. Returns ``False`` when *torchvision* is not installed. """ if not _TORCHVISION_AVAILABLE: return False if node.op != "call_function": return False if node.target is not shifted_window_attention: return False if node.kwargs.get("logit_scale") is not None: _LOG.warning( "Skipping %s: Swin V2 (logit_scale) not supported", node.name, ) return False return True def _resolve_get_attr( graph_module: fx.GraphModule, node: fx.Node, ) -> torch.Tensor | nn.Parameter: """Follow a ``get_attr`` node to the actual tensor.""" assert isinstance(node.target, str) obj: object = graph_module for part in node.target.split("."): obj = getattr(obj, part) assert isinstance(obj, torch.Tensor) return obj def _decompose_swin_attention( # pylint: disable=too-many-locals swa_node: fx.Node, graph_module: fx.GraphModule, ) -> tuple[ SwinWindowPartition, MHAInProjection, SwinAttention, nn.Linear, SwinWindowReverse, ]: """Split a ``shifted_window_attention`` call into five sub-modules. Extracts weights and parameters from the ``get_attr`` nodes in the function call's arguments, since the ``fx.GraphModule`` stores traced-through modules as generic ``nn.Module`` wrappers. """ args = swa_node.args assert isinstance(args[1], fx.Node) qkv_weight = _resolve_get_attr(graph_module, args[1]) assert isinstance(args[2], fx.Node) proj_weight = _resolve_get_attr(graph_module, args[2]) assert is_list_of(args[4], int) window_size = args[4] assert isinstance(args[5], int) num_heads = args[5] kwargs = swa_node.kwargs assert is_list_of(kwargs["shift_size"], int) shift_size = kwargs["shift_size"] assert isinstance(kwargs["attention_dropout"], float) attention_dropout = kwargs["attention_dropout"] qkv_bias_node = kwargs.get("qkv_bias") proj_bias_node = kwargs.get("proj_bias") embed_dim = qkv_weight.shape[1] head_dim = embed_dim // num_heads # Relative position bias from the _get_relative_position_bias call. assert isinstance(args[3], fx.Node) rel_bias_node = args[3] assert isinstance(rel_bias_node.args[0], fx.Node) bias_table = _resolve_get_attr(graph_module, rel_bias_node.args[0]) assert isinstance(rel_bias_node.args[1], fx.Node) bias_index = _resolve_get_attr(graph_module, rel_bias_node.args[1]) # Reconstruct nn.Linear for QKV projection. bias = None if isinstance(qkv_bias_node, fx.Node): bias = nn.Parameter(_resolve_get_attr(graph_module, qkv_bias_node)) qkv_linear = nn.Linear(embed_dim, 3 * embed_dim, bias=bias is not None) qkv_linear.weight = nn.Parameter(qkv_weight) if bias is not None: qkv_linear.bias = bias # Reconstruct nn.Linear for output projection. bias = None if isinstance(proj_bias_node, fx.Node): bias = nn.Parameter(_resolve_get_attr(graph_module, proj_bias_node)) out_proj = nn.Linear(embed_dim, embed_dim, bias=bias is not None) out_proj.weight = nn.Parameter(proj_weight) if bias is not None: out_proj.bias = bias spatial_state = SwinSpatialState() win_part = SwinWindowPartition( window_size, shift_size, spatial_state, ) in_proj = MHAInProjection(qkv_linear, num_heads, head_dim) attention = SwinAttention( relative_position_bias_table=nn.Parameter(bias_table), relative_position_index=bias_index, window_size=window_size, num_heads=num_heads, head_dim=head_dim, attention_dropout=attention_dropout, spatial_state=spatial_state, ) win_rev = SwinWindowReverse(window_size, spatial_state) return win_part, in_proj, attention, out_proj, win_rev def _get_swin_replace( tree_match: TreeMatch, ) -> tuple[NodeInserter]: """Return a replacement function for the full Swin decomposition.""" swa_node = tree_match.get_node(0) graph_module = swa_node.graph.owning_module assert isinstance(graph_module, fx.GraphModule) win_part, in_proj, attention, out_proj, win_rev = ( _decompose_swin_attention(swa_node, graph_module) ) def _insert( # pylint: disable=too-many-locals graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: # prev_args contains all fx.Node positional args of the # shifted_window_attention call. Only the first (input) is # used; the rest (weight get_attr, rel-bias node) are superseded # by the new modules. input_node = prev_args[0] graph = graph_module.graph # 1. SwinWindowPartition wp_name = get_auto_name(graph_module, win_part) graph_module.add_module(wp_name, win_part) with graph.inserting_after(input_node): wp_node = graph.call_module(wp_name, (input_node,)) # 2. MHAInProjection (self-attention: q=k=v) ip_name = get_auto_name(graph_module, in_proj) graph_module.add_module(ip_name, in_proj) with graph.inserting_after(wp_node): ip_node = graph.call_module( ip_name, (wp_node, wp_node, wp_node), ) # 3. getitem nodes for Q, K, V with graph.inserting_after(ip_node): q = graph.call_function( operator.getitem, (ip_node, 0), ) with graph.inserting_after(q): k = graph.call_function( operator.getitem, (ip_node, 1), ) with graph.inserting_after(k): v = graph.call_function( operator.getitem, (ip_node, 2), ) # 4. SwinAttention at_name = get_auto_name(graph_module, attention) graph_module.add_module(at_name, attention) with graph.inserting_after(v): at_node = graph.call_module(at_name, (q, k, v)) # 5. nn.Linear (output projection) op_name = get_auto_name(graph_module, out_proj) graph_module.add_module(op_name, out_proj) with graph.inserting_after(at_node): op_node = graph.call_module(op_name, (at_node,)) # 6. SwinWindowReverse wr_name = get_auto_name(graph_module, win_rev) graph_module.add_module(wr_name, win_rev) with graph.inserting_after(op_node): wr_node = graph.call_module(wr_name, (op_node,)) return [ wp_node, ip_node, q, k, v, at_node, op_node, wr_node, ] return (_insert,) class DecomposeSwinAttentionPattern(Pattern): """Decompose ``shifted_window_attention`` into explicit sub-modules. Matches the ``call_function`` node produced when ``torchvision.models.swin_transformer.ShiftedWindowAttention`` is FX-traced (``shifted_window_attention`` is ``torch.fx.wrap``-ped). Replaces each node with five sub-modules: 1. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowPartition` 2. :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` (reused) 3. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinAttention` 4. ``nn.Linear`` (output projection, reused from the original) 5. :class:`~embedl_deploy._internal.tensorrt.modules.swin_attention.SwinWindowReverse` Only Swin V1 (``ShiftedWindowAttention``) is supported. V2 nodes that carry a ``logit_scale`` keyword argument are skipped. """ phase = Phase.CONVERSION tree: Tree = (_is_shifted_window_attention,) graft: Graft = (_get_swin_replace,) # -- Compose raw SDPA operations ---------------------------------------- def _is_transpose_last_two(node: fx.Node) -> bool: """Return ``True`` for ``tensor.transpose(-1, -2)``.""" if node.op != "call_method" or node.target != "transpose": return False non_node = [a for a in node.args if not isinstance(a, fx.Node)] return set(non_node) == {-1, -2} @node_check def _is_matmul(node: fx.Node) -> bool: """Return ``True`` for a ``torch.matmul`` call.""" return node.op == "call_function" and node.target is torch.matmul def _is_scale(node: fx.Node) -> bool: """Return ``True`` for a scalar scale (div or mul) with one tensor.""" if node.op != "call_function": return False if node.target not in {operator.truediv, operator.mul}: return False node_args = [a for a in node.args if isinstance(a, fx.Node)] return len(node_args) == 1 def _is_softmax(node: fx.Node) -> bool: """Return ``True`` for ``F.softmax`` or ``torch.softmax``.""" if node.op != "call_function": return False return node.target in {F.softmax, torch.softmax} def _is_dropout(node: fx.Node) -> bool: """Return ``True`` for an ``nn.Dropout`` call.""" return isinstance(get_module(node), nn.Dropout) def _is_transpose_0213(node: fx.Node) -> bool: """Return ``True`` for ``tensor.permute(0, 2, 1, 3)``.""" if node.op != "call_method" or node.target != "permute": return False non_node = [a for a in node.args if not isinstance(a, fx.Node)] return tuple(non_node) == (0, 2, 1, 3) @node_check def _is_view_or_reshape(node: fx.Node) -> bool: """Return ``True`` for view or reshape.""" return node.op == "call_method" and node.target in {"view", "reshape"} @node_check def _is_attention_matmul(node: fx.Node) -> bool: """Return ``True`` for the outer ``torch.matmul(attn_probs, V)``. Requires 4-D shape metadata so ``num_heads`` and ``head_dim`` can be read in :meth:`ComposeScaledDotProductAttentionPattern.replace`. """ if node.op != "call_function" or node.target is not torch.matmul: return False meta = node.meta.get("tensor_meta") return meta is not None and len(meta.shape) == 4 def _make_sdpa( tree_match: TreeMatch, ) -> tuple[ScaledDotProductAttention]: """Build a ``ScaledDotProductAttention`` from matched attention ops.""" # The pattern's pre-trunk node is matmul(attn_probs, V) with # 4-D meta [B, H, S, D] (guaranteed by ``_is_attention_matmul``). attn_matmul = tree_match.pre_trunk_nodes[0] shape = attn_matmul.meta["tensor_meta"].shape num_heads = int(shape[1]) head_dim = int(shape[3]) # Optional dropout: nested[0] is the inner matmul fork whose # trunk[2] is the dropout wildcard. dropout_wild = tree_match.get_node(0, 2, is_wildcard=True) dropout = 0.0 if dropout_wild.nodes: drop_mod = resolve_module( dropout_wild.nodes[0], nn.Dropout, ) dropout = drop_mod.p sdpa = ScaledDotProductAttention(num_heads, head_dim, dropout) return (sdpa,) class ComposeScaledDotProductAttentionPattern(Pattern): """Compose raw ``matmul/scale/softmax/matmul`` into ``ScaledDotProductAttention``. Matches the raw operations that implement scaled dot-product attention: .. code-block:: text Q ──────→ matmul → scale → softmax → [dropout] → matmul K.T ────→↗ ↗ V ──────────────────────────────────────────────→↗ Replaces them with a single :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` module that takes ``(Q, K, V)`` and returns ``[B, num_heads, S, head_dim]``. Any trailing ``permute → contiguous → view`` that flattens the heads is left untouched in the graph and consumes the module's 4-D output. """ phase = Phase.CONVERSION tree: Tree = Fork( inputs=( Fork( inputs=( (), (_is_transpose_last_two,), ), operator=_is_matmul, output=( _is_scale, _is_softmax, Wildcard(_is_dropout, "?"), ), ), (), ), operator=_is_attention_matmul, output=(), ) graft: Graft = (_make_sdpa,) # -- Compose parallel linears into MHAInProjection ---------------------- def _is_transpose_1_2(node: fx.Node) -> bool: """Return ``True`` for ``tensor.transpose(1, 2)``.""" if node.op != "call_method" or node.target != "transpose": return False non_node = [a for a in node.args if not isinstance(a, fx.Node)] return set(non_node) == {1, 2} def _is_transpose_1_2_or_permute_0213(node: fx.Node) -> bool: """Return ``True`` for either ``transpose(1, 2)`` or ``permute(0, 2, 1, 3)``.""" return _is_transpose_1_2(node) or _is_transpose_0213(node) @node_check def _is_sdpa_module(node: fx.Node) -> bool: """Return ``True`` for a ``ScaledDotProductAttention`` call_module node.""" return isinstance(get_module(node), ScaledDotProductAttention) #: Shared across the three Q/K/V branches so they are all constrained to #: the same physical source tensor. Re-using one instance means the #: first branch to run caches the source node and the other two succeed #: only when they see that exact node. _parallel_linears_shared_input = SharedNodeCheck(lambda _: True) #: One of the three Q/K/V projection branches: a ``view``/``reshape`` with #: shape arguments, a head-ordering permutation (``transpose(1,2)`` or #: ``permute(0,2,1,3)``), and the ``nn.Linear`` whose input is constrained #: by :data:`_parallel_linears_shared_input`. _parallel_linears_branch: Tree = Fork( inputs=( (_parallel_linears_shared_input, nn.Linear), (), (), ), operator=_is_view_or_reshape, output=(_is_transpose_1_2_or_permute_0213,), perms_override=((0,), (0, 1), (0, 1, 2)), ) def _get_parallel_linears_insert( shared_input: fx.Node, in_proj: MHAInProjection, sdpa_node: fx.Node, ) -> NodeInserter: """Return a replacement that rewires SDPA inputs through an in-projection. The three old ``transpose → view → nn.Linear`` input chains are part of the matched tree and erased by :func:`~embedl_deploy._internal.core.replace.replace_tree` once the SDPA's args are rewired. """ def _insert( graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: del prev_args # inputs are derived from the shared pre-linear tensor replaced = get_replaced_nodes() resolved_input = replaced.get(shared_input, shared_input) graph = graph_module.graph ip_name = get_auto_name(graph_module, in_proj) graph_module.add_module(ip_name, in_proj) with graph.inserting_after(resolved_input): ip_node = graph.call_module( ip_name, (resolved_input, resolved_input, resolved_input), ) gis: list[fx.Node] = [] prev = ip_node for i in range(3): with graph.inserting_after(prev): gi = graph.call_function( operator.getitem, (ip_node, i), ) gis.append(gi) prev = gi # Preserve any positional args after Q/K/V (e.g. ``attn_mask``) # so masked-SDPA call sites keep their mask after rewiring. sdpa_node.args = tuple(gis) + tuple(sdpa_node.args[3:]) return [resolved_input, ip_node, *gis, sdpa_node] return _insert def _branch_linear(tree_match: TreeMatch, branch: int) -> nn.Linear: """Return the ``nn.Linear`` module matched in the *branch*-th Q/K/V arm.""" linear_node = tree_match.get_node(branch, 0, 1) return resolve_module(linear_node, nn.Linear) def _make_parallel_linears( tree_match: TreeMatch, ) -> tuple[NodeInserter]: """Pack three Q/K/V linears into an ``MHAInProjection``.""" sdpa_node = tree_match.pre_trunk_nodes[0] sdpa_mod = resolve_module(sdpa_node, ScaledDotProductAttention) num_heads = sdpa_mod.num_heads head_dim = sdpa_mod.head_dim embed_dim = num_heads * head_dim shared_input = tree_match.get_node(0, 0, 0) q_lin = _branch_linear(tree_match, 0) k_lin = _branch_linear(tree_match, 1) v_lin = _branch_linear(tree_match, 2) has_bias = q_lin.bias is not None packed = nn.Linear(embed_dim, 3 * embed_dim, bias=has_bias) packed.weight = nn.Parameter( torch.cat( [q_lin.weight, k_lin.weight, v_lin.weight], dim=0, ) ) if has_bias: packed.bias = nn.Parameter( torch.cat( [ q_lin.bias, k_lin.bias, v_lin.bias, ], dim=0, ) ) in_proj = MHAInProjection(packed, num_heads, head_dim) return (_get_parallel_linears_insert(shared_input, in_proj, sdpa_node),) class ComposeParallelLinearsPattern(Pattern): """Compose three parallel ``nn.Linear`` Q/K/V into ``MHAInProjection``. Matches a :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` node whose three inputs each trace back through ``transpose(1, 2) → view → nn.Linear`` from the same source tensor. The three branches are tied to a single source node by a :class:`~embedl_deploy._internal.core.pattern.SharedNodeCheck` shared across their data sub-trunks. Packs the three separate linear weights into a single ``nn.Linear(embed_dim, 3 * embed_dim)`` and wraps it in an :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`. Supports both the no-mask SDPA call shape ``(Q, K, V)`` and the masked shape ``(Q, K, V, attn_mask)`` produced when :class:`WrapFunctionalSDPAPattern` forwards an ``attn_mask`` positional from the source ``torch.nn.functional.scaled_dot_product_attention`` call. The 4th branch ``()`` accepts any single node as the mask edge; the rewriter preserves it on the rewired SDPA call. Depends on :class:`ComposeScaledDotProductAttentionPattern` having run first (handled automatically by the iterative conversion loop). """ phase = Phase.CONVERSION tree: Tree = Fork( inputs=( _parallel_linears_branch, _parallel_linears_branch, _parallel_linears_branch, (), ), operator=_is_sdpa_module, output=(), perms_override=((0, 1, 2), (0, 1, 2, 3)), ) graft: Graft = (_make_parallel_linears,) def match( self, graph_module: fx.GraphModule, ) -> list[PatternMatch]: matches = super().match(graph_module) return [m for m in matches if self._linears_compatible(m)] @staticmethod def _linears_compatible(pattern_match: PatternMatch) -> bool: """Return ``True`` when all three matched Linears are shape-compatible. Required for weight packing: shape/bias constraints can't be expressed in the tree grammar, so they are checked here to reject otherwise-structural matches before replacement runs. """ first = _branch_linear(pattern_match.tree_match, 0) for i in (1, 2): lin = _branch_linear(pattern_match.tree_match, i) if lin.in_features != first.in_features: return False if lin.out_features != first.out_features: return False if (lin.bias is None) != (first.bias is None): return False return True # -- Wrap F.scaled_dot_product_attention into ScaledDotProductAttention -- @node_check def _is_functional_sdpa(node: fx.Node) -> bool: """Return ``True`` for an ``F.scaled_dot_product_attention`` call. Requires 4-D shape metadata so ``num_heads`` and ``head_dim`` can be read from the output tensor in :meth:`WrapFunctionalSDPAPattern.replace`. """ if ( node.op != "call_function" or node.target is not F.scaled_dot_product_attention ): return False meta = node.meta.get("tensor_meta") return meta is not None and len(meta.shape) == 4 def _wrap_functional_sdpa( tree_match: TreeMatch, ) -> tuple[ScaledDotProductAttention]: """Build a ``ScaledDotProductAttention`` from a functional SDPA node. ``F.scaled_dot_product_attention``'s signature is ``(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False)``. The first four args are forwarded by the replacement ``call_module`` node, so only the trailing scalar kwargs need to be lifted onto the :class:`ScaledDotProductAttention` instance. Models that pre-scale Q themselves (e.g. chronos-2's RoPE path, where Q is multiplied by ``1/√d_head`` before SDPA) explicitly call ``F.scaled_dot_product_attention(..., scale=1.0)`` to disable the default ``1/√d_head`` scaling. Dropping ``scale`` here would apply the default on top of the model's pre-scaling, collapsing softmax toward uniform attention — same reasoning for ``is_causal``. """ sdpa_node = tree_match.get_node(0) shape = sdpa_node.meta["tensor_meta"].shape num_heads = int(shape[1]) head_dim = int(shape[3]) _T = TypeVar("_T") def _read_scalar_kwarg( idx: int, name: str, default: _T ) -> float | bool | _T: """Read kwarg ``name`` from positional slot ``idx`` or kwargs dict.""" if len(sdpa_node.args) > idx and not isinstance( sdpa_node.args[idx], fx.Node ): return cast("float | bool", sdpa_node.args[idx]) if name in sdpa_node.kwargs: return cast("float | bool", sdpa_node.kwargs[name]) return default dropout_p = float(_read_scalar_kwarg(4, "dropout_p", 0.0)) is_causal = bool(_read_scalar_kwarg(5, "is_causal", False)) scale_val = _read_scalar_kwarg(6, "scale", None) scale = float(scale_val) if scale_val is not None else None return ( ScaledDotProductAttention( num_heads, head_dim, dropout_p, is_causal, scale, ), ) class WrapFunctionalSDPAPattern(Pattern): """Wrap functional SDPA into ``ScaledDotProductAttention``. Matches the ``torch.nn.functional.scaled_dot_product_attention`` call produced either directly by user code or by :class:`~embedl_deploy._internal.core.patterns.recompositions.functional.AtenScaledDotProductAttentionPattern` rewriting an exported ``aten.scaled_dot_product_attention`` node: .. code-block:: text Q ─┐ K ─┼→ F.scaled_dot_product_attention V ─┘ Replaces the single node 1:1 with a :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` module call. The module preserves the 4-D ``[B, num_heads, S, head_dim]`` output, so any post-op chain (head-flatten reshape, batch-first transpose, etc.) is left untouched in the graph regardless of the originating layout. """ phase = Phase.CONVERSION tree: Tree = (_is_functional_sdpa,) graft: Graft = (_wrap_functional_sdpa,)