Source code for embedl_deploy._internal.core.config
# Copyright (C) 2026 Embedl AB
"""Quantization configuration.
A :class:`~embedl_deploy._internal.core.config.QuantConfig` bundles the
settings that control how activations and weights are quantized — bit-width,
symmetric vs. asymmetric range, and per-channel granularity — separately for
each tensor class.
The config is passed to
:func:`~embedl_deploy._internal.core.quantize.qdq.insert_qdq` which uses it to
parameterize the
:class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` (activations)
and :class:`~embedl_deploy._internal.core.quantize.modules.WeightFakeQuantize`
(weights) nodes it creates.
Example::
config = QuantConfig(
activation=TensorQuantConfig(n_bits=8, symmetric=True),
weight=TensorQuantConfig(n_bits=8, symmetric=True, per_channel=True),
)
qdq_model = insert_qdq(fused_model, matches, config=config)
"""
import enum
from dataclasses import dataclass, field
from typing import TypeAlias
from torch import fx, nn
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
)
#: Union of ``torch.ao`` observer types used for activation calibration.
Observer: TypeAlias = (
MinMaxObserver | MovingAverageMinMaxObserver | HistogramObserver
)
[docs]
class CalibrationMethod(enum.Enum):
"""Algorithm for collecting activation statistics during PTQ calibration.
Each member's value is the ``torch.ao`` observer class that backs
:class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`
during calibration.
"""
#: Track the global minimum and maximum (default, fastest).
MINMAX = MinMaxObserver
#: Exponential moving average of min/max — less sensitive to outliers.
MOVING_AVERAGE_MINMAX = MovingAverageMinMaxObserver
#: Build a histogram and search for the optimal quantisation range.
HISTOGRAM = HistogramObserver
[docs]
@dataclass(frozen=True)
class TensorQuantConfig:
"""Quantization settings for a single tensor class (activation or weight).
:param n_bits:
Number of bits for the quantized representation. Must be between 2 and
16 inclusive.
:param symmetric:
When ``True`` the quantized range is centred on zero (``zero_point =
0``). When ``False`` an asymmetric range is used, allowing better
coverage of distributions that are not centred on zero.
:param per_channel:
When ``True`` *and* when used for weight quantization, a separate
scale/zero-point is computed along the output-channel axis (axis 0).
Defaults to ``False`` for backward compatibility but should be set to
``True`` for production quantization of Conv/Linear weights.
"""
#: Number of quantization bits (2–16).
n_bits: int = 8
#: Symmetric (zero_point = 0) or asymmetric range.
symmetric: bool = True
#: Per-channel quantization (weight only).
per_channel: bool = False
#: Calibration algorithm used by
#: :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` to
#: collect activation statistics. Only relevant for activation configs;
#: weight quantization computes scale/zero-point on-the-fly.
calibration_method: CalibrationMethod = CalibrationMethod.MINMAX
def __post_init__(self) -> None:
if not 2 <= self.n_bits <= 16:
raise ValueError(
f"n_bits must be between 2 and 16, got {self.n_bits}"
)
@property
def _half_range(self) -> int:
"""Half the quantized integer range."""
half_range: int = 2 ** (self.n_bits - 1)
return half_range
@property
def quant_min(self) -> int:
"""Minimum representable integer value."""
if self.symmetric:
return -self._half_range
return 0
@property
def quant_max(self) -> int:
"""Maximum representable integer value."""
if self.symmetric:
return self._half_range - 1
return 2 * self._half_range - 1
[docs]
def quant_range(
self,
) -> tuple[bool, int, int]:
"""Return ``(symmetric, quant_min, quant_max)``."""
return (self.symmetric, self.quant_min, self.quant_max)
[docs]
@dataclass(frozen=True)
class SmoothQuantConfig:
"""SmoothQuant migration settings.
Controls the per-channel weight/activation redistribution applied by
:func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant`.
:param alpha:
Migration strength in ``[0, 1]``. ``0`` keeps all difficulty on
activations; ``1`` pushes it entirely to weights.
"""
#: Migration strength in ``[0, 1]``.
alpha: float = 0.5
def __post_init__(self) -> None:
if not 0.0 <= self.alpha <= 1.0:
raise ValueError(f"alpha must be in [0, 1], got {self.alpha}")
[docs]
@dataclass(frozen=True)
class ModulesToSkip:
"""Specifies which modules or module types to leave disabled during configure."""
#: Modules to leave stub quantization disabled for.
stub: set[type[nn.Module] | nn.Module] = field(default_factory=set)
#: Modules to leave weight quantization disabled for.
weight: set[type[nn.Module] | nn.Module] = field(default_factory=set)
#: Modules to leave smooth quantization disabled for.
smooth: set[type[nn.Module] | nn.Module] = field(default_factory=set)
[docs]
@dataclass(frozen=True)
class QuantConfig:
"""Top-level quantization configuration.
Bundles separate
:class:`~embedl_deploy._internal.core.config.TensorQuantConfig` instances
for activations and weights so that each can be configured independently.
:param activation:
Settings for activation (inter-layer) quantization.
:param weight:
Settings for weight quantization (Conv kernels only; bias is always
left in floating-point).
:param skip:
Specifies which modules or module types to leave disabled during
configure. See
:class:`~embedl_deploy._internal.core.config.ModulesToSkip`.
"""
#: Settings for activation quantization.
activation: TensorQuantConfig = field(default_factory=TensorQuantConfig)
#: Settings for weight quantization.
weight: TensorQuantConfig = field(default_factory=TensorQuantConfig)
#: SmoothQuant settings applied during calibration.
smooth_quant: SmoothQuantConfig = field(default_factory=SmoothQuantConfig)
#: Modules or module types to leave disabled during configure.
skip: ModulesToSkip = field(default_factory=ModulesToSkip)
@dataclass
class SmoothQuantObserver:
"""Per-LayerNorm smooth-quant state.
Created in a disabled state by
:class:`~embedl_deploy._internal.tensorrt.modules.linear.FusedLayerNorm`.
Enabled and configured by
:func:`~embedl_deploy._internal.core.quantize.main.configure`.
Downstream nodes are populated by
:class:`~embedl_deploy._internal.tensorrt.patterns.smoothings.PrepareSmoothQuantPattern`.
Consumed by
:func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant`,
which skips observers with empty ``downstream_nodes``.
"""
#: Set of modules that consume this observer's output.
consumers: set[nn.Module]
#: Whether this observer is active. Set to ``True`` by
#: ``configure``; the calibration pass skips disabled observers.
enabled: bool = False
#: Configuration copied from ``QuantConfig.smooth_quant`` by
#: ``configure``.
config: SmoothQuantConfig = field(default_factory=SmoothQuantConfig)
#: ``fx.Node`` references to downstream ``FusedLinear`` /
#: ``FusedLinearAct`` nodes. Must be resolved through the graph
#: module's ``_replaced_nodes`` mapping before use.
downstream_nodes: list[fx.Node] = field(default_factory=list)