Source code for embedl_deploy._internal.tensorrt.modules.conv

# Copyright (C) 2026 Embedl AB

"""Fused ``nn.Module`` replacements for convolution-based patterns.

Each class represents a hardware-fusible operation that replaces a multi-op
chain found by the pattern matcher.  The fused module keeps the original sub-
modules (``Conv``, ``BN``, ``ReLU``) as children so that:

* Weights are trivially transferred from the original model.
* ``forward()`` is numerically identical to the original chain.
* A later compilation step can fold the ``BN`` and emit a single kernel.
"""

import torch
import torch.nn.functional as F
from torch import nn

from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
    WeightFakeQuantize,
)


def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
    """Return ``True`` unless *conv* is a grouped conv violating TRT INT8.

    TensorRT requires ``in_channels / groups`` and
    ``out_channels / groups`` to both be multiples of 4 for INT8.
    """
    if conv.groups <= 1:
        return True
    in_per_group: int = conv.in_channels // conv.groups
    out_per_group: int = conv.out_channels // conv.groups
    return in_per_group % 4 == 0 and out_per_group % 4 == 0


def _conv_weight_forward(
    conv: nn.Conv2d,
    weight_fake_quant: WeightFakeQuantize | None,
    x: torch.Tensor,
) -> torch.Tensor:
    """Run a convolution, fake-quantizing the weight when enabled."""
    weight = (
        weight_fake_quant(conv.weight)
        if weight_fake_quant is not None
        else conv.weight
    )
    # pylint: disable-next=not-callable
    return F.conv2d(
        x,
        weight,
        conv.bias,
        conv.stride,
        conv.padding,
        conv.dilation,
        conv.groups,
    )


[docs] class FusedConvBNAct(FusedModule): """Fused ``Conv2d → [BatchNorm2d] → Act``.""" inputs_to_quantize: set[int] = {0} def __init__( self, conv: nn.Conv2d, bn: nn.BatchNorm2d | None, act: ActivationLike, *, bn_foldable: bool = True, ) -> None: super().__init__() self.conv = conv self.bn = bn self.act = act self.bn_foldable = bn_foldable if _is_int8_compatible_conv(conv): self.weight_fake_quant = WeightFakeQuantize({self}) else: self.input_quant_stubs = {}
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``conv → [bn] → act``.""" wfq = getattr(self, 'weight_fake_quant', None) x = _conv_weight_forward(self.conv, wfq, x) if self.bn is not None: x = self.bn(x) return self.act(x)
def __repr__(self) -> str: bn_info = "" if self.bn is not None: fold = "foldable" if self.bn_foldable else "kept" bn_info = f", bn={self.bn.num_features} ({fold})" return ( f"FusedConvBNAct(" f"{self.conv.in_channels}{self.conv.out_channels}, " f"k={self.conv.kernel_size}, s={self.conv.stride}" f"{bn_info})" )
[docs] class FusedConvBN(FusedModule): """Fused ``Conv2d → [BatchNorm2d]`` (no activation).""" inputs_to_quantize: set[int] = {0} def __init__( self, conv: nn.Conv2d, bn: nn.BatchNorm2d | None, *, bn_foldable: bool = True, ) -> None: super().__init__() self.conv = conv self.bn = bn self.bn_foldable = bn_foldable if _is_int8_compatible_conv(conv): self.weight_fake_quant = WeightFakeQuantize({self}) else: self.input_quant_stubs = {}
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``conv → [bn]``.""" wfq = getattr(self, 'weight_fake_quant', None) x = _conv_weight_forward(self.conv, wfq, x) if self.bn is not None: x = self.bn(x) return x
def __repr__(self) -> str: bn_info = "" if self.bn is not None: fold = "foldable" if self.bn_foldable else "kept" bn_info = f", bn={self.bn.num_features} ({fold})" return ( f"FusedConvBN(" f"{self.conv.in_channels}{self.conv.out_channels}, " f"k={self.conv.kernel_size}, s={self.conv.stride}" f"{bn_info})" )
[docs] class FusedConvBNActMaxPool(FusedModule): """Fused ``Conv2d → [BatchNorm2d] → Activation → MaxPool2d``.""" inputs_to_quantize: set[int] = {0} def __init__( self, conv: nn.Conv2d, bn: nn.BatchNorm2d | None, act: ActivationLike, maxpool: nn.MaxPool2d, *, bn_foldable: bool = True, ) -> None: super().__init__() self.conv = conv self.bn = bn self.act = act self.maxpool = maxpool self.bn_foldable = bn_foldable self.weight_fake_quant = WeightFakeQuantize({self})
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``conv → [bn] → act → maxpool``.""" x = _conv_weight_forward(self.conv, self.weight_fake_quant, x) if self.bn is not None: x = self.bn(x) x = self.act(x) return self.maxpool(x)
def __repr__(self) -> str: bn_info = "" if self.bn is not None: fold = "foldable" if self.bn_foldable else "kept" bn_info = f", bn={self.bn.num_features} ({fold})" mp = self.maxpool return ( f"FusedConvBNActMaxPool(" f"{self.conv.in_channels}{self.conv.out_channels}, " f"k={self.conv.kernel_size}, s={self.conv.stride}" f"{bn_info}, " f"pool_k={mp.kernel_size}, pool_s={mp.stride})" )
[docs] class FusedConvBNAddAct(FusedModule): """Fused ``Conv2d → BatchNorm2d → add(·, residual) → Activation``. ``forward()`` accepts two inputs: the main tensor ``x`` and the ``residual`` tensor. """ inputs_to_quantize: set[int] = {0, 1} def __init__( self, conv: nn.Conv2d, bn: nn.BatchNorm2d, act: ActivationLike, *, bn_foldable: bool = True, ) -> None: super().__init__() self.conv = conv self.bn = bn self.act = act self.bn_foldable = bn_foldable if _is_int8_compatible_conv(conv): self.weight_fake_quant = WeightFakeQuantize({self}) else: self.input_quant_stubs = {}
[docs] def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: """Apply ``conv → bn → add(·, residual) → act``.""" wfq = getattr(self, 'weight_fake_quant', None) x = _conv_weight_forward(self.conv, wfq, x) x = self.bn(x) return self.act(x + residual)
def __repr__(self) -> str: fold = "foldable" if self.bn_foldable else "kept" return ( f"FusedConvBNAddAct(" f"{self.conv.in_channels}{self.conv.out_channels}, " f"k={self.conv.kernel_size}, s={self.conv.stride}, " f"bn={self.bn.num_features} ({fold}))" )