Source code for embedl_deploy._internal.tensorrt.modules.linear
# Copyright (C) 2026 Embedl AB
"""Fused ``nn.Module`` replacements for Linear and LayerNorm patterns."""
from typing import TypeAlias
import torch
import torch.nn.functional as F
from torch import nn
from embedl_deploy._internal.core.config import SmoothQuantObserver
from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.modules import (
WeightFakeQuantize,
)
[docs]
class FusedLinear(FusedModule):
"""Fused wrapper for a standalone ``Linear`` layer.
:param linear:
The ``nn.Linear`` from the matched chain.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, linear: nn.Linear) -> None:
super().__init__()
self.linear = linear
self.weight_fake_quant = WeightFakeQuantize({self})
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``linear``, fake-quantizing the weight."""
weight = self.weight_fake_quant(self.linear.weight)
# pylint: disable-next=not-callable
return F.linear(x, weight, self.linear.bias)
def __repr__(self) -> str:
return (
f"FusedLinear("
f"{self.linear.in_features}→{self.linear.out_features})"
)
[docs]
class FusedLinearAct(FusedModule):
"""Fused ``Linear → Activation``.
:param linear:
The ``nn.Linear`` from the matched chain.
:param act:
The activation module from the matched chain.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, linear: nn.Linear, act: ActivationLike) -> None:
super().__init__()
self.linear = linear
self.act = act
self.weight_fake_quant = WeightFakeQuantize({self})
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``linear → activation``, fake-quantizing the weight."""
weight = self.weight_fake_quant(self.linear.weight)
# pylint: disable-next=not-callable
x = F.linear(x, weight, self.linear.bias)
return self.act(x)
def __repr__(self) -> str:
act_name = type(self.act).__name__
return (
f"FusedLinearAct("
f"{self.linear.in_features}→{self.linear.out_features}, "
f"act={act_name})"
)
FusedLinearLike: TypeAlias = FusedLinear | FusedLinearAct
[docs]
class FusedLayerNorm(FusedModule):
"""Fused wrapper for a standalone ``LayerNorm``.
Weight quantization is disabled by default. LayerNorm's learnable
``weight`` is an element-wise affine scale, so quantizing it yields
negligible savings while hurting accuracy.
:param layer_norm:
The ``nn.LayerNorm`` from the matched chain.
"""
prefers_fp_input: bool = True
inputs_to_quantize: set[int] = set()
def __init__(self, layer_norm: nn.LayerNorm) -> None:
super().__init__()
self.layer_norm = layer_norm
self.smooth_quant_observer = SmoothQuantObserver({self})
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``layer_norm``, fake-quantizing the weight."""
return F.layer_norm(
x,
self.layer_norm.normalized_shape,
self.layer_norm.weight,
self.layer_norm.bias,
self.layer_norm.eps,
)
def __repr__(self) -> str:
return (
f"FusedLayerNorm("
f"normalized_shape={self.layer_norm.normalized_shape}, "
f"eps={self.layer_norm.eps})"
)