Source code for embedl_deploy._internal.core.quantize.modules

# Copyright (C) 2026 Embedl AB

"""Quantization and fake-quantization ``nn.Module`` primitives.

These modules are backend-agnostic and can be used with any operator
(convolutions, linear layers, etc.).
"""

# mypy: disable-error-code="misc"
# torch lacks type stubs, so nn.Module resolves to Any.

import torch
from torch import nn

from embedl_deploy._internal.core.config import (
    CalibrationMethod,
    Observer,
    TensorQuantConfig,
)


def _make_observer(config: TensorQuantConfig) -> Observer:
    """Create a ``torch.ao`` observer from a ``TensorQuantConfig``."""
    qscheme = (
        torch.per_tensor_symmetric
        if config.symmetric
        else torch.per_tensor_affine
    )
    dtype = torch.qint8 if config.symmetric else torch.quint8
    cls: type[Observer] = config.calibration_method.value
    return cls(
        dtype=dtype,
        qscheme=qscheme,
        quant_min=config.quant_min,
        quant_max=config.quant_max,
    )


[docs] class WeightFakeQuantize(nn.Module): """Fake-quantize a weight tensor during the forward pass. Unlike :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` (which requires a calibration pass), this module computes ``scale`` and ``zero_point`` on-the-fly from the weight tensor. Correct for QAT where weights change each step. :param consumers: Set of modules that consume this module's output. :param n_bits: Number of quantization bits. :param symmetric: Symmetric (``zero_point = 0``) or asymmetric. :param per_channel: Use per-channel quantization along `channel_axis`. :param channel_axis: The axis along which per-channel scales are computed (default 0, i.e. output channels). """ def __init__( self, consumers: set[nn.Module], n_bits: int = 8, symmetric: bool = True, per_channel: bool = False, *, channel_axis: int = 0, ) -> None: super().__init__() self.consumers: set[nn.Module] = consumers self.config = TensorQuantConfig( n_bits=n_bits, symmetric=symmetric, per_channel=per_channel, ) self.channel_axis = channel_axis self.enabled = False
[docs] def forward(self, weight: torch.Tensor) -> torch.Tensor: """Fake-quantize `weight` using on-the-fly scale/zero_point.""" if not self.enabled: return weight if self.config.per_channel: return self._forward_per_channel(weight) return self._forward_per_tensor(weight)
def _forward_per_tensor(self, weight: torch.Tensor) -> torch.Tensor: """Per-tensor fake quantization.""" symmetric, quant_min, quant_max = self.config.quant_range() if symmetric: max_abs = weight.detach().abs().max() max_abs = torch.clamp(max_abs, min=1e-8) scale = max_abs / ((quant_max - quant_min) / 2) zero_point = torch.zeros( (), dtype=torch.int32, device=weight.device ) else: w_min = torch.min( weight.detach().min(), torch.zeros(1, device=weight.device), ) w_max = torch.max( weight.detach().max(), torch.zeros(1, device=weight.device), ) w_max = torch.max(w_max, w_min + 1e-8) scale = (w_max - w_min) / (quant_max - quant_min) zero_point = ( torch.round(-w_min / scale) .clamp(quant_min, quant_max) .to(torch.int32) ) return torch.fake_quantize_per_tensor_affine( weight, scale, zero_point, quant_min, quant_max ) def _forward_per_channel(self, weight: torch.Tensor) -> torch.Tensor: """Per-channel fake quantization along *self.channel_axis*.""" symmetric, quant_min, quant_max = self.config.quant_range() axis = self.channel_axis n_channels = weight.shape[axis] w = weight.detach().movedim(axis, 0).reshape(n_channels, -1) if symmetric: max_abs = w.abs().amax(dim=1) max_abs = torch.clamp(max_abs, min=1e-8) scales = max_abs / ((quant_max - quant_min) / 2) zero_points = torch.zeros( n_channels, dtype=torch.int32, device=weight.device ) else: w_min = torch.min( w.amin(dim=1), torch.zeros(n_channels, device=weight.device), ) w_max = torch.max( w.amax(dim=1), torch.zeros(n_channels, device=weight.device), ) w_max = torch.max(w_max, w_min + 1e-8) scales = (w_max - w_min) / (quant_max - quant_min) zero_points = ( torch.round(-w_min / scales) .clamp(quant_min, quant_max) .to(torch.int32) ) return torch.fake_quantize_per_channel_affine( weight, scales, zero_points, axis, quant_min, quant_max, ) def __repr__(self) -> str: sym = "sym" if self.config.symmetric else "asym" ch = " per-ch" if self.config.per_channel else "" return f"WeightFakeQuantize({self.config.n_bits}bit {sym}{ch})"
[docs] class QuantStub(nn.Module): """Quantize a floating-point tensor. During calibration the module delegates statistics collection to a ``torch.ao`` observer selected by *calibration_method*. After calibration, ``scale`` and ``zero_point`` are derived from the observer and used by :func:`torch.fake_quantize_per_tensor_affine` in the forward pass. :param consumers: Set of modules that consume this stub's output. :param n_bits: Number of quantization bits (default 8). :param symmetric: Symmetric or asymmetric quantization. :param calibration_method: Algorithm used to collect activation statistics. Defaults to :attr:`~embedl_deploy._internal.core.config.CalibrationMethod.MINMAX`. :param fixed_calibration: Fixed ``(scale, zero_point)`` tuple. When provided, calibration will not override the values. """ scale: torch.Tensor zero_point: torch.Tensor def __init__( self, consumers: set[nn.Module], n_bits: int = 8, symmetric: bool = True, calibration_method: CalibrationMethod = CalibrationMethod.MINMAX, *, fixed_calibration: tuple[float, int] | None = None, ) -> None: super().__init__() self.consumers: set[nn.Module] = consumers self.config = TensorQuantConfig( n_bits=n_bits, symmetric=symmetric, calibration_method=calibration_method, ) if fixed_calibration: scale, zero_point = fixed_calibration else: scale, zero_point = 1.0, 0 self.register_buffer("scale", torch.tensor(scale)) self.register_buffer( "zero_point", torch.tensor(zero_point, dtype=torch.int32), ) self._observer: Observer | None = None self.calibrating = False self.enabled = False self.fixed_parameters = fixed_calibration is not None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Fake-quantize `x`, updating observer stats if calibrating.""" if self.calibrating: if self._observer is None: self._observer = _make_observer(self.config) self._observer(x) if not self.enabled: return x return torch.fake_quantize_per_tensor_affine( x, self.scale, self.zero_point, self.config.quant_min, self.config.quant_max, )
[docs] def compute_parameters(self) -> None: """Derive ``scale`` and ``zero_point`` from the observer. :raises RuntimeError: If no data was observed during calibration. """ if self._observer is None: raise RuntimeError( "QuantStub did not observe valid data during calibration" ) scale, zero_point = self._observer.calculate_qparams() self.scale.copy_(scale.squeeze()) self.zero_point.copy_(zero_point.squeeze().to(torch.int32))
def __repr__(self) -> str: sym = "sym" if self.config.symmetric else "asym" method = self.config.calibration_method.name.lower() return ( f"QuantStub({self.config.n_bits}bit {sym} {method}, " f"scale={self.scale.item():.6g}, zp={self.zero_point.item()})" )