# Copyright (C) 2026 Embedl AB
"""Fused ``nn.Module`` replacements for convolution-based patterns.
Each class represents a hardware-fusible operation that replaces a multi-op
chain found by the pattern matcher. The fused module keeps the original sub-
modules (``Conv``, ``BN``, ``ReLU``) as children so that:
* Weights are trivially transferred from the original model.
* ``forward()`` is numerically identical to the original chain.
* A later compilation step can fold the ``BN`` and emit a single kernel.
"""
import torch
import torch.nn.functional as F
from torch import nn
from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
WeightFakeQuantize,
)
def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
"""Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
TensorRT requires ``in_channels / groups`` and
``out_channels / groups`` to both be multiples of 4 for INT8.
"""
if conv.groups <= 1:
return True
in_per_group: int = conv.in_channels // conv.groups
out_per_group: int = conv.out_channels // conv.groups
return in_per_group % 4 == 0 and out_per_group % 4 == 0
def _conv_weight_forward(
conv: nn.Conv2d,
weight_fake_quant: WeightFakeQuantize | None,
x: torch.Tensor,
) -> torch.Tensor:
"""Run a convolution, fake-quantizing the weight when enabled."""
weight = (
weight_fake_quant(conv.weight)
if weight_fake_quant is not None
else conv.weight
)
# pylint: disable-next=not-callable
return F.conv2d(
x,
weight,
conv.bias,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
)
[docs]
class FusedConvBNAct(FusedModule):
"""Fused ``Conv2d → [BatchNorm2d] → Act``."""
inputs_to_quantize: set[int] = {0}
def __init__(
self,
conv: nn.Conv2d,
bn: nn.BatchNorm2d | None,
act: ActivationLike,
*,
bn_foldable: bool = True,
) -> None:
super().__init__()
self.conv = conv
self.bn = bn
self.act = act
self.bn_foldable = bn_foldable
if _is_int8_compatible_conv(conv):
self.weight_fake_quant = WeightFakeQuantize({self})
else:
self.input_quant_stubs = {}
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``conv → [bn] → act``."""
wfq = getattr(self, 'weight_fake_quant', None)
x = _conv_weight_forward(self.conv, wfq, x)
if self.bn is not None:
x = self.bn(x)
return self.act(x)
def __repr__(self) -> str:
bn_info = ""
if self.bn is not None:
fold = "foldable" if self.bn_foldable else "kept"
bn_info = f", bn={self.bn.num_features} ({fold})"
return (
f"FusedConvBNAct("
f"{self.conv.in_channels}→{self.conv.out_channels}, "
f"k={self.conv.kernel_size}, s={self.conv.stride}"
f"{bn_info})"
)
[docs]
class FusedConvBN(FusedModule):
"""Fused ``Conv2d → [BatchNorm2d]`` (no activation)."""
inputs_to_quantize: set[int] = {0}
def __init__(
self,
conv: nn.Conv2d,
bn: nn.BatchNorm2d | None,
*,
bn_foldable: bool = True,
) -> None:
super().__init__()
self.conv = conv
self.bn = bn
self.bn_foldable = bn_foldable
if _is_int8_compatible_conv(conv):
self.weight_fake_quant = WeightFakeQuantize({self})
else:
self.input_quant_stubs = {}
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``conv → [bn]``."""
wfq = getattr(self, 'weight_fake_quant', None)
x = _conv_weight_forward(self.conv, wfq, x)
if self.bn is not None:
x = self.bn(x)
return x
def __repr__(self) -> str:
bn_info = ""
if self.bn is not None:
fold = "foldable" if self.bn_foldable else "kept"
bn_info = f", bn={self.bn.num_features} ({fold})"
return (
f"FusedConvBN("
f"{self.conv.in_channels}→{self.conv.out_channels}, "
f"k={self.conv.kernel_size}, s={self.conv.stride}"
f"{bn_info})"
)
[docs]
class FusedConvBNActMaxPool(FusedModule):
"""Fused ``Conv2d → [BatchNorm2d] → Activation → MaxPool2d``."""
inputs_to_quantize: set[int] = {0}
def __init__(
self,
conv: nn.Conv2d,
bn: nn.BatchNorm2d | None,
act: ActivationLike,
maxpool: nn.MaxPool2d,
*,
bn_foldable: bool = True,
) -> None:
super().__init__()
self.conv = conv
self.bn = bn
self.act = act
self.maxpool = maxpool
self.bn_foldable = bn_foldable
self.weight_fake_quant = WeightFakeQuantize({self})
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``conv → [bn] → act → maxpool``."""
x = _conv_weight_forward(self.conv, self.weight_fake_quant, x)
if self.bn is not None:
x = self.bn(x)
x = self.act(x)
return self.maxpool(x)
def __repr__(self) -> str:
bn_info = ""
if self.bn is not None:
fold = "foldable" if self.bn_foldable else "kept"
bn_info = f", bn={self.bn.num_features} ({fold})"
mp = self.maxpool
return (
f"FusedConvBNActMaxPool("
f"{self.conv.in_channels}→{self.conv.out_channels}, "
f"k={self.conv.kernel_size}, s={self.conv.stride}"
f"{bn_info}, "
f"pool_k={mp.kernel_size}, pool_s={mp.stride})"
)
[docs]
class FusedConvBNAddAct(FusedModule):
"""Fused ``Conv2d → BatchNorm2d → add(·, residual) → Activation``.
``forward()`` accepts two inputs: the main tensor ``x`` and the
``residual`` tensor.
"""
inputs_to_quantize: set[int] = {0, 1}
def __init__(
self,
conv: nn.Conv2d,
bn: nn.BatchNorm2d,
act: ActivationLike,
*,
bn_foldable: bool = True,
) -> None:
super().__init__()
self.conv = conv
self.bn = bn
self.act = act
self.bn_foldable = bn_foldable
if _is_int8_compatible_conv(conv):
self.weight_fake_quant = WeightFakeQuantize({self})
else:
self.input_quant_stubs = {}
[docs]
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""Apply ``conv → bn → add(·, residual) → act``."""
wfq = getattr(self, 'weight_fake_quant', None)
x = _conv_weight_forward(self.conv, wfq, x)
x = self.bn(x)
return self.act(x + residual)
def __repr__(self) -> str:
fold = "foldable" if self.bn_foldable else "kept"
return (
f"FusedConvBNAddAct("
f"{self.conv.in_channels}→{self.conv.out_channels}, "
f"k={self.conv.kernel_size}, s={self.conv.stride}, "
f"bn={self.bn.num_features} ({fold}))"
)