Source code for embedl_deploy._internal.core.plan
# Copyright (C) 2026 Embedl AB
"""Implementation of ``TransformationPlan`` functionality.
The plan-based workflow lets users inspect and edit pattern matches before
applying them to the model.
"""
import copy
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TypedDict
import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp
from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy._internal.core.patterns.main import (
Pattern,
PatternMatch,
Phase,
)
from embedl_deploy._internal.core.tree.state import ReplaceSession
from embedl_deploy._internal.core.tree.utils import remove_orphaned_modules
_LOG = logging.getLogger(__name__)
[docs]
@dataclass
class TransformationPlan:
"""Editable transformation plan.
Returned by :func:`~embedl_deploy.get_transformation_plan`. The
``matches`` dict maps ``input_node_name → pattern_class_name →
PatternMatch``. Toggle ``match.apply = False`` to skip specific matches
before calling :func:`~embedl_deploy.apply_transformation_plan`.
"""
#: Deep copy of the original graph (not yet modified by replacements).
model: fx.GraphModule
#: Nested dict of discovered matches, keyed by the last matched
#: node's name and the pattern class name.
matches: dict[str, dict[str, PatternMatch]] = field(default_factory=dict)
class TransformationReport(TypedDict):
"""Summary of an ``apply_transformation_plan`` run."""
applied: list[str]
skipped: list[str]
applied_count: int
skipped_count: int
total_count: int
[docs]
@dataclass
class TransformationResult:
"""Result of applying a transformation plan.
Returned by :func:`~embedl_deploy.apply_transformation_plan`.
"""
#: The transformed model with fused / quantized modules.
model: fx.GraphModule
#: Summary of what was applied and what was skipped.
report: TransformationReport
#: The actual :class:`~embedl_deploy._internal.core.pattern.PatternMatch`
#: objects that were applied.
matches: list[PatternMatch]
def _matches_to_dict(
matches: list[PatternMatch],
) -> dict[str, dict[str, PatternMatch]]:
"""Convert a flat match list to the nested-dict representation."""
result: dict[str, dict[str, PatternMatch]] = {}
for pm in matches:
node_name = pm.tree_match.get_tree_nodes()[-1].name
result.setdefault(node_name, {})
pattern_name = type(pm.pattern).__name__
if pattern_name in result[node_name]:
raise ValueError(
f"Duplicate match for ({node_name!r}, {pattern_name!r}); "
"the nested-dict representation cannot hold both"
)
result[node_name][pattern_name] = pm
return result
def _dict_to_matches(
match_dict: dict[str, dict[str, PatternMatch]],
) -> list[PatternMatch]:
"""Flatten the nested dict back to a list (preserves insertion order)."""
return [pm for per_node in match_dict.values() for pm in per_node.values()]
def _build_report(
enabled: list[PatternMatch],
skipped: list[PatternMatch],
) -> TransformationReport:
"""Build the summary dict for a TransformationResult."""
return {
"applied": [repr(m) for m in enabled],
"skipped": [repr(m) for m in skipped],
"applied_count": len(enabled),
"skipped_count": len(skipped),
"total_count": len(enabled) + len(skipped),
}
[docs]
def get_transformation_plan(
graph_module: fx.GraphModule,
patterns: Sequence[Pattern],
*,
deep_copy: bool = True,
) -> TransformationPlan:
"""Find all non-overlapping tree-based pattern matches.
Deep-copies `graph_module` by default so the original is never modified.
Pass ``deep_copy=False`` to operate on the graph in place (e.g. when the
caller already holds a copy). Each pattern's
:meth:`~embedl_deploy._internal.core.pattern.Pattern.match` must return
:class:`~embedl_deploy._internal.core.pattern.PatternMatch` objects with a
populated ``tree_match``. Overlapping matches are resolved by marking later
overlaps as ``apply=False``.
Returns a :class:`~embedl_deploy.TransformationPlan` that can be inspected
and edited before calling :func:`~embedl_deploy.apply_transformation_plan`.
:param graph_module:
The traced graph module to analyze.
:param patterns:
Patterns to search for. Order matters: patterns are matched in
sequence, and earlier matches claim nodes first. Supply longest (most
specific) patterns first to ensure they take priority over shorter ones
when sub-graphs overlap.
:param deep_copy:
When ``True`` (the default), `graph_module` is deep-copied before
matching so the original is never modified. Set to ``False`` when
operating on a model that has already been copied (e.g. inside
quantization passes).
:returns:
A :class:`~embedl_deploy.TransformationPlan` containing `graph_module`
and ``matches`` (a nested dict of discovered pattern matches).
Example:
.. code-block:: python
from torch import fx
from embedl_deploy import get_transformation_plan
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
graph_module = fx.symbolic_trace(model)
plan = get_transformation_plan(
graph_module, patterns=TENSORRT_PATTERNS
)
for node, pats in plan.matches.items():
for name, match in pats.items():
print(f"{node}: {name} apply={match.apply}")
"""
if deep_copy:
graph_module = copy.deepcopy(graph_module)
# Strip torch.export shape-guard nodes that ShapeProp cannot evaluate.
guards = [
n
for n in graph_module.graph.nodes
if n.op == "call_module" and n.name.startswith("_guards")
]
for node in guards:
node.replace_all_uses_with(next(iter(node.args)))
graph_module.graph.erase_node(node)
if guards:
graph_module.recompile()
pattern_matches: list[PatternMatch] = []
for pattern in patterns:
pattern_matches.extend(pattern.match(graph_module))
consumed: dict[fx.Node, str] = {}
for pm in pattern_matches:
name = type(pm.pattern).__name__
tree_nodes = pm.tree_match.get_tree_nodes()
overlap = set(tree_nodes) & consumed.keys()
if overlap:
pm.apply = False
blockers = {consumed[n] for n in overlap}
_LOG.debug(
"Skipping %s match at %s: %d node(s) already consumed by %s",
name,
tree_nodes[-1].name,
len(overlap),
", ".join(sorted(blockers)),
)
continue
pm.apply = True
for node in tree_nodes:
consumed[node] = name
return TransformationPlan(
model=graph_module,
matches=_matches_to_dict(pattern_matches),
)
def _propagate_shapes(graph_module: fx.GraphModule) -> None:
"""Re-propagate tensor shapes after graph surgery.
Builds fake inputs from placeholder ``tensor_meta`` and runs
:class:`~torch.fx.passes.shape_prop.ShapeProp`. Pins tensors to the
graph's parameter device so ``torch.export``'d graphs with
device-dispatched ops (e.g. SDPA) don't crash on cross-device tensors.
:param graph_module:
The graph module whose shapes should be refreshed.
"""
try:
device = next(graph_module.parameters()).device
except StopIteration:
device = torch.device("cpu")
# Patterns may register new submodules (QuantStub, FusedX) whose
# buffers default to CPU. Sync to the graph's device so ShapeProp
# doesn't hit a mixed-device forward.
if device.type != "cpu":
graph_module.to(device)
fake_args: list[torch.Tensor] = []
for n in graph_module.graph.nodes:
if n.op != "placeholder":
continue
meta = n.meta.get("tensor_meta")
if meta is None or not hasattr(meta, "shape"):
fake_args.clear()
break
dtype = getattr(meta, "dtype", torch.float32)
if dtype.is_floating_point:
fake_args.append(
torch.randn(meta.shape, dtype=dtype, device=device)
)
else:
fake_args.append(
torch.zeros(meta.shape, dtype=dtype, device=device)
)
if fake_args:
# `no_grad` keeps ShapeProp from materialising an autograd tape
# for the whole forward pass — for large transformer graphs
# (SAM3, ViT-L, …) the activation tape can blow GPU memory.
with torch.no_grad():
# ShapeProp.propagate lacks type stubs in torch
ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
[docs]
def apply_transformation_plan(
plan: TransformationPlan,
) -> TransformationResult:
"""Apply the enabled matches from `plan`.
Only matches with ``apply=True`` are applied via
:meth:`~embedl_deploy._internal.core.pattern.Pattern.replace`. The
plan's model is modified in place (it is already a deep copy created
by :func:`~embedl_deploy.get_transformation_plan`). After
replacement, dead code and orphaned submodules are removed, the
graph is linted and recompiled, the model is put into eval mode,
and shape metadata is re-propagated when available.
:param plan:
The plan to apply (from
:func:`~embedl_deploy.get_transformation_plan`).
:returns:
A :class:`~embedl_deploy.TransformationResult` containing ``model``
(transformed), ``report`` (summary), and ``matches`` (applied matches).
:raises ValueError:
If any nodes are included in more than one enabled pattern.
Example:
.. code-block:: python
result = apply_transformation_plan(plan)
print(result.report)
torch.onnx.export(result.model, x, "deployed.onnx")
"""
graph_module = plan.model
pattern_matches = _dict_to_matches(plan.matches)
enabled = [pm for pm in pattern_matches if pm.apply]
skipped = [pm for pm in pattern_matches if not pm.apply]
consumed: set[fx.Node] = set()
for pm in enabled:
tree_nodes = set(pm.tree_match.get_tree_nodes())
if inter := consumed.intersection(tree_nodes):
msg = f"Nodes: {inter} included in more than one enabled pattern."
raise ValueError(msg)
consumed.update(tree_nodes)
with ReplaceSession():
for pm in enabled:
pm.pattern.replace(pm)
if enabled:
graph_module.graph.eliminate_dead_code()
remove_orphaned_modules(graph_module)
# Graph.lint lacks type stubs in torch
graph_module.graph.lint() # type: ignore[no-untyped-call]
graph_module.recompile()
graph_module.eval()
_propagate_shapes(graph_module)
report = _build_report(enabled, skipped)
return TransformationResult(
model=graph_module,
report=report,
matches=enabled,
)
[docs]
def transform(
model: nn.Module | fx.GraphModule,
patterns: Sequence[Pattern],
) -> TransformationResult:
"""Apply pattern transformations to `model` in one step.
Recomposition and conversion patterns are applied iteratively until no new
matches are found, then fusion patterns are matched and applied in a single
pass. The original model is deep-copied on the first processing call and
never modified.
:param model:
The model to transform. If model is a ``nn.Module`` it will be traced
with :func:`~torch.fx.symbolic_trace`.
:param patterns:
Patterns to match and apply. Order matters: patterns are matched in
sequence, and earlier matches claim nodes first. Supply longest (most
specific) patterns first to ensure they take priority over shorter ones
when sub-graphs overlap.
:returns:
A :class:`~embedl_deploy.TransformationResult` containing ``model``
(transformed), ``report`` (summary), and ``matches`` (applied matches).
Example:
.. code-block:: python
from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model
"""
graph_module = (
model if isinstance(model, fx.GraphModule) else symbolic_trace(model)
)
_iterative = (Phase.RECOMPOSITION, Phase.CONVERSION)
conversions = [p for p in patterns if p.phase in _iterative]
fusions = [p for p in patterns if p.phase == Phase.FUSION]
model_is_original = True
if conversions:
while True:
conv_plan = get_transformation_plan(
graph_module,
conversions,
deep_copy=model_is_original,
)
model_is_original = False
if not conv_plan.matches:
break
conv_result = apply_transformation_plan(conv_plan)
graph_module = conv_result.model
fuse_plan = get_transformation_plan(
graph_module,
fusions,
deep_copy=model_is_original,
)
fuse_result = apply_transformation_plan(fuse_plan)
return fuse_result