# Copyright (C) 2026 Embedl AB
"""Attention sub-modules introduced by MHA decomposition.
These plain ``nn.Module`` subclasses replace the opaque
``nn.MultiheadAttention`` in the FX graph. Phase 2 creates ``Fused*`` wrappers
around them for Q/DQ insertion.
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
from embedl_deploy._internal.core.quantize.stubs import QuantStub
from embedl_deploy._internal.tensorrt.modules.linear import (
attach_int8_weight_quant,
maybe_quantize_weight,
)
class MHAInProjection(ConvertedModule):
"""Packed Q/K/V input projection: ``Linear(E, 3E) → chunk → reshape``.
:param linear:
An ``nn.Linear(embed_dim, 3 * embed_dim)`` holding the packed Q/K/V
projection weights and optional bias.
:param num_heads:
Number of attention heads.
:param head_dim:
Dimension of each head (``embed_dim // num_heads``).
"""
def __init__(
self,
linear: nn.Linear,
num_heads: int,
head_dim: int,
) -> None:
super().__init__()
self.linear = linear
self.num_heads = num_heads
self.head_dim = head_dim
def forward(
self,
query: torch.Tensor,
_key: torch.Tensor,
_value: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Project input to per-head ``(Q, K, V)`` tensors.
Only `query` is used; `_key` and `_value` are accepted so the signature
matches the ``MultiheadAttention(q, k, v)`` call-site (they are
identical for self-attention).
:param query:
Input tensor of shape ``[B, S, E]``.
:returns:
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
"""
batch, seq, _ = query.shape
qkv = self.linear(query)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
return q, k, v
def __repr__(self) -> str: # pragma: no cover
embed_dim = self.num_heads * self.head_dim
return (
f"MHAInProjection("
f"embed_dim={embed_dim}, "
f"num_heads={self.num_heads}, "
f"head_dim={self.head_dim})"
)
class ScaledDotProductAttention(ConvertedModule):
"""Core attention: ``softmax(Q · Kᵀ · scale) · V``.
:param num_heads:
Number of attention heads.
:param head_dim:
Dimension of each head.
:param dropout:
Dropout probability (applied during training only).
:param is_causal:
Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
of ``F.scaled_dot_product_attention``.
:param scale:
Explicit attention score scale (multiplied on Q·Kᵀ). When
``None`` the PyTorch default ``1/√head_dim`` is used. Models
that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
``scale=1.0`` so the default scaling does not apply twice.
"""
def __init__(
self,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
self.is_causal = is_causal
self.scale = scale
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute scaled dot-product attention.
:param q:
Query tensor ``[B, num_heads, S, head_dim]``.
:param k:
Key tensor ``[B, num_heads, S, head_dim]``.
:param v:
Value tensor ``[B, num_heads, S, head_dim]``.
:param attn_mask:
Optional attention mask.
``torch.nn.functional.scaled_dot_product_attention`` takes an
optional 4th positional arg; ``WrapFunctionalSDPAPattern``
forwards whatever positional args were on the source node, so
this module accepts the mask too. SAM3, masked-LM, and
similar models that compile with mixed-mask attention rely
on this. Passes through to ``F.scaled_dot_product_attention``
unchanged (``None`` is the no-mask default).
:returns:
Output tensor ``[B, num_heads, S, head_dim]``. Callers are
responsible for any subsequent head-flattening reshape.
"""
# pylint: disable-next=not-callable
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
scale=self.scale,
)
def __repr__(self) -> str: # pragma: no cover
return (
f"ScaledDotProductAttention("
f"num_heads={self.num_heads}, "
f"head_dim={self.head_dim}, "
f"dropout={self.dropout}, "
f"is_causal={self.is_causal}, "
f"scale={self.scale})"
)
[docs]
class FusedMHAInProjection(FusedModule):
"""Fused wrapper for ``MHAInProjection``.
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
around the input projection and to attach a
:class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize`
for the packed linear weight.
:param in_proj:
The
:class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`
from the decomposed MHA.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, in_proj: MHAInProjection) -> None:
super().__init__()
self.in_proj = in_proj
attach_int8_weight_quant(self, in_proj.linear)
[docs]
def forward(
self,
query: torch.Tensor,
_key: torch.Tensor,
_value: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Project input to per-head ``(Q, K, V)`` tensors.
Fake-quantizes the packed projection weight when enabled,
then performs the linear operation. Only `query` is used;
`_key` and `_value` are accepted to match the call-site
signature but ignored for self-attention.
:param query:
Input tensor of shape ``[B, S, E]``.
:returns:
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
"""
weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
batch, seq, _ = query.shape
# pylint: disable-next=not-callable
qkv = F.linear(query, weight, self.in_proj.linear.bias)
q, k, v = qkv.chunk(3, dim=-1)
num_heads = self.in_proj.num_heads
head_dim = self.in_proj.head_dim
q = q.view(batch, seq, num_heads, head_dim).transpose(1, 2)
k = k.view(batch, seq, num_heads, head_dim).transpose(1, 2)
v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
return q, k, v
def __repr__(self) -> str: # pragma: no cover
embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
return (
f"FusedMHAInProjection("
f"embed_dim={embed_dim}, "
f"num_heads={self.in_proj.num_heads}, "
f"head_dim={self.in_proj.head_dim})"
)
[docs]
class FusedScaledDotProductAttention(FusedModule):
"""Fused wrapper for ``ScaledDotProductAttention``.
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
on each of the three inputs (Q, K, V).
Additionally holds an internal
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between
the softmax output and the second batched matrix multiply (BMM2). When
that stub is disabled the forward pass delegates to the unwrapped
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`;
when enabled it performs manual attention with the quantization step.
:param attention:
The
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
from the decomposed MHA.
"""
inputs_to_quantize: set[int] = set()
def __init__(self, attention: ScaledDotProductAttention) -> None:
super().__init__()
self.attention = attention
self.softmax_quant = QuantStub(
consumers={self},
n_bits=8,
symmetric=True,
fixed_calibration=(1.0 / 127, 0),
)
[docs]
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
r"""Compute scaled dot-product attention.
When the SDPA has been surrounded by ``QuantStub``\ s on its Q/K/V
inputs *and* the internal softmax quant stub is enabled, performs
manual attention with a quantization step between softmax and
BMM2. Otherwise delegates to the wrapped attention module so
TensorRT can fuse it into its native FP16 MHA kernel.
:param q:
Query tensor ``[B, num_heads, S, head_dim]``.
:param k:
Key tensor ``[B, num_heads, S, head_dim]``.
:param v:
Value tensor ``[B, num_heads, S, head_dim]``.
:param attn_mask:
Optional mask forwarded to the inner attention. Either an
additive float mask broadcastable to ``[B, num_heads, S, S]``
or a bool mask where ``True`` means "attend".
:returns:
Output tensor ``[B, num_heads, S, head_dim]``. Callers are
responsible for any subsequent head-flattening reshape.
"""
# Manual attention is only beneficial when this SDPA was
# surrounded with input ``QuantStub``s (i.e. Q/K/V are arriving
# in INT8). Without surround, ``configure`` may still have left
# ``softmax_quant`` enabled — running manual attention then adds
# a softmax Q/DQ pair that pushes TensorRT off its FP16 fused
# MHA kernel onto the slower INT8-aware variant for no gain.
if not self.surrounded or not self.softmax_quant.enabled:
return self.attention(q, k, v, attn_mask)
# Honour the wrapped attention module's explicit ``scale`` if
# set — models that pre-scale Q themselves (chronos-2 + RoPE,
# for example) build with ``scale=1.0`` to disable the default
# ``1/sqrt(head_dim)`` scaling. Falling back to the default
# here would apply it twice and collapse softmax.
# Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the
# tensor Pow with a negative float exponent traces to ONNX as a
# ``Cast → complex128`` node that TRT 10.x can't parse.
if self.attention.scale is not None:
scale = self.attention.scale
else:
scale = 1.0 / math.sqrt(q.shape[-1])
attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weight = attn_weight.masked_fill(
~attn_mask,
float("-inf"),
)
else:
attn_weight = attn_weight + attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = self.softmax_quant(attn_weight)
if self.attention.training and self.attention.dropout > 0.0:
attn_weight = F.dropout(
attn_weight,
p=self.attention.dropout,
)
return torch.matmul(attn_weight, v)
def __repr__(self) -> str: # pragma: no cover
a = self.attention
qdq = "yes" if self.softmax_quant.enabled else "no"
return (
f"FusedScaledDotProductAttention("
f"num_heads={a.num_heads}, "
f"head_dim={a.head_dim}, "
f"dropout={a.dropout}, "
f"internal_qdq={qdq})"
)