Source code for embedl_deploy._internal.core.quantize.qat
# Copyright (C) 2026 Embedl AB
"""Quantization-aware training (QAT) utilities."""
from torch import nn
from embedl_deploy._internal.core.quantize.modules import (
QuantStub,
WeightFakeQuantize,
)
[docs]
def prepare_qat(model: nn.Module) -> nn.Module:
"""Prepare a quantized model for quantization-aware training.
Sets the model to training mode so that QuantStub nodes propagate gradients
through STE and WeightFakeQuantize nodes apply fake-quant to weights.
:param model:
A quantized ``nn.Module``.
:returns:
The same model, in-place, for method chaining.
"""
model.train()
return model
[docs]
def freeze_bn_stats(model: nn.Module) -> nn.Module:
"""Freeze BatchNorm running statistics.
Puts all ``BatchNorm*d`` layers into eval mode so ``running_mean`` and
``running_var`` are no longer updated. Affine parameters remain trainable.
:param model:
The model to modify in-place.
:returns:
The same model, for method chaining.
"""
for module in model.modules():
if isinstance(
module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
):
module.eval()
return model
[docs]
def enable_fake_quant(model: nn.Module) -> nn.Module:
"""Enable fake quantization in all stubs and weight quantizers.
:param model:
The model to modify in-place.
:returns:
The same model, for method chaining.
"""
for module in model.modules():
if isinstance(module, (QuantStub, WeightFakeQuantize)):
module.enabled = True
return model
[docs]
def disable_fake_quant(model: nn.Module) -> nn.Module:
"""Disable fake quantization throughout the model.
:param model:
The model to modify in-place.
:returns:
The same model, for method chaining.
"""
for module in model.modules():
if isinstance(module, (QuantStub, WeightFakeQuantize)):
module.enabled = False
return model