Source code for embedl_deploy._internal.core.quantize.calibrate

# Copyright (C) 2026 Embedl AB

"""Calibration passes for Q/DQ stubs and SmoothQuant observers."""

from collections.abc import Callable

import torch
from torch import fx

from embedl_deploy._internal.core.quantize.utils import get_model_quants


[docs] def calibrate_smooth_quant( model: fx.GraphModule, forward_loop: Callable[[fx.GraphModule], None], ) -> None: """Calibrate and apply SmoothQuant to a fused model in-place. Migrates quantization difficulty from activations to weights for every LayerNorm → Linear pair that has an enabled :class:`~embedl_deploy._internal.core.quantize.stubs.SmoothQuantObserver` with populated ``downstream_linears``. Instruments each observer's ``layer_norm`` with forward hooks, temporarily disables all enabled stubs and weight fake-quantize modules, and runs `forward_loop` on the model. Must be called after :func:`~embedl_deploy.transform` (fusion) and :func:`~embedl_deploy._internal.core.quantize.main.configure`, and before :func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_qdq`. :param model: A fused ``GraphModule`` whose observers have been enabled by ``configure``. :param forward_loop: ``(model) -> None`` callable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. :raises Exception: Propagates any exception from `forward_loop` after restoring model state. """ mq = get_model_quants(model) enabled_stubs = mq.stubs + mq.weight if not mq.smooth: return for stub in enabled_stubs: stub.enabled = False hooks = [ obs.register_forward_hook() for obs in mq.smooth if obs.downstream_linears ] try: model.eval() with torch.no_grad(): forward_loop(model) for obs in mq.smooth: if obs.downstream_linears: obs.compute_smooth_scales() finally: for h in hooks: h.remove() for stub in enabled_stubs: stub.enabled = True
[docs] def calibrate_qdq( model: fx.GraphModule, forward_loop: Callable[[fx.GraphModule], None], ) -> None: """Calibrate Q/DQ stubs by running the user's forward loop. Temporarily disables all enabled :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` and :class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize` modules, switches non-fixed stubs into calibration mode, invokes `forward_loop` once, then finalizes ``scale`` / ``zero_point`` from the observed min/max ranges and re-enables all modules. The model is modified **in-place**. :param model: A configured ``GraphModule`` whose fused modules have been set up by :func:`~embedl_deploy._internal.core.quantize.main.configure`. :param forward_loop: ``(model) -> None`` callable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. Example:: def forward_loop(model): for batch in calib_loader: model(batch) :raises ValueError: If the model contains no enabled ``QuantStub`` or ``WeightFakeQuantize`` modules. :raises RuntimeError: If any stub did not observe finite values during the loop. :raises Exception: Propagates any exception from `forward_loop` after restoring model state. """ mq = get_model_quants(model) enabled_stubs = mq.stubs + mq.weight if not enabled_stubs: raise ValueError("Model contains no enabled QuantStub modules.") for stub in enabled_stubs: stub.enabled = False for stub in mq.stubs: if not stub.fixed_parameters: stub.calibrating = True try: model.eval() with torch.no_grad(): forward_loop(model) for stub in mq.stubs: if not stub.fixed_parameters: stub.compute_parameters() finally: for stub in mq.stubs: if not stub.fixed_parameters: stub.calibrating = False for stub in enabled_stubs: stub.enabled = True