Source code for embedl_deploy._internal.tensorrt.modules.attention

# Copyright (C) 2026 Embedl AB

"""Attention sub-modules introduced by MHA decomposition.

These plain ``nn.Module`` subclasses replace the opaque
``nn.MultiheadAttention`` in the FX graph.  Phase 2 creates ``Fused*`` wrappers
around them for Q/DQ insertion.
"""

import math

import torch
import torch.nn.functional as F
from torch import nn

from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
from embedl_deploy._internal.core.quantize.stubs import QuantStub
from embedl_deploy._internal.tensorrt.modules.linear import (
    attach_int8_weight_quant,
    maybe_quantize_weight,
)


class MHAInProjection(ConvertedModule):
    """Packed Q/K/V input projection: ``Linear(E, 3E) → chunk → reshape``.

    :param linear:
        An ``nn.Linear(embed_dim, 3 * embed_dim)`` holding the packed Q/K/V
        projection weights and optional bias.
    :param num_heads:
        Number of attention heads.
    :param head_dim:
        Dimension of each head (``embed_dim // num_heads``).
    """

    def __init__(
        self,
        linear: nn.Linear,
        num_heads: int,
        head_dim: int,
    ) -> None:
        super().__init__()
        self.linear = linear
        self.num_heads = num_heads
        self.head_dim = head_dim

    def forward(
        self,
        query: torch.Tensor,
        _key: torch.Tensor,
        _value: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        """Project input to per-head ``(Q, K, V)`` tensors.

        Only `query` is used; `_key` and `_value` are accepted so the signature
        matches the ``MultiheadAttention(q, k, v)`` call-site (they are
        identical for self-attention).

        :param query:
            Input tensor of shape ``[B, S, E]``.
        :returns:
            Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
        """
        batch, seq, _ = query.shape
        qkv = self.linear(query)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        return q, k, v

    def __repr__(self) -> str:  # pragma: no cover
        embed_dim = self.num_heads * self.head_dim
        return (
            f"MHAInProjection("
            f"embed_dim={embed_dim}, "
            f"num_heads={self.num_heads}, "
            f"head_dim={self.head_dim})"
        )


class ScaledDotProductAttention(ConvertedModule):
    """Core attention: ``softmax(Q · Kᵀ · scale) · V``.

    :param num_heads:
        Number of attention heads.
    :param head_dim:
        Dimension of each head.
    :param dropout:
        Dropout probability (applied during training only).
    :param is_causal:
        Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
        of ``F.scaled_dot_product_attention``.
    :param scale:
        Explicit attention score scale (multiplied on Q·Kᵀ). When
        ``None`` the PyTorch default ``1/√head_dim`` is used. Models
        that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
        ``scale=1.0`` so the default scaling does not apply twice.
    """

    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        dropout: float = 0.0,
        is_causal: bool = False,
        scale: float | None = None,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dropout = dropout
        self.is_causal = is_causal
        self.scale = scale

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        attn_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Compute scaled dot-product attention.

        :param q:
            Query tensor ``[B, num_heads, S, head_dim]``.
        :param k:
            Key tensor ``[B, num_heads, S, head_dim]``.
        :param v:
            Value tensor ``[B, num_heads, S, head_dim]``.
        :param attn_mask:
            Optional attention mask.
            ``torch.nn.functional.scaled_dot_product_attention`` takes an
            optional 4th positional arg; ``WrapFunctionalSDPAPattern``
            forwards whatever positional args were on the source node, so
            this module accepts the mask too. SAM3, masked-LM, and
            similar models that compile with mixed-mask attention rely
            on this. Passes through to ``F.scaled_dot_product_attention``
            unchanged (``None`` is the no-mask default).
        :returns:
            Output tensor ``[B, num_heads, S, head_dim]``.  Callers are
            responsible for any subsequent head-flattening reshape.
        """
        # pylint: disable-next=not-callable
        return F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=self.is_causal,
            scale=self.scale,
        )

    def __repr__(self) -> str:  # pragma: no cover
        return (
            f"ScaledDotProductAttention("
            f"num_heads={self.num_heads}, "
            f"head_dim={self.head_dim}, "
            f"dropout={self.dropout}, "
            f"is_causal={self.is_causal}, "
            f"scale={self.scale})"
        )


[docs] class FusedMHAInProjection(FusedModule): """Fused wrapper for ``MHAInProjection``. Allows the Q/DQ insertion pass to place quantize / dequantize stubs around the input projection and to attach a :class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize` for the packed linear weight. :param in_proj: The :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection` from the decomposed MHA. """ inputs_to_quantize: set[int] = {0} def __init__(self, in_proj: MHAInProjection) -> None: super().__init__() self.in_proj = in_proj attach_int8_weight_quant(self, in_proj.linear)
[docs] def forward( self, query: torch.Tensor, _key: torch.Tensor, _value: torch.Tensor, ) -> tuple[torch.Tensor, ...]: """Project input to per-head ``(Q, K, V)`` tensors. Fake-quantizes the packed projection weight when enabled, then performs the linear operation. Only `query` is used; `_key` and `_value` are accepted to match the call-site signature but ignored for self-attention. :param query: Input tensor of shape ``[B, S, E]``. :returns: Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``. """ weight = maybe_quantize_weight(self, self.in_proj.linear.weight) batch, seq, _ = query.shape # pylint: disable-next=not-callable qkv = F.linear(query, weight, self.in_proj.linear.bias) q, k, v = qkv.chunk(3, dim=-1) num_heads = self.in_proj.num_heads head_dim = self.in_proj.head_dim q = q.view(batch, seq, num_heads, head_dim).transpose(1, 2) k = k.view(batch, seq, num_heads, head_dim).transpose(1, 2) v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2) return q, k, v
def __repr__(self) -> str: # pragma: no cover embed_dim = self.in_proj.num_heads * self.in_proj.head_dim return ( f"FusedMHAInProjection(" f"embed_dim={embed_dim}, " f"num_heads={self.in_proj.num_heads}, " f"head_dim={self.in_proj.head_dim})" )
[docs] class FusedScaledDotProductAttention(FusedModule): """Fused wrapper for ``ScaledDotProductAttention``. Allows the Q/DQ insertion pass to place quantize / dequantize stubs on each of the three inputs (Q, K, V). Additionally holds an internal :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the softmax output and the second batched matrix multiply (BMM2). When that stub is disabled the forward pass delegates to the unwrapped :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`; when enabled it performs manual attention with the quantization step. :param attention: The :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention` from the decomposed MHA. """ inputs_to_quantize: set[int] = set() def __init__(self, attention: ScaledDotProductAttention) -> None: super().__init__() self.attention = attention self.softmax_quant = QuantStub( consumers={self}, n_bits=8, symmetric=True, fixed_calibration=(1.0 / 127, 0), )
[docs] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: r"""Compute scaled dot-product attention. When the SDPA has been surrounded by ``QuantStub``\ s on its Q/K/V inputs *and* the internal softmax quant stub is enabled, performs manual attention with a quantization step between softmax and BMM2. Otherwise delegates to the wrapped attention module so TensorRT can fuse it into its native FP16 MHA kernel. :param q: Query tensor ``[B, num_heads, S, head_dim]``. :param k: Key tensor ``[B, num_heads, S, head_dim]``. :param v: Value tensor ``[B, num_heads, S, head_dim]``. :param attn_mask: Optional mask forwarded to the inner attention. Either an additive float mask broadcastable to ``[B, num_heads, S, S]`` or a bool mask where ``True`` means "attend". :returns: Output tensor ``[B, num_heads, S, head_dim]``. Callers are responsible for any subsequent head-flattening reshape. """ # Manual attention is only beneficial when this SDPA was # surrounded with input ``QuantStub``s (i.e. Q/K/V are arriving # in INT8). Without surround, ``configure`` may still have left # ``softmax_quant`` enabled — running manual attention then adds # a softmax Q/DQ pair that pushes TensorRT off its FP16 fused # MHA kernel onto the slower INT8-aware variant for no gain. if not self.surrounded or not self.softmax_quant.enabled: return self.attention(q, k, v, attn_mask) # Honour the wrapped attention module's explicit ``scale`` if # set — models that pre-scale Q themselves (chronos-2 + RoPE, # for example) build with ``scale=1.0`` to disable the default # ``1/sqrt(head_dim)`` scaling. Falling back to the default # here would apply it twice and collapse softmax. # Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the # tensor Pow with a negative float exponent traces to ONNX as a # ``Cast → complex128`` node that TRT 10.x can't parse. if self.attention.scale is not None: scale = self.attention.scale else: scale = 1.0 / math.sqrt(q.shape[-1]) attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_weight = attn_weight.masked_fill( ~attn_mask, float("-inf"), ) else: attn_weight = attn_weight + attn_mask attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = self.softmax_quant(attn_weight) if self.attention.training and self.attention.dropout > 0.0: attn_weight = F.dropout( attn_weight, p=self.attention.dropout, ) return torch.matmul(attn_weight, v)
def __repr__(self) -> str: # pragma: no cover a = self.attention qdq = "yes" if self.softmax_quant.enabled else "no" return ( f"FusedScaledDotProductAttention(" f"num_heads={a.num_heads}, " f"head_dim={a.head_dim}, " f"dropout={a.dropout}, " f"internal_qdq={qdq})" )