Source code for embedl_deploy._internal.core.quantize.qat

# Copyright (C) 2026 Embedl AB

"""Quantization-aware training (QAT) utilities."""

from torch import nn

from embedl_deploy._internal.core.quantize.modules import (
    QuantStub,
    WeightFakeQuantize,
)


[docs] def prepare_qat(model: nn.Module) -> nn.Module: """Prepare a quantized model for quantization-aware training. Sets the model to training mode so that QuantStub nodes propagate gradients through STE and WeightFakeQuantize nodes apply fake-quant to weights. :param model: A quantized ``nn.Module``. :returns: The same model, in-place, for method chaining. """ model.train() return model
[docs] def freeze_bn_stats(model: nn.Module) -> nn.Module: """Freeze BatchNorm running statistics. Puts all ``BatchNorm*d`` layers into eval mode so ``running_mean`` and ``running_var`` are no longer updated. Affine parameters remain trainable. :param model: The model to modify in-place. :returns: The same model, for method chaining. """ for module in model.modules(): if isinstance( module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) ): module.eval() return model
[docs] def enable_fake_quant(model: nn.Module) -> nn.Module: """Enable fake quantization in all stubs and weight quantizers. :param model: The model to modify in-place. :returns: The same model, for method chaining. """ for module in model.modules(): if isinstance(module, (QuantStub, WeightFakeQuantize)): module.enabled = True return model
[docs] def disable_fake_quant(model: nn.Module) -> nn.Module: """Disable fake quantization throughout the model. :param model: The model to modify in-place. :returns: The same model, for method chaining. """ for module in model.modules(): if isinstance(module, (QuantStub, WeightFakeQuantize)): module.enabled = False return model