# Copyright (C) 2026 Embedl AB
"""Attention sub-modules introduced by MHA decomposition.
These plain ``nn.Module`` subclasses replace the opaque
``nn.MultiheadAttention`` in the FX graph. Phase 2 creates ``Fused*`` wrappers
around them for Q/DQ insertion.
"""
import torch
import torch.nn.functional as F
from torch import nn
from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
QuantStub,
WeightFakeQuantize,
)
class MHAInProjection(ConvertedModule):
"""Packed Q/K/V input projection: ``Linear(E, 3E) → chunk → reshape``.
:param linear:
An ``nn.Linear(embed_dim, 3 * embed_dim)`` holding the packed Q/K/V
projection weights and optional bias.
:param num_heads:
Number of attention heads.
:param head_dim:
Dimension of each head (``embed_dim // num_heads``).
"""
def __init__(
self,
linear: nn.Linear,
num_heads: int,
head_dim: int,
) -> None:
super().__init__()
self.linear = linear
self.num_heads = num_heads
self.head_dim = head_dim
def forward(
self,
query: torch.Tensor,
_key: torch.Tensor,
_value: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Project input to per-head ``(Q, K, V)`` tensors.
Only `query` is used; `_key` and `_value` are accepted so the signature
matches the ``MultiheadAttention(q, k, v)`` call-site (they are
identical for self-attention).
:param query:
Input tensor of shape ``[B, S, E]``.
:returns:
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
"""
batch, seq, _ = query.shape
qkv = self.linear(query)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
return q, k, v
def __repr__(self) -> str:
embed_dim = self.num_heads * self.head_dim
return (
f"MHAInProjection("
f"embed_dim={embed_dim}, "
f"num_heads={self.num_heads}, "
f"head_dim={self.head_dim})"
)
class ScaledDotProductAttention(ConvertedModule):
"""Core attention: ``softmax(Q · Kᵀ / √H) · V``.
:param num_heads:
Number of attention heads.
:param head_dim:
Dimension of each head.
:param dropout:
Dropout probability (applied during training only).
"""
def __init__(
self,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""Compute scaled dot-product attention.
:param q:
Query tensor ``[B, num_heads, S, head_dim]``.
:param k:
Key tensor ``[B, num_heads, S, head_dim]``.
:param v:
Value tensor ``[B, num_heads, S, head_dim]``.
:returns:
Output tensor ``[B, S, embed_dim]``.
"""
# pylint: disable-next=not-callable
attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout if self.training else 0.0,
)
batch, _, seq, _ = attn.shape
return attn.transpose(1, 2).contiguous().view(batch, seq, -1)
def __repr__(self) -> str:
return (
f"ScaledDotProductAttention("
f"num_heads={self.num_heads}, "
f"head_dim={self.head_dim}, "
f"dropout={self.dropout})"
)
[docs]
class FusedMHAInProjection(FusedModule):
"""Fused wrapper for :class:`MHAInProjection`.
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
around the input projection and to attach a
:class:`~embedl_deploy._internal.core.quantize.modules.WeightFakeQuantize`
for the packed linear weight.
:param in_proj:
The :class:`MHAInProjection` from the decomposed MHA.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, in_proj: MHAInProjection) -> None:
super().__init__()
self.in_proj = in_proj
self.weight_fake_quant = WeightFakeQuantize({self})
[docs]
def forward(
self,
query: torch.Tensor,
_key: torch.Tensor,
_value: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Project input to per-head ``(Q, K, V)`` tensors.
When :attr:`weight_fake_quant` is set, fake-quantizes the packed
projection weight before the linear operation. Only `query` is
used; `_key` and `_value` are accepted to match the call-site
signature but ignored for self-attention.
:param query:
Input tensor of shape ``[B, S, E]``.
:returns:
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
"""
weight = self.weight_fake_quant(self.in_proj.linear.weight)
batch, seq, _ = query.shape
# pylint: disable-next=not-callable
qkv = F.linear(query, weight, self.in_proj.linear.bias)
q, k, v = qkv.chunk(3, dim=-1)
num_heads = self.in_proj.num_heads
head_dim = self.in_proj.head_dim
q = q.view(batch, seq, num_heads, head_dim).transpose(1, 2)
k = k.view(batch, seq, num_heads, head_dim).transpose(1, 2)
v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
return q, k, v
def __repr__(self) -> str:
embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
return (
f"FusedMHAInProjection("
f"embed_dim={embed_dim}, "
f"num_heads={self.in_proj.num_heads}, "
f"head_dim={self.in_proj.head_dim})"
)
[docs]
class FusedScaledDotProductAttention(FusedModule):
"""Fused wrapper for :class:`ScaledDotProductAttention`.
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
on each of the three inputs (Q, K, V).
Additionally holds an internal
:class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` between
the softmax output and the second batched matrix multiply (BMM2). When
``None`` (before Q/DQ insertion) the forward pass is numerically identical
to the unwrapped :class:`ScaledDotProductAttention`.
:param attention:
The :class:`ScaledDotProductAttention` from the decomposed MHA.
"""
inputs_to_quantize: set[int] = set()
def __init__(self, attention: ScaledDotProductAttention) -> None:
super().__init__()
self.attention = attention
self.softmax_quant = QuantStub(
consumers={self},
n_bits=8,
symmetric=True,
fixed_calibration=(1.0 / 127, 0),
)
[docs]
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""Compute scaled dot-product attention.
Performs manual attention with an
internal quantization step between softmax and BMM2.
:param q:
Query tensor ``[B, num_heads, S, head_dim]``.
:param k:
Key tensor ``[B, num_heads, S, head_dim]``.
:param v:
Value tensor ``[B, num_heads, S, head_dim]``.
:returns:
Output tensor ``[B, S, embed_dim]``.
"""
if not self.softmax_quant.enabled:
return self.attention(q, k, v)
batch, _, seq, head_dim = q.shape
scale = head_dim ** (-0.5)
attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = self.softmax_quant(attn_weight)
if self.attention.training and self.attention.dropout > 0.0:
attn_weight = F.dropout(
attn_weight,
p=self.attention.dropout,
)
attn = torch.matmul(attn_weight, v)
return attn.transpose(1, 2).contiguous().view(batch, seq, -1)
def __repr__(self) -> str:
a = self.attention
qdq = "yes" if self.softmax_quant.enabled else "no"
return (
f"FusedScaledDotProductAttention("
f"num_heads={a.num_heads}, "
f"head_dim={a.head_dim}, "
f"dropout={a.dropout}, "
f"internal_qdq={qdq})"
)