embedl_deploy.quantize package#
Module contents:
Public quantization API.
Users import from here:
from embedl_deploy.quantize import insert_qdq, optimize_qdq, calibrate
- class embedl_deploy.quantize.CalibrationMethod(value)[source]#
Bases:
EnumAlgorithm for collecting activation statistics during PTQ calibration.
Each member’s value is the
torch.aoobserver class instantiated and attached to aQuantStubon the first calibration forward pass.- HISTOGRAM = <class 'torch.ao.quantization.observer.HistogramObserver'>#
Build a histogram and search for the optimal quantization range.
- MINMAX = <class 'torch.ao.quantization.observer.MinMaxObserver'>#
Track the global minimum and maximum (default, fastest).
- MOVING_AVERAGE_MINMAX = <class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>#
Exponential moving average of min/max — less sensitive to outliers.
- class embedl_deploy.quantize.ModulesToSkip(stub: set[type[Module] | Module] = <factory>, weight: set[type[Module] | Module] = <factory>, smooth: set[type[Module] | Module] = <factory>)[source]#
Bases:
objectSpecifies which modules or module types to leave disabled during configure.
- smooth: set[type[Module] | Module]#
Modules to leave smooth quantization disabled for.
- stub: set[type[Module] | Module]#
Modules to leave stub quantization disabled for.
- weight: set[type[Module] | Module]#
Modules to leave weight quantization disabled for.
- class embedl_deploy.quantize.QuantConfig(activation: TensorQuantConfig = <factory>, weight: TensorQuantConfig = <factory>, smooth_quant: SmoothQuantConfig = <factory>, skip: ModulesToSkip = <factory>)[source]#
Bases:
objectTop-level quantization configuration.
Bundles separate
TensorQuantConfiginstances for activations and weights so that each can be configured independently.- Parameters:
activation – Settings for activation (inter-layer) quantization.
weight – Settings for weight quantization (Conv kernels only; bias is always left in floating-point).
skip – Specifies which modules or module types to leave disabled during configure. See
ModulesToSkip.
- activation: TensorQuantConfig#
Settings for activation quantization.
- skip: ModulesToSkip#
Modules or module types to leave disabled during configure.
- smooth_quant: SmoothQuantConfig#
SmoothQuant settings applied during calibration.
- weight: TensorQuantConfig#
Settings for weight quantization.
- class embedl_deploy.quantize.QuantStub(consumers: set[Module], n_bits: int = 8, symmetric: bool = True, calibration_method: CalibrationMethod = CalibrationMethod.MINMAX, *, fixed_calibration: tuple[float, int] | None = None)[source]#
Bases:
ModuleQuantize a floating-point tensor.
During calibration the module delegates statistics collection to a
torch.aoobserver selected by calibration_method. After calibration,scaleandzero_pointare derived from the observer and used bytorch.fake_quantize_per_tensor_affine()in the forward pass.- Parameters:
consumers – Set of modules that consume this stub’s output.
n_bits – Number of quantization bits (default 8).
symmetric – Symmetric or asymmetric quantization.
calibration_method – Algorithm used to collect activation statistics. Defaults to
MINMAX.fixed_calibration – Fixed
(scale, zero_point)tuple. When provided, calibration will not override the values.
- compute_parameters() None[source]#
Derive
scaleandzero_pointfrom the observer.- Raises:
RuntimeError – If no data was observed during calibration.
- scale: Tensor#
- zero_point: Tensor#
- class embedl_deploy.quantize.SmoothQuantConfig(alpha: float = 0.5)[source]#
Bases:
objectSmoothQuant migration settings.
Controls the per-channel weight/activation redistribution applied by
calibrate_smooth_quant().- Parameters:
alpha – Migration strength in
[0, 1].0keeps all difficulty on activations;1pushes it entirely to weights.
- alpha: float = 0.5#
Migration strength in
[0, 1].
- class embedl_deploy.quantize.TensorQuantConfig(n_bits: int = 8, symmetric: bool = True, per_channel: bool = False, calibration_method: CalibrationMethod = CalibrationMethod.MINMAX)[source]#
Bases:
objectQuantization settings for a single tensor class (activation or weight).
- Parameters:
n_bits – Number of bits for the quantized representation. Must be between 2 and 16 inclusive.
symmetric – When
Truethe quantized range is centered on zero (zero_point = 0). WhenFalsean asymmetric range is used, allowing better coverage of distributions that are not centered on zero.per_channel – When
Trueand when used for weight quantization, a separate scale/zero-point is computed along the output-channel axis (axis 0). Defaults toFalsefor backward compatibility but should be set toTruefor production quantization of Conv/Linear weights.
- calibration_method: CalibrationMethod = <class 'torch.ao.quantization.observer.MinMaxObserver'>#
Calibration algorithm used by
QuantStubto collect activation statistics. Only relevant for activation configs; weight quantization computes scale/zero-point on-the-fly.
- n_bits: int = 8#
Number of quantization bits (2–16).
- per_channel: bool = False#
Per-channel quantization (weight only).
- property quant_max: int#
Maximum representable integer value.
- property quant_min: int#
Minimum representable integer value.
- symmetric: bool = True#
Symmetric (zero_point = 0) or asymmetric range.
- class embedl_deploy.quantize.WeightFakeQuantize(consumers: set[Module], n_bits: int = 8, symmetric: bool = True, per_channel: bool = False, *, channel_axis: int = 0)[source]#
Bases:
ModuleFake-quantize a weight tensor during the forward pass.
Unlike
QuantStub(which requires a calibration pass), this module computesscaleandzero_pointon-the-fly from the weight tensor. Correct for QAT where weights change each step.- Parameters:
consumers – Set of modules that consume this module’s output.
n_bits – Number of quantization bits.
symmetric – Symmetric (
zero_point = 0) or asymmetric.per_channel – Use per-channel quantization along channel_axis.
channel_axis – The axis along which per-channel scales are computed (default 0, i.e. output channels).
- consumers: set[Module]#
- forward(weight: Tensor) Tensor[source]#
Fake-quantize weight using stored (frozen) or on-the-fly scale.
- freeze(weight: Tensor) None[source]#
Compute scale/zero_point from weight and store as constant buffers.
After calling this,
forward()uses the stored constants instead of recomputing them each call. ONNX export will therefore emit aConstantnode for the scale rather than the fullAbs → ReduceMax → Divarithmetic, which TensorRT requires for explicit-quantization fusion.- Parameters:
weight – The weight tensor to calibrate against (typically
conv.weight).
- scale: Tensor | None#
- zero_point: Tensor | None#
- embedl_deploy.quantize.calibrate_qdq(model: GraphModule, forward_loop: Callable[[GraphModule], None]) None[source]#
Calibrate Q/DQ stubs by running the user’s forward loop.
Temporarily disables all enabled
QuantStubandWeightFakeQuantizemodules, switches non-fixed stubs into calibration mode, invokes forward_loop once, then finalizesscale/zero_pointfrom the observed min/max ranges and re-enables all modules.The model is modified in-place.
- Parameters:
model – A configured
GraphModulewhose fused modules have been set up byconfigure().forward_loop –
(model) -> Nonecallable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. Example:def forward_loop(model): for batch in calib_loader: model(batch)
- Raises:
ValueError – If the model contains no enabled
QuantStuborWeightFakeQuantizemodules.RuntimeError – If any stub did not observe finite values during the loop.
Exception – Propagates any exception from forward_loop after restoring model state.
- embedl_deploy.quantize.calibrate_smooth_quant(model: GraphModule, forward_loop: Callable[[GraphModule], None]) None[source]#
Calibrate and apply SmoothQuant to a fused model in-place.
Migrates quantization difficulty from activations to weights for every LayerNorm → Linear pair that has an enabled
SmoothQuantObserverwith populateddownstream_linears. Instruments each observer’slayer_normwith forward hooks, temporarily disables all enabled stubs and weight fake-quantize modules, and runs forward_loop on the model.Must be called after
transform()(fusion) andconfigure(), and beforecalibrate_qdq().- Parameters:
model – A fused
GraphModulewhose observers have been enabled byconfigure.forward_loop –
(model) -> Nonecallable that runs representative data through the model. The caller controls batch size, device placement, and iteration count.
- Raises:
Exception – Propagates any exception from forward_loop after restoring model state.
- embedl_deploy.quantize.configure(model: GraphModule, config: QuantConfig) None[source]#
Configure quantization settings on all fused modules in-place.
Walks every
FusedModuleand:Configures and enables each
QuantStub. Stubs withoutfixed_parametersreceiveconfig.activation; stubs not excluded byconfig.skipare enabled.Configures and enables
weight_fake_quant(respectingskip).Enables
smooth_quant_observerwhere present and copies thesmooth_quantconfig from config.
After configuration, prepares the graph for smooth-quant observation and inserts Q/DQ stub nodes.
- Parameters:
model – A
GraphModuleproduced by the fusion step. Modified in-place.config – A
QuantConfigcontrolling activation bits, weight bits, and which module types to skip.
- embedl_deploy.quantize.disable_fake_quant(model: Module) Module[source]#
Disable fake quantization throughout the model.
- Parameters:
model – The model to modify in-place.
- Returns:
The same model, for method chaining.
- embedl_deploy.quantize.enable_fake_quant(model: Module) Module[source]#
Enable fake quantization in all stubs and weight quantizers.
- Parameters:
model – The model to modify in-place.
- Returns:
The same model, for method chaining.
- embedl_deploy.quantize.freeze_bn_stats(model: Module) Module[source]#
Freeze BatchNorm running statistics.
Puts all
BatchNorm*dlayers into eval mode sorunning_meanandrunning_varare no longer updated. Affine parameters remain trainable.- Parameters:
model – The model to modify in-place.
- Returns:
The same model, for method chaining.
- embedl_deploy.quantize.freeze_weight_quantization(model: GraphModule) None[source]#
Freeze all
WeightFakeQuantizescale/zero_point buffers.After calibration (or QAT training) the weights are fixed for export, so we compute the scale once and store it as a constant buffer. This ensures ONNX export emits a
Constantnode for the scale rather than the dynamicAbs → ReduceMax → Divarithmetic, which TensorRT requires for explicit-quantization Q/DQ fusion.This is called automatically by
quantize()unlessfreeze_weights=Falseis passed. QAT users should call this explicitly before ONNX export once training is complete.
- embedl_deploy.quantize.prepare_qat(model: Module) Module[source]#
Prepare a quantized model for quantization-aware training.
Sets the model to training mode so that QuantStub nodes propagate gradients through STE and WeightFakeQuantize nodes apply fake-quant to weights.
- Parameters:
model – A quantized
nn.Module.- Returns:
The same model, in-place, for method chaining.
- embedl_deploy.quantize.quantize(model: GraphModule, args: tuple[Any, ...], config: QuantConfig | None = None, *, forward_loop: Callable[[GraphModule], None], freeze_weights: bool = False) GraphModule[source]#
Configure, insert Q/DQ stubs, optimize, and calibrate in one call.
Convenience wrapper that chains
configure()→calibrate_smooth_quant()→calibrate_qdq().- Parameters:
model – A
GraphModuleproduced by the fusion step.args – The arguments to use for shape propagation necessary to get tensor meta data required for calibration.
config – Optional
QuantConfig. Defaults to 8-bit symmetric.forward_loop –
(model) -> Nonecallable that runs representative data through the model. The caller controls batch size, device placement, and iteration count.freeze_weights – When
True, weight scales are computed from the current weights and stored as constant buffers after calibration. This is required for ONNX/TensorRT export (PTQ workflow). Defaults toFalseto preserve the original behavior (dynamic on-the-fly scale computation, suitable for QAT). Callfreeze_weight_quantization()before export once training is complete.
- Returns:
The quantized
GraphModulewith calibrated stubs.