# Copyright (C) 2026 Embedl AB
"""Quantization and fake-quantization ``nn.Module`` primitives.
These modules are backend-agnostic and can be used with any operator
(convolutions, linear layers, etc.).
"""
# mypy: disable-error-code="misc"
# torch lacks type stubs, so nn.Module resolves to Any.
import torch
from torch import nn
from embedl_deploy._internal.core.config import (
CalibrationMethod,
Observer,
TensorQuantConfig,
)
def _make_observer(config: TensorQuantConfig) -> Observer:
"""Create a ``torch.ao`` observer from a ``TensorQuantConfig``."""
qscheme = (
torch.per_tensor_symmetric
if config.symmetric
else torch.per_tensor_affine
)
dtype = torch.qint8 if config.symmetric else torch.quint8
cls: type[Observer] = config.calibration_method.value
return cls(
dtype=dtype,
qscheme=qscheme,
quant_min=config.quant_min,
quant_max=config.quant_max,
)
[docs]
class WeightFakeQuantize(nn.Module):
"""Fake-quantize a weight tensor during the forward pass.
Unlike :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`
(which requires a calibration pass), this module computes ``scale`` and
``zero_point`` on-the-fly from the weight tensor. Correct for QAT where
weights change each step.
:param consumers:
Set of modules that consume this module's output.
:param n_bits:
Number of quantization bits.
:param symmetric:
Symmetric (``zero_point = 0``) or asymmetric.
:param per_channel:
Use per-channel quantization along `channel_axis`.
:param channel_axis:
The axis along which per-channel scales are computed (default 0, i.e.
output channels).
"""
def __init__(
self,
consumers: set[nn.Module],
n_bits: int = 8,
symmetric: bool = True,
per_channel: bool = False,
*,
channel_axis: int = 0,
) -> None:
super().__init__()
self.consumers: set[nn.Module] = consumers
self.config = TensorQuantConfig(
n_bits=n_bits,
symmetric=symmetric,
per_channel=per_channel,
)
self.channel_axis = channel_axis
self.enabled = False
[docs]
def forward(self, weight: torch.Tensor) -> torch.Tensor:
"""Fake-quantize `weight` using on-the-fly scale/zero_point."""
if not self.enabled:
return weight
if self.config.per_channel:
return self._forward_per_channel(weight)
return self._forward_per_tensor(weight)
def _forward_per_tensor(self, weight: torch.Tensor) -> torch.Tensor:
"""Per-tensor fake quantization."""
symmetric, quant_min, quant_max = self.config.quant_range()
if symmetric:
max_abs = weight.detach().abs().max()
max_abs = torch.clamp(max_abs, min=1e-8)
scale = max_abs / ((quant_max - quant_min) / 2)
zero_point = torch.zeros(
(), dtype=torch.int32, device=weight.device
)
else:
w_min = torch.min(
weight.detach().min(),
torch.zeros(1, device=weight.device),
)
w_max = torch.max(
weight.detach().max(),
torch.zeros(1, device=weight.device),
)
w_max = torch.max(w_max, w_min + 1e-8)
scale = (w_max - w_min) / (quant_max - quant_min)
zero_point = (
torch.round(-w_min / scale)
.clamp(quant_min, quant_max)
.to(torch.int32)
)
return torch.fake_quantize_per_tensor_affine(
weight, scale, zero_point, quant_min, quant_max
)
def _forward_per_channel(self, weight: torch.Tensor) -> torch.Tensor:
"""Per-channel fake quantization along *self.channel_axis*."""
symmetric, quant_min, quant_max = self.config.quant_range()
axis = self.channel_axis
n_channels = weight.shape[axis]
w = weight.detach().movedim(axis, 0).reshape(n_channels, -1)
if symmetric:
max_abs = w.abs().amax(dim=1)
max_abs = torch.clamp(max_abs, min=1e-8)
scales = max_abs / ((quant_max - quant_min) / 2)
zero_points = torch.zeros(
n_channels, dtype=torch.int32, device=weight.device
)
else:
w_min = torch.min(
w.amin(dim=1),
torch.zeros(n_channels, device=weight.device),
)
w_max = torch.max(
w.amax(dim=1),
torch.zeros(n_channels, device=weight.device),
)
w_max = torch.max(w_max, w_min + 1e-8)
scales = (w_max - w_min) / (quant_max - quant_min)
zero_points = (
torch.round(-w_min / scales)
.clamp(quant_min, quant_max)
.to(torch.int32)
)
return torch.fake_quantize_per_channel_affine(
weight,
scales,
zero_points,
axis,
quant_min,
quant_max,
)
def __repr__(self) -> str:
sym = "sym" if self.config.symmetric else "asym"
ch = " per-ch" if self.config.per_channel else ""
return f"WeightFakeQuantize({self.config.n_bits}bit {sym}{ch})"
[docs]
class QuantStub(nn.Module):
"""Quantize a floating-point tensor.
During calibration the module delegates statistics collection to a
``torch.ao`` observer selected by *calibration_method*. After calibration,
``scale`` and ``zero_point`` are derived from the observer and used by
:func:`torch.fake_quantize_per_tensor_affine` in the forward pass.
:param consumers:
Set of modules that consume this stub's output.
:param n_bits:
Number of quantization bits (default 8).
:param symmetric:
Symmetric or asymmetric quantization.
:param calibration_method:
Algorithm used to collect activation statistics. Defaults to
:attr:`~embedl_deploy._internal.core.config.CalibrationMethod.MINMAX`.
:param fixed_calibration:
Fixed ``(scale, zero_point)`` tuple. When provided, calibration
will not override the values.
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(
self,
consumers: set[nn.Module],
n_bits: int = 8,
symmetric: bool = True,
calibration_method: CalibrationMethod = CalibrationMethod.MINMAX,
*,
fixed_calibration: tuple[float, int] | None = None,
) -> None:
super().__init__()
self.consumers: set[nn.Module] = consumers
self.config = TensorQuantConfig(
n_bits=n_bits,
symmetric=symmetric,
calibration_method=calibration_method,
)
if fixed_calibration:
scale, zero_point = fixed_calibration
else:
scale, zero_point = 1.0, 0
self.register_buffer("scale", torch.tensor(scale))
self.register_buffer(
"zero_point",
torch.tensor(zero_point, dtype=torch.int32),
)
self._observer: Observer | None = None
self.calibrating = False
self.enabled = False
self.fixed_parameters = fixed_calibration is not None
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Fake-quantize `x`, updating observer stats if calibrating."""
if self.calibrating:
if self._observer is None:
self._observer = _make_observer(self.config)
self._observer(x)
if not self.enabled:
return x
return torch.fake_quantize_per_tensor_affine(
x,
self.scale,
self.zero_point,
self.config.quant_min,
self.config.quant_max,
)
[docs]
def compute_parameters(self) -> None:
"""Derive ``scale`` and ``zero_point`` from the observer.
:raises RuntimeError:
If no data was observed during calibration.
"""
if self._observer is None:
raise RuntimeError(
"QuantStub did not observe valid data during calibration"
)
scale, zero_point = self._observer.calculate_qparams()
self.scale.copy_(scale.squeeze())
self.zero_point.copy_(zero_point.squeeze().to(torch.int32))
def __repr__(self) -> str:
sym = "sym" if self.config.symmetric else "asym"
method = self.config.calibration_method.name.lower()
return (
f"QuantStub({self.config.n_bits}bit {sym} {method}, "
f"scale={self.scale.item():.6g}, zp={self.zero_point.item()})"
)