Source code for embedl_deploy._internal.tensorrt.modules.attention

# Copyright (C) 2026 Embedl AB

"""Attention sub-modules introduced by MHA decomposition.

These plain ``nn.Module`` subclasses replace the opaque
``nn.MultiheadAttention`` in the FX graph.  Phase 2 creates ``Fused*`` wrappers
around them for Q/DQ insertion.
"""

import torch
import torch.nn.functional as F
from torch import nn

from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
    QuantStub,
    WeightFakeQuantize,
)


class MHAInProjection(ConvertedModule):
    """Packed Q/K/V input projection: ``Linear(E, 3E) → chunk → reshape``.

    :param linear:
        An ``nn.Linear(embed_dim, 3 * embed_dim)`` holding the packed Q/K/V
        projection weights and optional bias.
    :param num_heads:
        Number of attention heads.
    :param head_dim:
        Dimension of each head (``embed_dim // num_heads``).
    """

    def __init__(
        self,
        linear: nn.Linear,
        num_heads: int,
        head_dim: int,
    ) -> None:
        super().__init__()
        self.linear = linear
        self.num_heads = num_heads
        self.head_dim = head_dim

    def forward(
        self,
        query: torch.Tensor,
        _key: torch.Tensor,
        _value: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
        """Project input to per-head ``(Q, K, V)`` tensors.

        Only `query` is used; `_key` and `_value` are accepted so the signature
        matches the ``MultiheadAttention(q, k, v)`` call-site (they are
        identical for self-attention).

        :param query:
            Input tensor of shape ``[B, S, E]``.
        :returns:
            Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
        """
        batch, seq, _ = query.shape
        qkv = self.linear(query)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        return q, k, v

    def __repr__(self) -> str:
        embed_dim = self.num_heads * self.head_dim
        return (
            f"MHAInProjection("
            f"embed_dim={embed_dim}, "
            f"num_heads={self.num_heads}, "
            f"head_dim={self.head_dim})"
        )


class ScaledDotProductAttention(ConvertedModule):
    """Core attention: ``softmax(Q · Kᵀ / √H) · V``.

    :param num_heads:
        Number of attention heads.
    :param head_dim:
        Dimension of each head.
    :param dropout:
        Dropout probability (applied during training only).
    """

    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dropout = dropout

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> torch.Tensor:
        """Compute scaled dot-product attention.

        :param q:
            Query tensor ``[B, num_heads, S, head_dim]``.
        :param k:
            Key tensor ``[B, num_heads, S, head_dim]``.
        :param v:
            Value tensor ``[B, num_heads, S, head_dim]``.
        :returns:
            Output tensor ``[B, S, embed_dim]``.
        """
        # pylint: disable-next=not-callable
        attn = F.scaled_dot_product_attention(
            q,
            k,
            v,
            dropout_p=self.dropout if self.training else 0.0,
        )
        batch, _, seq, _ = attn.shape
        return attn.transpose(1, 2).contiguous().view(batch, seq, -1)

    def __repr__(self) -> str:
        return (
            f"ScaledDotProductAttention("
            f"num_heads={self.num_heads}, "
            f"head_dim={self.head_dim}, "
            f"dropout={self.dropout})"
        )


[docs] class FusedMHAInProjection(FusedModule): """Fused wrapper for :class:`MHAInProjection`. Allows the Q/DQ insertion pass to place quantize / dequantize stubs around the input projection and to attach a :class:`~embedl_deploy._internal.core.quantize.modules.WeightFakeQuantize` for the packed linear weight. :param in_proj: The :class:`MHAInProjection` from the decomposed MHA. """ inputs_to_quantize: set[int] = {0} def __init__(self, in_proj: MHAInProjection) -> None: super().__init__() self.in_proj = in_proj self.weight_fake_quant = WeightFakeQuantize({self})
[docs] def forward( self, query: torch.Tensor, _key: torch.Tensor, _value: torch.Tensor, ) -> tuple[torch.Tensor, ...]: """Project input to per-head ``(Q, K, V)`` tensors. When :attr:`weight_fake_quant` is set, fake-quantizes the packed projection weight before the linear operation. Only `query` is used; `_key` and `_value` are accepted to match the call-site signature but ignored for self-attention. :param query: Input tensor of shape ``[B, S, E]``. :returns: Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``. """ weight = self.weight_fake_quant(self.in_proj.linear.weight) batch, seq, _ = query.shape # pylint: disable-next=not-callable qkv = F.linear(query, weight, self.in_proj.linear.bias) q, k, v = qkv.chunk(3, dim=-1) num_heads = self.in_proj.num_heads head_dim = self.in_proj.head_dim q = q.view(batch, seq, num_heads, head_dim).transpose(1, 2) k = k.view(batch, seq, num_heads, head_dim).transpose(1, 2) v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2) return q, k, v
def __repr__(self) -> str: embed_dim = self.in_proj.num_heads * self.in_proj.head_dim return ( f"FusedMHAInProjection(" f"embed_dim={embed_dim}, " f"num_heads={self.in_proj.num_heads}, " f"head_dim={self.in_proj.head_dim})" )
[docs] class FusedScaledDotProductAttention(FusedModule): """Fused wrapper for :class:`ScaledDotProductAttention`. Allows the Q/DQ insertion pass to place quantize / dequantize stubs on each of the three inputs (Q, K, V). Additionally holds an internal :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` between the softmax output and the second batched matrix multiply (BMM2). When ``None`` (before Q/DQ insertion) the forward pass is numerically identical to the unwrapped :class:`ScaledDotProductAttention`. :param attention: The :class:`ScaledDotProductAttention` from the decomposed MHA. """ inputs_to_quantize: set[int] = set() def __init__(self, attention: ScaledDotProductAttention) -> None: super().__init__() self.attention = attention self.softmax_quant = QuantStub( consumers={self}, n_bits=8, symmetric=True, fixed_calibration=(1.0 / 127, 0), )
[docs] def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, ) -> torch.Tensor: """Compute scaled dot-product attention. Performs manual attention with an internal quantization step between softmax and BMM2. :param q: Query tensor ``[B, num_heads, S, head_dim]``. :param k: Key tensor ``[B, num_heads, S, head_dim]``. :param v: Value tensor ``[B, num_heads, S, head_dim]``. :returns: Output tensor ``[B, S, embed_dim]``. """ if not self.softmax_quant.enabled: return self.attention(q, k, v) batch, _, seq, head_dim = q.shape scale = head_dim ** (-0.5) attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = self.softmax_quant(attn_weight) if self.attention.training and self.attention.dropout > 0.0: attn_weight = F.dropout( attn_weight, p=self.attention.dropout, ) attn = torch.matmul(attn_weight, v) return attn.transpose(1, 2).contiguous().view(batch, seq, -1)
def __repr__(self) -> str: a = self.attention qdq = "yes" if self.softmax_quant.enabled else "no" return ( f"FusedScaledDotProductAttention(" f"num_heads={a.num_heads}, " f"head_dim={a.head_dim}, " f"dropout={a.dropout}, " f"internal_qdq={qdq})" )