Source code for embedl_deploy._internal.core.config

# Copyright (C) 2026 Embedl AB

"""Quantization configuration.

A :class:`~embedl_deploy._internal.core.config.QuantConfig` bundles the
settings that control how activations and weights are quantized — bit-width,
symmetric vs. asymmetric range, and per-channel granularity — separately for
each tensor class.

The config is passed to
:func:`~embedl_deploy._internal.core.quantize.qdq.insert_qdq` which uses it to
parameterize the
:class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` (activations)
and :class:`~embedl_deploy._internal.core.quantize.modules.WeightFakeQuantize`
(weights) nodes it creates.

Example::

    config = QuantConfig(
        activation=TensorQuantConfig(n_bits=8, symmetric=True),
        weight=TensorQuantConfig(n_bits=8, symmetric=True, per_channel=True),
    )
    qdq_model = insert_qdq(fused_model, matches, config=config)
"""

import enum
from dataclasses import dataclass, field
from typing import TypeAlias

from torch import fx, nn
from torch.ao.quantization.observer import (
    HistogramObserver,
    MinMaxObserver,
    MovingAverageMinMaxObserver,
)

#: Union of ``torch.ao`` observer types used for activation calibration.
Observer: TypeAlias = (
    MinMaxObserver | MovingAverageMinMaxObserver | HistogramObserver
)


[docs] class CalibrationMethod(enum.Enum): """Algorithm for collecting activation statistics during PTQ calibration. Each member's value is the ``torch.ao`` observer class that backs :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` during calibration. """ #: Track the global minimum and maximum (default, fastest). MINMAX = MinMaxObserver #: Exponential moving average of min/max — less sensitive to outliers. MOVING_AVERAGE_MINMAX = MovingAverageMinMaxObserver #: Build a histogram and search for the optimal quantisation range. HISTOGRAM = HistogramObserver
[docs] @dataclass(frozen=True) class TensorQuantConfig: """Quantization settings for a single tensor class (activation or weight). :param n_bits: Number of bits for the quantized representation. Must be between 2 and 16 inclusive. :param symmetric: When ``True`` the quantized range is centred on zero (``zero_point = 0``). When ``False`` an asymmetric range is used, allowing better coverage of distributions that are not centred on zero. :param per_channel: When ``True`` *and* when used for weight quantization, a separate scale/zero-point is computed along the output-channel axis (axis 0). Defaults to ``False`` for backward compatibility but should be set to ``True`` for production quantization of Conv/Linear weights. """ #: Number of quantization bits (2–16). n_bits: int = 8 #: Symmetric (zero_point = 0) or asymmetric range. symmetric: bool = True #: Per-channel quantization (weight only). per_channel: bool = False #: Calibration algorithm used by #: :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` to #: collect activation statistics. Only relevant for activation configs; #: weight quantization computes scale/zero-point on-the-fly. calibration_method: CalibrationMethod = CalibrationMethod.MINMAX def __post_init__(self) -> None: if not 2 <= self.n_bits <= 16: raise ValueError( f"n_bits must be between 2 and 16, got {self.n_bits}" ) @property def _half_range(self) -> int: """Half the quantized integer range.""" half_range: int = 2 ** (self.n_bits - 1) return half_range @property def quant_min(self) -> int: """Minimum representable integer value.""" if self.symmetric: return -self._half_range return 0 @property def quant_max(self) -> int: """Maximum representable integer value.""" if self.symmetric: return self._half_range - 1 return 2 * self._half_range - 1
[docs] def quant_range( self, ) -> tuple[bool, int, int]: """Return ``(symmetric, quant_min, quant_max)``.""" return (self.symmetric, self.quant_min, self.quant_max)
[docs] @dataclass(frozen=True) class SmoothQuantConfig: """SmoothQuant migration settings. Controls the per-channel weight/activation redistribution applied by :func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant`. :param alpha: Migration strength in ``[0, 1]``. ``0`` keeps all difficulty on activations; ``1`` pushes it entirely to weights. """ #: Migration strength in ``[0, 1]``. alpha: float = 0.5 def __post_init__(self) -> None: if not 0.0 <= self.alpha <= 1.0: raise ValueError(f"alpha must be in [0, 1], got {self.alpha}")
[docs] @dataclass(frozen=True) class ModulesToSkip: """Specifies which modules or module types to leave disabled during configure.""" #: Modules to leave stub quantization disabled for. stub: set[type[nn.Module] | nn.Module] = field(default_factory=set) #: Modules to leave weight quantization disabled for. weight: set[type[nn.Module] | nn.Module] = field(default_factory=set) #: Modules to leave smooth quantization disabled for. smooth: set[type[nn.Module] | nn.Module] = field(default_factory=set)
[docs] @dataclass(frozen=True) class QuantConfig: """Top-level quantization configuration. Bundles separate :class:`~embedl_deploy._internal.core.config.TensorQuantConfig` instances for activations and weights so that each can be configured independently. :param activation: Settings for activation (inter-layer) quantization. :param weight: Settings for weight quantization (Conv kernels only; bias is always left in floating-point). :param skip: Specifies which modules or module types to leave disabled during configure. See :class:`~embedl_deploy._internal.core.config.ModulesToSkip`. """ #: Settings for activation quantization. activation: TensorQuantConfig = field(default_factory=TensorQuantConfig) #: Settings for weight quantization. weight: TensorQuantConfig = field(default_factory=TensorQuantConfig) #: SmoothQuant settings applied during calibration. smooth_quant: SmoothQuantConfig = field(default_factory=SmoothQuantConfig) #: Modules or module types to leave disabled during configure. skip: ModulesToSkip = field(default_factory=ModulesToSkip)
@dataclass class SmoothQuantObserver: """Per-LayerNorm smooth-quant state. Created in a disabled state by :class:`~embedl_deploy._internal.tensorrt.modules.linear.FusedLayerNorm`. Enabled and configured by :func:`~embedl_deploy._internal.core.quantize.main.configure`. Downstream nodes are populated by :class:`~embedl_deploy._internal.tensorrt.patterns.smoothings.PrepareSmoothQuantPattern`. Consumed by :func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant`, which skips observers with empty ``downstream_nodes``. """ #: Set of modules that consume this observer's output. consumers: set[nn.Module] #: Whether this observer is active. Set to ``True`` by #: ``configure``; the calibration pass skips disabled observers. enabled: bool = False #: Configuration copied from ``QuantConfig.smooth_quant`` by #: ``configure``. config: SmoothQuantConfig = field(default_factory=SmoothQuantConfig) #: ``fx.Node`` references to downstream ``FusedLinear`` / #: ``FusedLinearAct`` nodes. Must be resolved through the graph #: module's ``_replaced_nodes`` mapping before use. downstream_nodes: list[fx.Node] = field(default_factory=list)