Source code for embedl_deploy._internal.core.quantize.main
# Copyright (C) 2026 Embedl AB
"""High-level quantization entry points.
Convenience wrappers that chain configuration, smooth-quant calibration,
Q/DQ stub insertion, optimisation, and calibration into single calls.
"""
import copy
from collections.abc import Callable
from typing import Any
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp
from embedl_deploy._internal.core.config import QuantConfig
from embedl_deploy._internal.core.quantize.qdq import (
calibrate_qdq,
prepare_qdq,
)
from embedl_deploy._internal.core.quantize.smooth_quant import (
calibrate_smooth_quant,
prepare_smooth_quant,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants
def _no_skip_match(
mods: set[nn.Module],
skip_set: set[type[nn.Module] | nn.Module],
) -> bool:
"""Return ``True`` if no child of any module in `mods` matches `skip_set`."""
types = tuple(s for s in skip_set if isinstance(s, type))
instances = {s for s in skip_set if not isinstance(s, type)}
for mod in mods:
for child in mod.children():
if child in instances:
return False
if types and isinstance(child, types):
return False
return True
[docs]
def quantize(
model: fx.GraphModule,
args: tuple[Any, ...],
config: QuantConfig | None = None,
*,
forward_loop: Callable[[fx.GraphModule], None],
) -> fx.GraphModule:
"""Configure, insert Q/DQ stubs, optimise, and calibrate in one call.
Convenience wrapper that chains
:func:`~embedl_deploy._internal.core.quantize.main.configure` →
:func:`~embedl_deploy._internal.core.quantize.smooth_quant.calibrate_smooth_quant` →
:func:`~embedl_deploy._internal.core.quantize.qdq.calibrate_qdq`.
:param model:
A ``GraphModule`` produced by the fusion step.
:param args:
The arguments to use for shape propagation necessary to get tensor meta
data required for calibration.
:param config:
Optional :class:`~embedl_deploy._internal.core.config.QuantConfig`.
Defaults to 8-bit symmetric.
:param forward_loop:
``(model) -> None`` callable that runs representative data
through the model. The caller controls batch size, device
placement, and iteration count.
:returns:
The quantized ``GraphModule`` with calibrated stubs.
"""
ShapeProp(model).propagate(*args) # type: ignore[no-untyped-call]
config = config or QuantConfig()
configure(model, config)
calibrate_smooth_quant(model, forward_loop)
calibrate_qdq(model, forward_loop)
return model