Source code for embedl_deploy._internal.tensorrt.patterns.fusions.linear
# Copyright (C) 2026 Embedl AB
"""Fusion ``Pattern`` subclasses for Linear and LayerNorm."""
from typing import cast
from torch import fx, nn
from embedl_deploy._internal.core.match import match_tree
from embedl_deploy._internal.core.modules import ActivationLike
from embedl_deploy._internal.core.pattern import (
Pattern,
PatternMatch,
QDQPoint,
Tree,
resolve_module,
)
from embedl_deploy._internal.core.replace import (
replace_tree,
)
from embedl_deploy._internal.tensorrt.modules.linear import (
FusedLayerNorm,
FusedLinear,
FusedLinearAct,
)
[docs]
class LinearActPattern(Pattern):
"""Match ``Linear → Activation`` and fuse.
Any activation included in
:data:`~embedl_deploy._internal.core.modules.ActivationLike` is accepted.
"""
tree: Tree = (nn.Linear, ActivationLike)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
linear = resolve_module(tree_match.get_node(0), nn.Linear)
act = cast(ActivationLike, resolve_module(tree_match.get_node(1)))
fused_module = FusedLinearAct(linear, act)
return replace_tree(pattern_match, [fused_module])
[docs]
class LinearPattern(Pattern):
"""Match a standalone ``Linear`` and wrap in a fused module."""
tree: Tree = (nn.Linear,)
qdq_points = frozenset({QDQPoint.INPUT})
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
linear = resolve_module(tree_match.get_node(0), nn.Linear)
fused_module = FusedLinear(linear)
return replace_tree(pattern_match, [fused_module])
[docs]
class LayerNormPattern(Pattern):
"""Match ``LayerNorm`` and wrap in a fused module.
No Q/DQ stubs are placed around LayerNorm (empty ``qdq_points``).
"""
tree: Tree = (nn.LayerNorm,)
qdq_points = frozenset()
[docs]
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
return match_tree(
graph_module,
pattern=self,
)
[docs]
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
assert pattern_match.pattern is self
tree_match = pattern_match.tree_match
layer_norm = resolve_module(tree_match.get_node(0), nn.LayerNorm)
fused_module = FusedLayerNorm(layer_norm)
return replace_tree(pattern_match, [fused_module])