Source code for embedl_deploy._internal.tensorrt.modules.linear

# Copyright (C) 2026 Embedl AB

"""Fused ``nn.Module`` replacements for Linear and LayerNorm patterns."""

from typing import TypeAlias

import torch
import torch.nn.functional as F
from torch import nn

from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
from embedl_deploy._internal.core.quantize.stubs import (
    SmoothQuantObserver,
    WeightFakeQuantize,
)

#: Minimum ``K * N / (K + N)`` for INT8 to outperform FP16.
INT8_LINEAR_MIN_RATIO: int = 256


def is_int8_beneficial_linear(linear: nn.Linear) -> bool:
    """Return ``True`` when INT8 quantization benefits *linear*.

    Uses the harmonic mean of the weight dimensions ``K * N / (K + N)``
    as a proxy for the ratio of INT8 compute savings to Q/DQ reformat
    overhead.  Below :data:`INT8_LINEAR_MIN_RATIO`, the overhead from
    quantize/dequantize boundary layers exceeds any INT8 GEMM speedup
    and the layer is better left in FP16.

    Reference: NVIDIA benchmarks show INT8 GEMM outperforms FP16 only
    when all three matrix dimensions exceed ~2048 (A100).  The harmonic
    mean threshold of 256 conservatively separates mobile-class models
    (MobileViT FFN ratio ≤ 160) from server-class models (ViT-B/16 FFN
    ratio = 614) where INT8 is beneficial.
    """
    k, n = linear.in_features, linear.out_features
    return k * n / (k + n) >= INT8_LINEAR_MIN_RATIO


def attach_int8_weight_quant(
    mod: FusedModule,
    linear: nn.Linear,
) -> None:
    """Attach a ``WeightFakeQuantize`` to *mod* when INT8 helps *linear*.

    When INT8 wouldn't pay for its Q/DQ boundary cost, also clear
    ``mod.input_quant_stubs`` so the surrounding Q/DQ pass leaves the
    wrapped linear entirely in FP16.
    """
    if is_int8_beneficial_linear(linear):
        mod.weight_fake_quant = WeightFakeQuantize({mod})
    else:
        mod.input_quant_stubs = {}


def maybe_quantize_weight(
    mod: nn.Module,
    weight: torch.Tensor,
) -> torch.Tensor:
    """Fake-quantize *weight* through ``mod.weight_fake_quant`` if present."""
    wfq = getattr(mod, "weight_fake_quant", None)
    return wfq(weight) if wfq is not None else weight


[docs] class FusedLinear(FusedModule): """Fused wrapper for a standalone ``Linear`` layer. :param linear: The ``nn.Linear`` from the matched chain. """ inputs_to_quantize: set[int] = {0} def __init__(self, linear: nn.Linear) -> None: super().__init__() self.linear = linear attach_int8_weight_quant(self, linear)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``linear``, fake-quantizing the weight.""" weight = maybe_quantize_weight(self, self.linear.weight) # pylint: disable-next=not-callable return F.linear(x, weight, self.linear.bias)
def __repr__(self) -> str: # pragma: no cover return ( f"FusedLinear(" f"{self.linear.in_features}{self.linear.out_features})" )
[docs] class FusedLinearAct(FusedModule): """Fused ``Linear → Activation``. :param linear: The ``nn.Linear`` from the matched chain. :param act: The activation module from the matched chain. """ inputs_to_quantize: set[int] = {0} def __init__(self, linear: nn.Linear, act: ActivationLike) -> None: super().__init__() self.linear = linear self.act = act attach_int8_weight_quant(self, linear)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``linear → activation``, fake-quantizing the weight.""" weight = maybe_quantize_weight(self, self.linear.weight) # pylint: disable-next=not-callable x = F.linear(x, weight, self.linear.bias) return self.act(x)
def __repr__(self) -> str: # pragma: no cover act_name = type(self.act).__name__ return ( f"FusedLinearAct(" f"{self.linear.in_features}{self.linear.out_features}, " f"act={act_name})" )
FusedLinearLike: TypeAlias = FusedLinear | FusedLinearAct
[docs] class FusedLayerNorm(FusedModule): """Fused wrapper for a standalone ``LayerNorm``. Weight quantization is disabled by default. LayerNorm's learnable ``weight`` is an element-wise affine scale, so quantizing it yields negligible savings while hurting accuracy. :param layer_norm: The ``nn.LayerNorm`` from the matched chain. """ prefers_fp_input: bool = True inputs_to_quantize: set[int] = set() def __init__(self, layer_norm: nn.LayerNorm) -> None: super().__init__() self.layer_norm = layer_norm self.smooth_quant_observer = SmoothQuantObserver( consumers={self}, layer_norm=layer_norm, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply ``layer_norm``.""" return self.layer_norm(x)
def __repr__(self) -> str: # pragma: no cover return ( f"FusedLayerNorm(" f"normalized_shape={self.layer_norm.normalized_shape}, " f"eps={self.layer_norm.eps})" )