Source code for embedl_deploy._internal.core.quantize.main

# Copyright (C) 2026 Embedl AB

"""High-level quantization entry points.

Convenience wrappers that chain configuration, smooth-quant calibration,
Q/DQ stub insertion, optimisation, and calibration into single calls.
"""

import copy
from collections.abc import Callable
from typing import Any

from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy._internal.core.config import QuantConfig
from embedl_deploy._internal.core.quantize.qdq import (
    calibrate_qdq,
    prepare_qdq,
)
from embedl_deploy._internal.core.quantize.smooth_quant import (
    calibrate_smooth_quant,
    prepare_smooth_quant,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants


def _no_skip_match(
    mods: set[nn.Module],
    skip_set: set[type[nn.Module] | nn.Module],
) -> bool:
    """Return ``True`` if no child of any module in `mods` matches `skip_set`."""
    types = tuple(s for s in skip_set if isinstance(s, type))
    instances = {s for s in skip_set if not isinstance(s, type)}
    for mod in mods:
        for child in mod.children():
            if child in instances:
                return False
            if types and isinstance(child, types):
                return False
    return True


[docs] def configure( model: fx.GraphModule, config: QuantConfig, ) -> None: """Configure quantization settings on all fused modules in-place. Walks every :class:`~embedl_deploy._internal.core.modules.FusedModule` and: * Configures and enables each :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`. Stubs without ``fixed_parameters`` receive ``config.activation``; stubs not excluded by ``config.skip`` are enabled. * Configures and enables ``weight_fake_quant`` (respecting :attr:`~embedl_deploy._internal.core.config.QuantConfig.skip`). * Enables ``smooth_quant_observer`` where present and copies the ``smooth_quant`` config from *config*. :param model: A ``GraphModule`` produced by the fusion step. Modified in-place. :param config: A :class:`~embedl_deploy._internal.core.config.QuantConfig` controlling activation bits, weight bits, and which module types to skip. """ mq = get_model_quants(model, include_enabled=True, include_disabled=True) for stub in mq.stubs: stub.enabled = _no_skip_match(stub.consumers, config.skip.stub) if stub.enabled and not stub.fixed_parameters: stub.config = copy.copy(config.activation) for wfq in mq.weight: wfq.enabled = _no_skip_match(wfq.consumers, config.skip.weight) if wfq.enabled: wfq.config = copy.copy(config.weight) for obs in mq.smooth: obs.enabled = _no_skip_match(obs.consumers, config.skip.smooth) if obs.enabled: obs.config = copy.copy(config.smooth_quant) prepare_smooth_quant(model) prepare_qdq(model)
[docs] def quantize( model: fx.GraphModule, args: tuple[Any, ...], config: QuantConfig | None = None, *, forward_loop: Callable[[fx.GraphModule], None], ) -> fx.GraphModule: """Configure, insert Q/DQ stubs, optimise, and calibrate in one call. Convenience wrapper that chains :func:`~embedl_deploy._internal.core.quantize.main.configure` → :func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant` → :func:`~embedl_deploy._internal.core.quantize.qdq.calibrate_qdq`. :param model: A ``GraphModule`` produced by the fusion step. :param args: The arguments to use for shape propagation necessary to get tensor meta data required for calibration. :param config: Optional :class:`~embedl_deploy._internal.core.config.QuantConfig`. Defaults to 8-bit symmetric. :param forward_loop: ``(model) -> None`` callable that runs representative data through the model. The caller controls batch size, device placement, and iteration count. :returns: The quantized ``GraphModule`` with calibrated stubs. """ ShapeProp(model).propagate(*args) # type: ignore[no-untyped-call] config = config or QuantConfig() configure(model, config) calibrate_smooth_quant(model, forward_loop) calibrate_qdq(model, forward_loop) return model