Source code for embedl_deploy._internal.tensorrt.modules.linear
# Copyright (C) 2026 Embedl AB
"""Fused ``nn.Module`` replacements for Linear and LayerNorm patterns."""
from typing import TypeAlias
import torch
import torch.nn.functional as F
from torch import nn
from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.stubs import (
SmoothQuantObserver,
WeightFakeQuantize,
)
#: Minimum ``K * N / (K + N)`` for INT8 to outperform FP16.
INT8_LINEAR_MIN_RATIO: int = 256
def is_int8_beneficial_linear(linear: nn.Linear) -> bool:
"""Return ``True`` when INT8 quantization benefits *linear*.
Uses the harmonic mean of the weight dimensions ``K * N / (K + N)``
as a proxy for the ratio of INT8 compute savings to Q/DQ reformat
overhead. Below :data:`INT8_LINEAR_MIN_RATIO`, the overhead from
quantize/dequantize boundary layers exceeds any INT8 GEMM speedup
and the layer is better left in FP16.
Reference: NVIDIA benchmarks show INT8 GEMM outperforms FP16 only
when all three matrix dimensions exceed ~2048 (A100). The harmonic
mean threshold of 256 conservatively separates mobile-class models
(MobileViT FFN ratio ≤ 160) from server-class models (ViT-B/16 FFN
ratio = 614) where INT8 is beneficial.
"""
k, n = linear.in_features, linear.out_features
return k * n / (k + n) >= INT8_LINEAR_MIN_RATIO
def attach_int8_weight_quant(
mod: FusedModule,
linear: nn.Linear,
) -> None:
"""Attach a ``WeightFakeQuantize`` to *mod* when INT8 helps *linear*.
When INT8 wouldn't pay for its Q/DQ boundary cost, also clear
``mod.input_quant_stubs`` so the surrounding Q/DQ pass leaves the
wrapped linear entirely in FP16.
"""
if is_int8_beneficial_linear(linear):
mod.weight_fake_quant = WeightFakeQuantize({mod})
else:
mod.input_quant_stubs = {}
def maybe_quantize_weight(
mod: nn.Module,
weight: torch.Tensor,
) -> torch.Tensor:
"""Fake-quantize *weight* through ``mod.weight_fake_quant`` if present."""
wfq = getattr(mod, "weight_fake_quant", None)
return wfq(weight) if wfq is not None else weight
[docs]
class FusedLinear(FusedModule):
"""Fused wrapper for a standalone ``Linear`` layer.
:param linear:
The ``nn.Linear`` from the matched chain.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, linear: nn.Linear) -> None:
super().__init__()
self.linear = linear
attach_int8_weight_quant(self, linear)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``linear``, fake-quantizing the weight."""
weight = maybe_quantize_weight(self, self.linear.weight)
# pylint: disable-next=not-callable
return F.linear(x, weight, self.linear.bias)
def __repr__(self) -> str: # pragma: no cover
return (
f"FusedLinear("
f"{self.linear.in_features}→{self.linear.out_features})"
)
[docs]
class FusedLinearAct(FusedModule):
"""Fused ``Linear → Activation``.
:param linear:
The ``nn.Linear`` from the matched chain.
:param act:
The activation module from the matched chain.
"""
inputs_to_quantize: set[int] = {0}
def __init__(self, linear: nn.Linear, act: ActivationLike) -> None:
super().__init__()
self.linear = linear
self.act = act
attach_int8_weight_quant(self, linear)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``linear → activation``, fake-quantizing the weight."""
weight = maybe_quantize_weight(self, self.linear.weight)
# pylint: disable-next=not-callable
x = F.linear(x, weight, self.linear.bias)
return self.act(x)
def __repr__(self) -> str: # pragma: no cover
act_name = type(self.act).__name__
return (
f"FusedLinearAct("
f"{self.linear.in_features}→{self.linear.out_features}, "
f"act={act_name})"
)
FusedLinearLike: TypeAlias = FusedLinear | FusedLinearAct
[docs]
class FusedLayerNorm(FusedModule):
"""Fused wrapper for a standalone ``LayerNorm``.
Weight quantization is disabled by default. LayerNorm's learnable
``weight`` is an element-wise affine scale, so quantizing it yields
negligible savings while hurting accuracy.
:param layer_norm:
The ``nn.LayerNorm`` from the matched chain.
"""
prefers_fp_input: bool = True
inputs_to_quantize: set[int] = set()
def __init__(self, layer_norm: nn.LayerNorm) -> None:
super().__init__()
self.layer_norm = layer_norm
self.smooth_quant_observer = SmoothQuantObserver(
consumers={self},
layer_norm=layer_norm,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply ``layer_norm``."""
return self.layer_norm(x)
def __repr__(self) -> str: # pragma: no cover
return (
f"FusedLayerNorm("
f"normalized_shape={self.layer_norm.normalized_shape}, "
f"eps={self.layer_norm.eps})"
)