Source code for embedl_deploy._internal.core.quantize.config

# Copyright (C) 2026 Embedl AB

"""Quantization configuration.

:class:`~embedl_deploy._internal.core.quantize.config.QuantConfig` bundles the
settings that control how activations and weights are quantized — bit-width,
symmetric vs. asymmetric range, and per-channel granularity — separately for
each tensor class.

The config is passed to ``configure()``, which copies the relevant sub-config
onto each activation stub and weight fake-quantize module found on the fused
modules.

Example::

    config = QuantConfig(
        activation=TensorQuantConfig(n_bits=8, symmetric=True),
        weight=TensorQuantConfig(n_bits=8, symmetric=True, per_channel=True),
    )
    configure(fused_model, config)
"""

import enum
from dataclasses import dataclass, field
from typing import TypeAlias

from torch import nn
from torch.ao.quantization.observer import (
    HistogramObserver,
    MinMaxObserver,
    MovingAverageMinMaxObserver,
)

#: Union of ``torch.ao`` observer types used for activation calibration.
Observer: TypeAlias = (
    MinMaxObserver | MovingAverageMinMaxObserver | HistogramObserver
)


[docs] class CalibrationMethod(enum.Enum): """Algorithm for collecting activation statistics during PTQ calibration. Each member's value is the ``torch.ao`` observer class instantiated and attached to a :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` on the first calibration forward pass. """ #: Track the global minimum and maximum (default, fastest). MINMAX = MinMaxObserver #: Exponential moving average of min/max — less sensitive to outliers. MOVING_AVERAGE_MINMAX = MovingAverageMinMaxObserver #: Build a histogram and search for the optimal quantization range. HISTOGRAM = HistogramObserver
[docs] @dataclass(frozen=True) class TensorQuantConfig: """Quantization settings for a single tensor class (activation or weight). :param n_bits: Number of bits for the quantized representation. Must be between 2 and 16 inclusive. :param symmetric: When ``True`` the quantized range is centered on zero (``zero_point = 0``). When ``False`` an asymmetric range is used, allowing better coverage of distributions that are not centered on zero. :param per_channel: When ``True`` *and* when used for weight quantization, a separate scale/zero-point is computed along the output-channel axis (axis 0). Defaults to ``False`` for backward compatibility but should be set to ``True`` for production quantization of Conv/Linear weights. """ #: Number of quantization bits (2–16). n_bits: int = 8 #: Symmetric (zero_point = 0) or asymmetric range. symmetric: bool = True #: Per-channel quantization (weight only). per_channel: bool = False #: Calibration algorithm used by #: :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` to #: collect activation statistics. Only relevant for activation configs; #: weight quantization computes scale/zero-point on-the-fly. calibration_method: CalibrationMethod = CalibrationMethod.MINMAX def __post_init__(self) -> None: if not 2 <= self.n_bits <= 16: raise ValueError( f"n_bits must be between 2 and 16, got {self.n_bits}" ) @property def _half_range(self) -> int: """Half the quantized integer range.""" half_range: int = 2 ** (self.n_bits - 1) return half_range @property def quant_min(self) -> int: """Minimum representable integer value.""" if self.symmetric: return -self._half_range return 0 @property def quant_max(self) -> int: """Maximum representable integer value.""" if self.symmetric: return self._half_range - 1 return 2 * self._half_range - 1
[docs] @dataclass(frozen=True) class SmoothQuantConfig: """SmoothQuant migration settings. Controls the per-channel weight/activation redistribution applied by :func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_smooth_quant`. :param alpha: Migration strength in ``[0, 1]``. ``0`` keeps all difficulty on activations; ``1`` pushes it entirely to weights. """ #: Migration strength in ``[0, 1]``. alpha: float = 0.5 def __post_init__(self) -> None: if not 0.0 <= self.alpha <= 1.0: raise ValueError(f"alpha must be in [0, 1], got {self.alpha}")
[docs] @dataclass(frozen=True) class ModulesToSkip: """Specifies which modules or module types to leave disabled during configure.""" #: Modules to leave stub quantization disabled for. stub: set[type[nn.Module] | nn.Module] = field(default_factory=set) #: Modules to leave weight quantization disabled for. weight: set[type[nn.Module] | nn.Module] = field(default_factory=set) #: Modules to leave smooth quantization disabled for. smooth: set[type[nn.Module] | nn.Module] = field(default_factory=set)
[docs] @dataclass(frozen=True) class QuantConfig: """Top-level quantization configuration. Bundles separate :class:`~embedl_deploy._internal.core.quantize.config.TensorQuantConfig` instances for activations and weights so that each can be configured independently. :param activation: Settings for activation (inter-layer) quantization. :param weight: Settings for weight quantization (Conv kernels only; bias is always left in floating-point). :param skip: Specifies which modules or module types to leave disabled during configure. See :class:`~embedl_deploy._internal.core.quantize.config.ModulesToSkip`. """ #: Settings for activation quantization. activation: TensorQuantConfig = field(default_factory=TensorQuantConfig) #: Settings for weight quantization. weight: TensorQuantConfig = field(default_factory=TensorQuantConfig) #: SmoothQuant settings applied during calibration. smooth_quant: SmoothQuantConfig = field(default_factory=SmoothQuantConfig) #: Modules or module types to leave disabled during configure. skip: ModulesToSkip = field(default_factory=ModulesToSkip)