Source code for embedl_deploy._internal.core.quantize.stubs

# Copyright (C) 2026 Embedl AB

"""Quantization and fake-quantization ``nn.Module`` primitives.

These modules are backend-agnostic and can be used with any operator
(convolutions, linear layers, etc.).
"""

# mypy: disable-error-code="misc"
# torch lacks type stubs, so nn.Module resolves to Any.

from dataclasses import dataclass, field

import torch
from torch import nn

from embedl_deploy._internal.core.quantize.config import (
    CalibrationMethod,
    Observer,
    SmoothQuantConfig,
    TensorQuantConfig,
)


def _make_observer(config: TensorQuantConfig) -> Observer:
    """Create a ``torch.ao`` observer from a ``TensorQuantConfig``."""
    qscheme = (
        torch.per_tensor_symmetric
        if config.symmetric
        else torch.per_tensor_affine
    )
    dtype = torch.qint8 if config.symmetric else torch.quint8
    cls: type[Observer] = config.calibration_method.value
    return cls(
        dtype=dtype,
        qscheme=qscheme,
        quant_min=config.quant_min,
        quant_max=config.quant_max,
    )


[docs] class WeightFakeQuantize(nn.Module): """Fake-quantize a weight tensor during the forward pass. Unlike :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` (which requires a calibration pass), this module computes ``scale`` and ``zero_point`` on-the-fly from the weight tensor. Correct for QAT where weights change each step. :param consumers: Set of modules that consume this module's output. :param n_bits: Number of quantization bits. :param symmetric: Symmetric (``zero_point = 0``) or asymmetric. :param per_channel: Use per-channel quantization along `channel_axis`. :param channel_axis: The axis along which per-channel scales are computed (default 0, i.e. output channels). """ # Class-level annotations so static analyzers recognize these buffer # attributes as declared in the class body rather than flagging the # register_buffer() calls as "defined outside __init__". scale: torch.Tensor | None zero_point: torch.Tensor | None def __init__( self, consumers: set[nn.Module], n_bits: int = 8, symmetric: bool = True, per_channel: bool = False, *, channel_axis: int = 0, ) -> None: super().__init__() self.consumers: set[nn.Module] = consumers self.config = TensorQuantConfig( n_bits=n_bits, symmetric=symmetric, per_channel=per_channel, ) self.channel_axis = channel_axis self.enabled = False self.frozen = False # Populated by freeze(); None until then. self.register_buffer("scale", None) self.register_buffer("zero_point", None) def _symmetric_quant_params( self, w_min: torch.Tensor, w_max: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute symmetric scale and zero_point from observed min/max.""" max_abs = torch.max(w_min.abs(), w_max.abs()) max_abs = torch.clamp(max_abs, min=1e-8) q_min, q_max = self.config.quant_min, self.config.quant_max scale = max_abs / ((q_max - q_min) / 2) zero_point = torch.zeros_like(scale, dtype=torch.int32) return scale, zero_point def _asymmetric_quant_params( self, w_min: torch.Tensor, w_max: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute asymmetric scale and zero_point from observed min/max.""" zeros = torch.zeros_like(w_min) w_min = torch.min(w_min, zeros) w_max = torch.max(w_max, zeros) w_max = torch.max(w_max, w_min + 1e-8) q_min, q_max = self.config.quant_min, self.config.quant_max scale = (w_max - w_min) / (q_max - q_min) zero_point = ( torch.round(-w_min / scale).clamp(q_min, q_max).to(torch.int32) ) return scale, zero_point def _compute_quant_params( self, weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute scale and zero_point from `weight`.""" w = weight.detach() if self.config.per_channel: axis = self.channel_axis n = w.shape[axis] w = w.movedim(axis, 0).reshape(n, -1) w_min = w.amin(dim=1) w_max = w.amax(dim=1) else: w_min = w.min() w_max = w.max() if self.config.symmetric: return self._symmetric_quant_params(w_min, w_max) return self._asymmetric_quant_params(w_min, w_max)
[docs] def freeze(self, weight: torch.Tensor) -> None: """Compute scale/zero_point from `weight` and store as constant buffers. After calling this, :meth:`forward` uses the stored constants instead of recomputing them each call. ONNX export will therefore emit a ``Constant`` node for the scale rather than the full ``Abs → ReduceMax → Div`` arithmetic, which TensorRT requires for explicit-quantization fusion. :param weight: The weight tensor to calibrate against (typically ``conv.weight``). """ self.scale, self.zero_point = self._compute_quant_params(weight) self.frozen = True
[docs] def forward(self, weight: torch.Tensor) -> torch.Tensor: """Fake-quantize `weight` using stored (frozen) or on-the-fly scale.""" if not self.enabled: return weight if self.frozen: scale, zero_point = self.scale, self.zero_point assert scale is not None assert zero_point is not None else: scale, zero_point = self._compute_quant_params(weight) q_min, q_max = self.config.quant_min, self.config.quant_max if self.config.per_channel: return torch.fake_quantize_per_channel_affine( weight, scale, zero_point, self.channel_axis, q_min, q_max ) return torch.fake_quantize_per_tensor_affine( weight, scale, zero_point, q_min, q_max )
def __repr__(self) -> str: # pragma: no cover sym = "sym" if self.config.symmetric else "asym" ch = " per-ch" if self.config.per_channel else "" frozen = " frozen" if self.frozen else "" return f"WeightFakeQuantize({self.config.n_bits}bit {sym}{ch}{frozen})"
[docs] class QuantStub(nn.Module): """Quantize a floating-point tensor. During calibration the module delegates statistics collection to a ``torch.ao`` observer selected by *calibration_method*. After calibration, ``scale`` and ``zero_point`` are derived from the observer and used by :func:`torch.fake_quantize_per_tensor_affine` in the forward pass. :param consumers: Set of modules that consume this stub's output. :param n_bits: Number of quantization bits (default 8). :param symmetric: Symmetric or asymmetric quantization. :param calibration_method: Algorithm used to collect activation statistics. Defaults to :attr:`~embedl_deploy._internal.core.quantize.config.CalibrationMethod.MINMAX`. :param fixed_calibration: Fixed ``(scale, zero_point)`` tuple. When provided, calibration will not override the values. """ scale: torch.Tensor zero_point: torch.Tensor def __init__( self, consumers: set[nn.Module], n_bits: int = 8, symmetric: bool = True, calibration_method: CalibrationMethod = CalibrationMethod.MINMAX, *, fixed_calibration: tuple[float, int] | None = None, ) -> None: super().__init__() self.consumers: set[nn.Module] = consumers self.config = TensorQuantConfig( n_bits=n_bits, symmetric=symmetric, calibration_method=calibration_method, ) if fixed_calibration: scale, zero_point = fixed_calibration else: scale, zero_point = 1.0, 0 self.register_buffer("scale", torch.tensor(scale)) self.register_buffer( "zero_point", torch.tensor(zero_point, dtype=torch.int32), ) self._observer: Observer | None = None self.calibrating = False self.enabled = False self.fixed_parameters = fixed_calibration is not None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Fake-quantize `x`, updating observer stats if calibrating.""" if self.calibrating: if self._observer is None: self._observer = _make_observer(self.config) self._observer(x) if not self.enabled: return x return torch.fake_quantize_per_tensor_affine( x, self.scale, self.zero_point, self.config.quant_min, self.config.quant_max, )
[docs] def compute_parameters(self) -> None: """Derive ``scale`` and ``zero_point`` from the observer. :raises RuntimeError: If no data was observed during calibration. """ if self._observer is None: raise RuntimeError( "QuantStub did not observe valid data during calibration" ) scale, zero_point = self._observer.calculate_qparams() self.scale.copy_(scale.squeeze()) self.zero_point.copy_(zero_point.squeeze().to(torch.int32))
def __repr__(self) -> str: # pragma: no cover sym = "sym" if self.config.symmetric else "asym" method = self.config.calibration_method.name.lower() return ( f"QuantStub({self.config.n_bits}bit {sym} {method}, " f"scale={self.scale.item():.6g}, zp={self.zero_point.item()})" )
@dataclass(eq=False) class SmoothQuantObserver: """Per-LayerNorm smooth-quant state. Created in a disabled state by :class:`~embedl_deploy._internal.tensorrt.modules.linear.FusedLayerNorm`. Enabled and configured by :func:`~embedl_deploy._internal.core.quantize.main.configure`. Downstream linears are populated by :class:`~embedl_deploy._internal.tensorrt.patterns.smoothings.PrepareSmoothQuantPattern`. Consumed by :func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_smooth_quant`, which skips observers with empty ``downstream_linears``. """ #: Set of modules that consume this observer's output. consumers: set[nn.Module] #: The ``LayerNorm`` whose affine parameters are divided by the #: smoothing scales. layer_norm: nn.LayerNorm #: Whether this observer is active. Set to ``True`` by #: ``configure``; the calibration pass skips disabled observers. enabled: bool = False #: Configuration copied from ``QuantConfig.smooth_quant`` by #: ``configure``. config: SmoothQuantConfig = field(default_factory=SmoothQuantConfig) #: Downstream linear layers whose weight columns are multiplied #: by the smoothing scales. downstream_linears: list[nn.Linear] = field( default_factory=list, ) #: Per-channel activation maximum collected during calibration. #: Populated by the calibration forward loop, reset to ``None`` #: after scales are applied. activation_max: torch.Tensor | None = None def register_forward_hook( self, ) -> torch.utils.hooks.RemovableHandle: """Register a forward hook that observes per-channel activation maxima. Hooks onto ``layer_norm`` and accumulates the element-wise maximum of the absolute output values along all dimensions except the last (channel) dimension. Results are stored in ``activation_max``. :returns: A handle that can be used to remove the hook. """ def hook( _mod: nn.Module, _input: tuple[torch.Tensor, ...], output: torch.Tensor, ) -> None: amax = ( output.detach().abs().amax(dim=tuple(range(output.ndim - 1))) ) if self.activation_max is None: self.activation_max = amax.clone() else: self.activation_max = torch.max(self.activation_max, amax) return self.layer_norm.register_forward_hook(hook) def compute_smooth_scales(self) -> None: """Derive ``scales`` from observed per-channel activation maxima. Divides the ``layer_norm`` weight (and bias when present) by the scales and multiplies each downstream linear weight by the scales, then resets ``activation_max`` to ``None``. :raises RuntimeError: If no data was observed during calibration. """ if not self.downstream_linears: return if self.activation_max is None: raise RuntimeError( "SmoothQuantObserver did not observe valid data " "during calibration" ) weight_maxes = [ lin.weight.detach().abs().amax(dim=0) for lin in self.downstream_linears ] weight_max = torch.stack(weight_maxes).amax(dim=0) eps = torch.tensor(1e-5, device=self.activation_max.device) act_max = torch.clamp(self.activation_max, min=eps) weight_max = torch.clamp(weight_max, min=eps) alpha = self.config.alpha scales = act_max.pow(alpha) / weight_max.pow(1.0 - alpha) with torch.no_grad(): self.layer_norm.weight.div_(scales) if self.layer_norm.bias is not None: self.layer_norm.bias.div_(scales) for lin in self.downstream_linears: lin.weight.mul_(scales.unsqueeze(0)) self.activation_max = None