Source code for embedl_deploy._internal.core.quantize.qdq

# Copyright (C) 2026 Embedl AB

"""Pattern-based Q/DQ stub insertion and calibration.

Provides :func:`calibrate_qdq` which inserts Q/DQ stubs via
:data:`~embedl_deploy._internal.tensorrt.plan.TENSORRT_QUANTIZED_PATTERNS`
and then calibrates them in a single pass.
"""

from collections.abc import Callable

import torch
from torch import fx

from embedl_deploy._internal.core.plan import (
    apply_transformation_plan,
    get_transformation_plan,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants
from embedl_deploy._internal.tensorrt.plan import TENSORRT_QUANTIZED_PATTERNS


def prepare_qdq(model: fx.GraphModule) -> None:
    """Insert and optimise Q/DQ stubs using pattern-based rewrites.

    Iteratively applies
    :data:`~embedl_deploy._internal.tensorrt.plan.TENSORRT_QUANTIZED_PATTERNS`
    until no new matches are found.  The model must already have its fused
    modules configured via
    :func:`~embedl_deploy._internal.core.quantize.main.configure`.

    :param model:
        A configured ``GraphModule`` with enabled ``input_quant_stubs``.
    """
    setattr(model, "_deep_copy_done", True)
    while True:
        plan = get_transformation_plan(model, TENSORRT_QUANTIZED_PATTERNS)
        if not plan.matches:
            break
        model = apply_transformation_plan(plan).model


[docs] def calibrate_qdq( model: fx.GraphModule, forward_loop: Callable[[fx.GraphModule], None], ) -> None: """Calibrate Q/DQ stubs by running the user's forward loop. Switches every :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub` into calibration mode, invokes *forward_loop* once, then finalises ``scale`` / ``zero_point`` from the observed min/max ranges. The model is modified **in-place**. :param model: A configured ``GraphModule`` whose fused modules have been set up by :func:`~embedl_deploy._internal.core.quantize.main.configure`. :param forward_loop: ``(model) -> None`` callable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. Example:: def forward_loop(model): for batch in calib_loader: model(batch) :raises ValueError: If the model contains no enabled ``QuantStub`` modules. :raises RuntimeError: If any stub did not observe finite values during the loop. :raises Exception: Re-raises any exception from *forward_loop* after restoring model state. """ mq = get_model_quants(model) enabled_stubs = mq.stubs + mq.weight if not enabled_stubs: raise ValueError("Model contains no enabled QuantStub modules.") for stub in enabled_stubs: stub.enabled = False for stub in mq.stubs: if not stub.fixed_parameters: stub.calibrating = True try: model.eval() with torch.no_grad(): forward_loop(model) for stub in mq.stubs: if not stub.fixed_parameters: stub.compute_parameters() finally: for stub in mq.stubs: if not stub.fixed_parameters: stub.calibrating = False for stub in enabled_stubs: stub.enabled = True