# Copyright (C) 2026 Embedl AB
"""Quantization and fake-quantization ``nn.Module`` primitives.
These modules are backend-agnostic and can be used with any operator
(convolutions, linear layers, etc.).
"""
# mypy: disable-error-code="misc"
# torch lacks type stubs, so nn.Module resolves to Any.
from dataclasses import dataclass, field
import torch
from torch import nn
from embedl_deploy._internal.core.quantize.config import (
CalibrationMethod,
Observer,
SmoothQuantConfig,
TensorQuantConfig,
)
def _make_observer(config: TensorQuantConfig) -> Observer:
"""Create a ``torch.ao`` observer from a ``TensorQuantConfig``."""
qscheme = (
torch.per_tensor_symmetric
if config.symmetric
else torch.per_tensor_affine
)
dtype = torch.qint8 if config.symmetric else torch.quint8
cls: type[Observer] = config.calibration_method.value
return cls(
dtype=dtype,
qscheme=qscheme,
quant_min=config.quant_min,
quant_max=config.quant_max,
)
[docs]
class WeightFakeQuantize(nn.Module):
"""Fake-quantize a weight tensor during the forward pass.
Unlike :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`
(which requires a calibration pass), this module computes ``scale`` and
``zero_point`` on-the-fly from the weight tensor. Correct for QAT where
weights change each step.
:param consumers:
Set of modules that consume this module's output.
:param n_bits:
Number of quantization bits.
:param symmetric:
Symmetric (``zero_point = 0``) or asymmetric.
:param per_channel:
Use per-channel quantization along `channel_axis`.
:param channel_axis:
The axis along which per-channel scales are computed (default 0, i.e.
output channels).
"""
# Class-level annotations so static analyzers recognize these buffer
# attributes as declared in the class body rather than flagging the
# register_buffer() calls as "defined outside __init__".
scale: torch.Tensor | None
zero_point: torch.Tensor | None
def __init__(
self,
consumers: set[nn.Module],
n_bits: int = 8,
symmetric: bool = True,
per_channel: bool = False,
*,
channel_axis: int = 0,
) -> None:
super().__init__()
self.consumers: set[nn.Module] = consumers
self.config = TensorQuantConfig(
n_bits=n_bits,
symmetric=symmetric,
per_channel=per_channel,
)
self.channel_axis = channel_axis
self.enabled = False
self.frozen = False
# Populated by freeze(); None until then.
self.register_buffer("scale", None)
self.register_buffer("zero_point", None)
def _symmetric_quant_params(
self,
w_min: torch.Tensor,
w_max: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute symmetric scale and zero_point from observed min/max."""
max_abs = torch.max(w_min.abs(), w_max.abs())
max_abs = torch.clamp(max_abs, min=1e-8)
q_min, q_max = self.config.quant_min, self.config.quant_max
scale = max_abs / ((q_max - q_min) / 2)
zero_point = torch.zeros_like(scale, dtype=torch.int32)
return scale, zero_point
def _asymmetric_quant_params(
self,
w_min: torch.Tensor,
w_max: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute asymmetric scale and zero_point from observed min/max."""
zeros = torch.zeros_like(w_min)
w_min = torch.min(w_min, zeros)
w_max = torch.max(w_max, zeros)
w_max = torch.max(w_max, w_min + 1e-8)
q_min, q_max = self.config.quant_min, self.config.quant_max
scale = (w_max - w_min) / (q_max - q_min)
zero_point = (
torch.round(-w_min / scale).clamp(q_min, q_max).to(torch.int32)
)
return scale, zero_point
def _compute_quant_params(
self,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute scale and zero_point from `weight`."""
w = weight.detach()
if self.config.per_channel:
axis = self.channel_axis
n = w.shape[axis]
w = w.movedim(axis, 0).reshape(n, -1)
w_min = w.amin(dim=1)
w_max = w.amax(dim=1)
else:
w_min = w.min()
w_max = w.max()
if self.config.symmetric:
return self._symmetric_quant_params(w_min, w_max)
return self._asymmetric_quant_params(w_min, w_max)
[docs]
def freeze(self, weight: torch.Tensor) -> None:
"""Compute scale/zero_point from `weight` and store as constant buffers.
After calling this, :meth:`forward` uses the stored constants instead
of recomputing them each call. ONNX export will therefore emit a
``Constant`` node for the scale rather than the full
``Abs → ReduceMax → Div`` arithmetic, which TensorRT requires for
explicit-quantization fusion.
:param weight:
The weight tensor to calibrate against (typically ``conv.weight``).
"""
self.scale, self.zero_point = self._compute_quant_params(weight)
self.frozen = True
[docs]
def forward(self, weight: torch.Tensor) -> torch.Tensor:
"""Fake-quantize `weight` using stored (frozen) or on-the-fly scale."""
if not self.enabled:
return weight
if self.frozen:
scale, zero_point = self.scale, self.zero_point
assert scale is not None
assert zero_point is not None
else:
scale, zero_point = self._compute_quant_params(weight)
q_min, q_max = self.config.quant_min, self.config.quant_max
if self.config.per_channel:
return torch.fake_quantize_per_channel_affine(
weight, scale, zero_point, self.channel_axis, q_min, q_max
)
return torch.fake_quantize_per_tensor_affine(
weight, scale, zero_point, q_min, q_max
)
def __repr__(self) -> str: # pragma: no cover
sym = "sym" if self.config.symmetric else "asym"
ch = " per-ch" if self.config.per_channel else ""
frozen = " frozen" if self.frozen else ""
return f"WeightFakeQuantize({self.config.n_bits}bit {sym}{ch}{frozen})"
[docs]
class QuantStub(nn.Module):
"""Quantize a floating-point tensor.
During calibration the module delegates statistics collection to a
``torch.ao`` observer selected by *calibration_method*. After calibration,
``scale`` and ``zero_point`` are derived from the observer and used by
:func:`torch.fake_quantize_per_tensor_affine` in the forward pass.
:param consumers:
Set of modules that consume this stub's output.
:param n_bits:
Number of quantization bits (default 8).
:param symmetric:
Symmetric or asymmetric quantization.
:param calibration_method:
Algorithm used to collect activation statistics. Defaults to
:attr:`~embedl_deploy._internal.core.quantize.config.CalibrationMethod.MINMAX`.
:param fixed_calibration:
Fixed ``(scale, zero_point)`` tuple. When provided, calibration
will not override the values.
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(
self,
consumers: set[nn.Module],
n_bits: int = 8,
symmetric: bool = True,
calibration_method: CalibrationMethod = CalibrationMethod.MINMAX,
*,
fixed_calibration: tuple[float, int] | None = None,
) -> None:
super().__init__()
self.consumers: set[nn.Module] = consumers
self.config = TensorQuantConfig(
n_bits=n_bits,
symmetric=symmetric,
calibration_method=calibration_method,
)
if fixed_calibration:
scale, zero_point = fixed_calibration
else:
scale, zero_point = 1.0, 0
self.register_buffer("scale", torch.tensor(scale))
self.register_buffer(
"zero_point",
torch.tensor(zero_point, dtype=torch.int32),
)
self._observer: Observer | None = None
self.calibrating = False
self.enabled = False
self.fixed_parameters = fixed_calibration is not None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Fake-quantize `x`, updating observer stats if calibrating."""
if self.calibrating:
if self._observer is None:
self._observer = _make_observer(self.config)
self._observer(x)
if not self.enabled:
return x
return torch.fake_quantize_per_tensor_affine(
x,
self.scale,
self.zero_point,
self.config.quant_min,
self.config.quant_max,
)
[docs]
def compute_parameters(self) -> None:
"""Derive ``scale`` and ``zero_point`` from the observer.
:raises RuntimeError:
If no data was observed during calibration.
"""
if self._observer is None:
raise RuntimeError(
"QuantStub did not observe valid data during calibration"
)
scale, zero_point = self._observer.calculate_qparams()
self.scale.copy_(scale.squeeze())
self.zero_point.copy_(zero_point.squeeze().to(torch.int32))
def __repr__(self) -> str: # pragma: no cover
sym = "sym" if self.config.symmetric else "asym"
method = self.config.calibration_method.name.lower()
return (
f"QuantStub({self.config.n_bits}bit {sym} {method}, "
f"scale={self.scale.item():.6g}, zp={self.zero_point.item()})"
)
@dataclass(eq=False)
class SmoothQuantObserver:
"""Per-LayerNorm smooth-quant state.
Created in a disabled state by
:class:`~embedl_deploy._internal.tensorrt.modules.linear.FusedLayerNorm`.
Enabled and configured by
:func:`~embedl_deploy._internal.core.quantize.main.configure`.
Downstream linears are populated by
:class:`~embedl_deploy._internal.tensorrt.patterns.smoothings.PrepareSmoothQuantPattern`.
Consumed by
:func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_smooth_quant`,
which skips observers with empty ``downstream_linears``.
"""
#: Set of modules that consume this observer's output.
consumers: set[nn.Module]
#: The ``LayerNorm`` whose affine parameters are divided by the
#: smoothing scales.
layer_norm: nn.LayerNorm
#: Whether this observer is active. Set to ``True`` by
#: ``configure``; the calibration pass skips disabled observers.
enabled: bool = False
#: Configuration copied from ``QuantConfig.smooth_quant`` by
#: ``configure``.
config: SmoothQuantConfig = field(default_factory=SmoothQuantConfig)
#: Downstream linear layers whose weight columns are multiplied
#: by the smoothing scales.
downstream_linears: list[nn.Linear] = field(
default_factory=list,
)
#: Per-channel activation maximum collected during calibration.
#: Populated by the calibration forward loop, reset to ``None``
#: after scales are applied.
activation_max: torch.Tensor | None = None
def register_forward_hook(
self,
) -> torch.utils.hooks.RemovableHandle:
"""Register a forward hook that observes per-channel activation maxima.
Hooks onto ``layer_norm`` and accumulates the element-wise
maximum of the absolute output values along all dimensions
except the last (channel) dimension. Results are stored in
``activation_max``.
:returns:
A handle that can be used to remove the hook.
"""
def hook(
_mod: nn.Module,
_input: tuple[torch.Tensor, ...],
output: torch.Tensor,
) -> None:
amax = (
output.detach().abs().amax(dim=tuple(range(output.ndim - 1)))
)
if self.activation_max is None:
self.activation_max = amax.clone()
else:
self.activation_max = torch.max(self.activation_max, amax)
return self.layer_norm.register_forward_hook(hook)
def compute_smooth_scales(self) -> None:
"""Derive ``scales`` from observed per-channel activation maxima.
Divides the ``layer_norm`` weight (and bias when present) by
the scales and multiplies each downstream linear weight by the
scales, then resets ``activation_max`` to ``None``.
:raises RuntimeError:
If no data was observed during calibration.
"""
if not self.downstream_linears:
return
if self.activation_max is None:
raise RuntimeError(
"SmoothQuantObserver did not observe valid data "
"during calibration"
)
weight_maxes = [
lin.weight.detach().abs().amax(dim=0)
for lin in self.downstream_linears
]
weight_max = torch.stack(weight_maxes).amax(dim=0)
eps = torch.tensor(1e-5, device=self.activation_max.device)
act_max = torch.clamp(self.activation_max, min=eps)
weight_max = torch.clamp(weight_max, min=eps)
alpha = self.config.alpha
scales = act_max.pow(alpha) / weight_max.pow(1.0 - alpha)
with torch.no_grad():
self.layer_norm.weight.div_(scales)
if self.layer_norm.bias is not None:
self.layer_norm.bias.div_(scales)
for lin in self.downstream_linears:
lin.weight.mul_(scales.unsqueeze(0))
self.activation_max = None