Source code for embedl_deploy._internal.core.quantize.main

# Copyright (C) 2026 Embedl AB

"""High-level quantization entry points.

Convenience wrappers that chain configuration, smooth-quant calibration,
Q/DQ stub insertion, optimization, and calibration into single calls.
"""

import copy
from collections.abc import Callable
from typing import Any

import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy._internal.core.quantize.calibrate import (
    calibrate_qdq,
    calibrate_smooth_quant,
)
from embedl_deploy._internal.core.quantize.config import QuantConfig
from embedl_deploy._internal.core.quantize.prepare import (
    prepare_qdq,
    prepare_smooth_quant,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants


def _no_skip_match(
    mods: set[nn.Module],
    skip_set: set[type[nn.Module] | nn.Module],
) -> bool:
    """Return ``True`` if no child of any module in `mods` matches `skip_set`."""
    types = tuple(s for s in skip_set if isinstance(s, type))
    instances = {s for s in skip_set if not isinstance(s, type)}
    for mod in mods:
        for child in mod.children():
            if child in instances:
                return False
            if types and isinstance(child, types):
                return False
    return True


[docs] def configure( model: fx.GraphModule, config: QuantConfig, ) -> None: """Configure quantization settings on all fused modules in-place. Walks every :class:`~embedl_deploy._internal.core.modules.FusedModule` and: * Configures and enables each :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`. Stubs without ``fixed_parameters`` receive ``config.activation``; stubs not excluded by ``config.skip`` are enabled. * Configures and enables ``weight_fake_quant`` (respecting :attr:`~embedl_deploy._internal.core.quantize.config.QuantConfig.skip`). * Enables ``smooth_quant_observer`` where present and copies the ``smooth_quant`` config from `config`. After configuration, prepares the graph for smooth-quant observation and inserts Q/DQ stub nodes. :param model: A ``GraphModule`` produced by the fusion step. Modified in-place. :param config: A :class:`~embedl_deploy._internal.core.quantize.config.QuantConfig` controlling activation bits, weight bits, and which module types to skip. """ mq = get_model_quants(model, include_enabled=True, include_disabled=True) for stub in mq.stubs: stub.enabled = _no_skip_match(stub.consumers, config.skip.stub) if stub.enabled and not stub.fixed_parameters: stub.config = copy.copy(config.activation) for wfq in mq.weight: wfq.enabled = _no_skip_match(wfq.consumers, config.skip.weight) if wfq.enabled: wfq.config = copy.copy(config.weight) for obs in mq.smooth: obs.enabled = _no_skip_match(obs.consumers, config.skip.smooth) if obs.enabled: obs.config = copy.copy(config.smooth_quant) # Snapshot the model device before insertion so newly-created # QuantStub / WeightFakeQuantize / observer buffers can be moved # to the same device as the rest of the graph after the # smooth-quant / Q-DQ insertion passes run. Without this, models # already on CUDA hit "scale on cpu vs other tensors on cuda" in # ``fake_quantize_per_tensor_affine_cachemask_tensor_qparams``. try: device = next(model.parameters()).device except StopIteration: device = torch.device("cpu") prepare_smooth_quant(model) prepare_qdq(model) if device.type != "cpu": model.to(device)
[docs] def quantize( model: fx.GraphModule, args: tuple[Any, ...], config: QuantConfig | None = None, *, forward_loop: Callable[[fx.GraphModule], None], freeze_weights: bool = False, ) -> fx.GraphModule: """Configure, insert Q/DQ stubs, optimize, and calibrate in one call. Convenience wrapper that chains :func:`~embedl_deploy._internal.core.quantize.main.configure` → :func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_smooth_quant` → :func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_qdq`. :param model: A ``GraphModule`` produced by the fusion step. :param args: The arguments to use for shape propagation necessary to get tensor meta data required for calibration. :param config: Optional :class:`~embedl_deploy._internal.core.quantize.config.QuantConfig`. Defaults to 8-bit symmetric. :param forward_loop: ``(model) -> None`` callable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. :param freeze_weights: When ``True``, weight scales are computed from the current weights and stored as constant buffers after calibration. This is required for ONNX/TensorRT export (PTQ workflow). Defaults to ``False`` to preserve the original behavior (dynamic on-the-fly scale computation, suitable for QAT). Call :func:`freeze_weight_quantization` before export once training is complete. :returns: The quantized ``GraphModule`` with calibrated stubs. """ # ShapeProp.propagate lacks type stubs in torch ShapeProp(model).propagate(*args) # type: ignore[no-untyped-call] config = config or QuantConfig() configure(model, config) calibrate_smooth_quant(model, forward_loop) calibrate_qdq(model, forward_loop) if freeze_weights: freeze_weight_quantization(model) return model
[docs] def freeze_weight_quantization(model: fx.GraphModule) -> None: """Freeze all :class:`WeightFakeQuantize` scale/zero_point buffers. After calibration (or QAT training) the weights are fixed for export, so we compute the scale once and store it as a constant buffer. This ensures ONNX export emits a ``Constant`` node for the scale rather than the dynamic ``Abs → ReduceMax → Div`` arithmetic, which TensorRT requires for explicit-quantization Q/DQ fusion. This is called automatically by :func:`quantize` unless ``freeze_weights=False`` is passed. QAT users should call this explicitly before ONNX export once training is complete. """ mq = get_model_quants(model) for wfq in mq.weight: mod = next(iter(wfq.consumers)) weight = _get_quantized_weight(mod) if isinstance(weight, torch.Tensor): wfq.freeze(weight)
def _get_quantized_weight(mod: nn.Module) -> torch.Tensor | None: """Return the weight tensor that a ``WeightFakeQuantize`` covers. Checks ``mod.conv.weight``, ``mod.linear.weight``, and ``mod.in_proj.linear.weight`` — the three attribute paths used by fused-module subclasses that carry a ``WeightFakeQuantize``. """ for attr_chain in ("conv", "linear", "in_proj.linear"): obj: object = mod for part in attr_chain.split("."): obj = getattr(obj, part, None) if obj is None: break if obj is not None: weight = getattr(obj, "weight", None) if isinstance(weight, torch.Tensor): return weight return None