# Copyright (C) 2026 Embedl AB
"""High-level quantization entry points.
Convenience wrappers that chain configuration, smooth-quant calibration,
Q/DQ stub insertion, optimization, and calibration into single calls.
"""
import copy
from collections.abc import Callable
from typing import Any
import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp
from embedl_deploy._internal.core.quantize.calibrate import (
calibrate_qdq,
calibrate_smooth_quant,
)
from embedl_deploy._internal.core.quantize.config import QuantConfig
from embedl_deploy._internal.core.quantize.prepare import (
prepare_qdq,
prepare_smooth_quant,
)
from embedl_deploy._internal.core.quantize.utils import get_model_quants
def _no_skip_match(
mods: set[nn.Module],
skip_set: set[type[nn.Module] | nn.Module],
) -> bool:
"""Return ``True`` if no child of any module in `mods` matches `skip_set`."""
types = tuple(s for s in skip_set if isinstance(s, type))
instances = {s for s in skip_set if not isinstance(s, type)}
for mod in mods:
for child in mod.children():
if child in instances:
return False
if types and isinstance(child, types):
return False
return True
[docs]
def quantize(
model: fx.GraphModule,
args: tuple[Any, ...],
config: QuantConfig | None = None,
*,
forward_loop: Callable[[fx.GraphModule], None],
freeze_weights: bool = False,
) -> fx.GraphModule:
"""Configure, insert Q/DQ stubs, optimize, and calibrate in one call.
Convenience wrapper that chains
:func:`~embedl_deploy._internal.core.quantize.main.configure` →
:func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_smooth_quant` →
:func:`~embedl_deploy._internal.core.quantize.calibrate.calibrate_qdq`.
:param model:
A ``GraphModule`` produced by the fusion step.
:param args:
The arguments to use for shape propagation necessary to get tensor meta
data required for calibration.
:param config:
Optional :class:`~embedl_deploy._internal.core.quantize.config.QuantConfig`.
Defaults to 8-bit symmetric.
:param forward_loop:
``(model) -> None`` callable that runs representative data
through the model. The caller controls batch size, device placement,
and iteration count.
:param freeze_weights:
When ``True``, weight scales are computed from the current weights
and stored as constant buffers after calibration. This is required
for ONNX/TensorRT export (PTQ workflow). Defaults to ``False`` to
preserve the original behavior (dynamic on-the-fly scale computation,
suitable for QAT). Call :func:`freeze_weight_quantization` before
export once training is complete.
:returns:
The quantized ``GraphModule`` with calibrated stubs.
"""
# ShapeProp.propagate lacks type stubs in torch
ShapeProp(model).propagate(*args) # type: ignore[no-untyped-call]
config = config or QuantConfig()
configure(model, config)
calibrate_smooth_quant(model, forward_loop)
calibrate_qdq(model, forward_loop)
if freeze_weights:
freeze_weight_quantization(model)
return model
[docs]
def freeze_weight_quantization(model: fx.GraphModule) -> None:
"""Freeze all :class:`WeightFakeQuantize` scale/zero_point buffers.
After calibration (or QAT training) the weights are fixed for export,
so we compute the scale once and store it as a constant buffer. This
ensures ONNX export emits a ``Constant`` node for the scale rather than
the dynamic ``Abs → ReduceMax → Div`` arithmetic, which TensorRT
requires for explicit-quantization Q/DQ fusion.
This is called automatically by :func:`quantize` unless
``freeze_weights=False`` is passed. QAT users should call this
explicitly before ONNX export once training is complete.
"""
mq = get_model_quants(model)
for wfq in mq.weight:
mod = next(iter(wfq.consumers))
weight = _get_quantized_weight(mod)
if isinstance(weight, torch.Tensor):
wfq.freeze(weight)
def _get_quantized_weight(mod: nn.Module) -> torch.Tensor | None:
"""Return the weight tensor that a ``WeightFakeQuantize`` covers.
Checks ``mod.conv.weight``, ``mod.linear.weight``, and
``mod.in_proj.linear.weight`` — the three attribute paths used
by fused-module subclasses that carry a ``WeightFakeQuantize``.
"""
for attr_chain in ("conv", "linear", "in_proj.linear"):
obj: object = mod
for part in attr_chain.split("."):
obj = getattr(obj, part, None)
if obj is None:
break
if obj is not None:
weight = getattr(obj, "weight", None)
if isinstance(weight, torch.Tensor):
return weight
return None