Source code for embedl_deploy._internal.tensorrt.patterns.fusions.linear

# Copyright (C) 2026 Embedl AB

"""Fusion ``Pattern`` subclasses for Linear and LayerNorm."""

from typing import cast

from torch import fx, nn

from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.pattern import (
    Pattern,
    PatternMatch,
    QDQPoint,
    Tree,
    resolve_module,
)
from embedl_deploy._internal.core.replace import (
    replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.linear import (
    FusedLayerNorm,
    FusedLinear,
    FusedLinearAct,
)


[docs] class LinearActPattern(Pattern): """Match ``Linear → Activation`` and fuse. Any activation included in :data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted. """ tree: Tree = (nn.Linear, ActivationLike) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match linear = resolve_module(tree_match.get_node(0), nn.Linear) act = cast(ActivationLike, resolve_module(tree_match.get_node(1))) fused_module = FusedLinearAct(linear, act) return replace_tree(pattern_match, [fused_module])
[docs] class LinearPattern(Pattern): """Match a standalone ``Linear`` and wrap in a fused module.""" tree: Tree = (nn.Linear,) qdq_points = frozenset({QDQPoint.INPUT})
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match linear = resolve_module(tree_match.get_node(0), nn.Linear) fused_module = FusedLinear(linear) return replace_tree(pattern_match, [fused_module])
[docs] class LayerNormPattern(Pattern): """Match ``LayerNorm`` and wrap in a fused module. No Q/DQ stubs are placed around LayerNorm (empty ``qdq_points``). """ tree: Tree = (nn.LayerNorm,) qdq_points = frozenset()
[docs] def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]: return match_tree( graph_module, pattern=self, )
[docs] def replace(self, pattern_match: PatternMatch) -> list[fx.Node]: assert pattern_match.pattern is self tree_match = pattern_match.tree_match layer_norm = resolve_module(tree_match.get_node(0), nn.LayerNorm) fused_module = FusedLayerNorm(layer_norm) return replace_tree(pattern_match, [fused_module])