Source code for embedl_deploy._internal.core.quantize.calibrate
# Copyright (C) 2026 Embedl AB
"""Calibration passes for Q/DQ stubs and SmoothQuant observers."""
from collections.abc import Callable
import torch
from torch import fx
from embedl_deploy._internal.core.quantize.utils import get_model_quants
[docs]
def calibrate_smooth_quant(
model: fx.GraphModule,
forward_loop: Callable[[fx.GraphModule], None],
) -> None:
"""Calibrate and apply SmoothQuant to a fused model in-place.
Migrates quantization difficulty from activations to weights for
every LayerNorm → Linear pair that has an enabled
:class:`~embedl_deploy._internal.core.quantize.stubs.SmoothQuantObserver`
with populated ``downstream_linears``. Instruments each observer's
``layer_norm`` with forward hooks, temporarily disables all enabled
stubs and weight fake-quantize modules, and runs `forward_loop` on
the model.
Must be called after :func:`~embedl_deploy.transform` (fusion) and
:func:`~embedl_deploy._internal.core.quantize.main.configure`,
and before
:func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_qdq`.
:param model:
A fused ``GraphModule`` whose observers have been enabled by
``configure``.
:param forward_loop:
``(model) -> None`` callable that runs representative data
through the model. The caller controls batch size, device placement,
and iteration count.
:raises Exception:
Propagates any exception from `forward_loop` after restoring
model state.
"""
mq = get_model_quants(model)
enabled_stubs = mq.stubs + mq.weight
if not mq.smooth:
return
for stub in enabled_stubs:
stub.enabled = False
hooks = [
obs.register_forward_hook()
for obs in mq.smooth
if obs.downstream_linears
]
try:
model.eval()
with torch.no_grad():
forward_loop(model)
for obs in mq.smooth:
if obs.downstream_linears:
obs.compute_smooth_scales()
finally:
for h in hooks:
h.remove()
for stub in enabled_stubs:
stub.enabled = True
[docs]
def calibrate_qdq(
model: fx.GraphModule,
forward_loop: Callable[[fx.GraphModule], None],
) -> None:
"""Calibrate Q/DQ stubs by running the user's forward loop.
Temporarily disables all enabled
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` and
:class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize`
modules, switches non-fixed stubs into calibration mode, invokes
`forward_loop` once, then finalizes ``scale`` / ``zero_point`` from
the observed min/max ranges and re-enables all modules.
The model is modified **in-place**.
:param model:
A configured ``GraphModule`` whose fused modules have been
set up by
:func:`~embedl_deploy._internal.core.quantize.main.configure`.
:param forward_loop:
``(model) -> None`` callable that runs representative data
through the model. The caller controls batch size, device placement,
and iteration count. Example::
def forward_loop(model):
for batch in calib_loader:
model(batch)
:raises ValueError:
If the model contains no enabled ``QuantStub`` or
``WeightFakeQuantize`` modules.
:raises RuntimeError:
If any stub did not observe finite values during the loop.
:raises Exception:
Propagates any exception from `forward_loop` after restoring
model state.
"""
mq = get_model_quants(model)
enabled_stubs = mq.stubs + mq.weight
if not enabled_stubs:
raise ValueError("Model contains no enabled QuantStub modules.")
for stub in enabled_stubs:
stub.enabled = False
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.calibrating = True
try:
model.eval()
with torch.no_grad():
forward_loop(model)
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.compute_parameters()
finally:
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.calibrating = False
for stub in enabled_stubs:
stub.enabled = True