Source code for embedl_deploy._internal.core.quantize.qdq
# Copyright (C) 2026 Embedl AB
"""Pattern-based Q/DQ stub insertion and calibration.
Provides :func:`calibrate_qdq` which inserts Q/DQ stubs via
:data:`~embedl_deploy._internal.tensorrt.plan.TENSORRT_QUANTIZED_PATTERNS`
and then calibrates them in a single pass.
"""
from collections.abc import Callable
import torch
from torch import fx
from embedl_deploy._internal.core.plan import (
apply_transformation_plan,
get_transformation_plan,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants
from embedl_deploy._internal.tensorrt.plan import TENSORRT_QUANTIZED_PATTERNS
def prepare_qdq(model: fx.GraphModule) -> None:
"""Insert and optimise Q/DQ stubs using pattern-based rewrites.
Iteratively applies
:data:`~embedl_deploy._internal.tensorrt.plan.TENSORRT_QUANTIZED_PATTERNS`
until no new matches are found. The model must already have its fused
modules configured via
:func:`~embedl_deploy._internal.core.quantize.main.configure`.
:param model:
A configured ``GraphModule`` with enabled ``input_quant_stubs``.
"""
setattr(model, "_deep_copy_done", True)
while True:
plan = get_transformation_plan(model, TENSORRT_QUANTIZED_PATTERNS)
if not plan.matches:
break
model = apply_transformation_plan(plan).model
[docs]
def calibrate_qdq(
model: fx.GraphModule,
forward_loop: Callable[[fx.GraphModule], None],
) -> None:
"""Calibrate Q/DQ stubs by running the user's forward loop.
Switches every :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`
into calibration mode, invokes *forward_loop* once, then finalises
``scale`` / ``zero_point`` from the observed min/max ranges.
The model is modified **in-place**.
:param model:
A configured ``GraphModule`` whose fused modules have been
set up by
:func:`~embedl_deploy._internal.core.quantize.main.configure`.
:param forward_loop:
``(model) -> None`` callable that runs representative data
through the model. The caller controls batch size, device
placement, and iteration count. Example::
def forward_loop(model):
for batch in calib_loader:
model(batch)
:raises ValueError:
If the model contains no enabled ``QuantStub`` modules.
:raises RuntimeError:
If any stub did not observe finite values during the loop.
:raises Exception:
Re-raises any exception from *forward_loop* after restoring
model state.
"""
mq = get_model_quants(model)
enabled_stubs = mq.stubs + mq.weight
if not enabled_stubs:
raise ValueError("Model contains no enabled QuantStub modules.")
for stub in enabled_stubs:
stub.enabled = False
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.calibrating = True
try:
model.eval()
with torch.no_grad():
forward_loop(model)
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.compute_parameters()
finally:
for stub in mq.stubs:
if not stub.fixed_parameters:
stub.calibrating = False
for stub in enabled_stubs:
stub.enabled = True