Source code for embedl_deploy._internal.tensorrt.modules.linear

# Copyright (C) 2026 Embedl AB

"""Fused ``nn.Module`` replacements for Linear and LayerNorm patterns."""

from typing import TypeAlias

import torch
import torch.nn.functional as F
from torch import nn

from embedl_deploy._internal.core.config import SmoothQuantObserver
from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
    WeightFakeQuantize,
)


[docs] class FusedLinear(FusedModule): """Fused wrapper for a standalone ``Linear`` layer. :param linear: The ``nn.Linear`` from the matched chain. """ inputs_to_quantize: set[int] = {0} def __init__(self, linear: nn.Linear) -> None: super().__init__() self.linear = linear self.weight_fake_quant = WeightFakeQuantize({self})
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``linear``, fake-quantizing the weight.""" weight = self.weight_fake_quant(self.linear.weight) # pylint: disable-next=not-callable return F.linear(x, weight, self.linear.bias)
def __repr__(self) -> str: return ( f"FusedLinear(" f"{self.linear.in_features}{self.linear.out_features})" )
[docs] class FusedLinearAct(FusedModule): """Fused ``Linear → Activation``. :param linear: The ``nn.Linear`` from the matched chain. :param act: The activation module from the matched chain. """ inputs_to_quantize: set[int] = {0} def __init__(self, linear: nn.Linear, act: ActivationLike) -> None: super().__init__() self.linear = linear self.act = act self.weight_fake_quant = WeightFakeQuantize({self})
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``linear → activation``, fake-quantizing the weight.""" weight = self.weight_fake_quant(self.linear.weight) # pylint: disable-next=not-callable x = F.linear(x, weight, self.linear.bias) return self.act(x)
def __repr__(self) -> str: act_name = type(self.act).__name__ return ( f"FusedLinearAct(" f"{self.linear.in_features}{self.linear.out_features}, " f"act={act_name})" )
FusedLinearLike: TypeAlias = FusedLinear | FusedLinearAct
[docs] class FusedLayerNorm(FusedModule): """Fused wrapper for a standalone ``LayerNorm``. Weight quantization is disabled by default. LayerNorm's learnable ``weight`` is an element-wise affine scale, so quantizing it yields negligible savings while hurting accuracy. :param layer_norm: The ``nn.LayerNorm`` from the matched chain. """ prefers_fp_input: bool = True inputs_to_quantize: set[int] = set() def __init__(self, layer_norm: nn.LayerNorm) -> None: super().__init__() self.layer_norm = layer_norm self.smooth_quant_observer = SmoothQuantObserver({self})
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``layer_norm``, fake-quantizing the weight.""" return F.layer_norm( x, self.layer_norm.normalized_shape, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps, )
def __repr__(self) -> str: return ( f"FusedLayerNorm(" f"normalized_shape={self.layer_norm.normalized_shape}, " f"eps={self.layer_norm.eps})" )