Source code for embedl_deploy._internal.core.plan
# Copyright (C) 2026 Embedl AB
"""Implementation of ``TransformationPlan`` functionality.
The plan-based workflow lets users inspect and edit pattern matches before
applying them to the model.
"""
import copy
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TypedDict
import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp
from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy._internal.core.pattern import Pattern, PatternMatch
[docs]
@dataclass
class TransformationPlan:
"""Editable transformation plan.
Returned by :func:`~embedl_deploy.get_transformation_plan`. The
``matches`` dict maps ``input_node_name → pattern_class_name →
PatternMatch``. Toggle ``match.apply = False`` to skip specific matches
before calling :func:`~embedl_deploy.apply_transformation_plan`.
"""
#: Deep copy of the original graph (not yet modified by replacements).
model: fx.GraphModule
#: Nested dict of discovered matches, keyed by the last matched
#: node's name and the pattern class name.
matches: dict[str, dict[str, PatternMatch]] = field(default_factory=dict)
class TransformationReport(TypedDict):
"""Summary of a :func:`apply_transformation_plan` run."""
applied: list[str]
skipped: list[str]
applied_count: int
skipped_count: int
total_count: int
[docs]
@dataclass
class TransformationResult:
"""Result of applying a transformation plan.
Returned by :func:`~embedl_deploy.apply_transformation_plan`.
"""
#: The transformed model with fused / quantized modules.
model: fx.GraphModule
#: Summary of what was applied and what was skipped.
report: TransformationReport
#: The actual :class:`PatternMatch` objects that were applied.
matches: list[PatternMatch]
def _matches_to_dict(
matches: list[PatternMatch],
) -> dict[str, dict[str, PatternMatch]]:
"""Convert a flat match list to the nested-dict representation."""
result: dict[str, dict[str, PatternMatch]] = {}
for pm in matches:
node_name = pm.tree_match.get_tree_nodes()[-1].name
result.setdefault(node_name, {})
pattern_name = type(pm.pattern).__name__
if pattern_name in result[node_name]:
raise ValueError(
f"Duplicate match for ({node_name!r}, {pattern_name!r}); "
"the nested-dict representation cannot hold both"
)
result[node_name][pattern_name] = pm
return result
def _dict_to_matches(
match_dict: dict[str, dict[str, PatternMatch]],
) -> list[PatternMatch]:
"""Flatten the nested dict back to a list (preserves insertion order)."""
return [pm for per_node in match_dict.values() for pm in per_node.values()]
def _build_report(
enabled: list[PatternMatch],
skipped: list[PatternMatch],
) -> TransformationReport:
"""Build the summary dict for a TransformationResult."""
return {
"applied": [repr(m) for m in enabled],
"skipped": [repr(m) for m in skipped],
"applied_count": len(enabled),
"skipped_count": len(skipped),
"total_count": len(enabled) + len(skipped),
}
[docs]
def get_transformation_plan(
graph_module: fx.GraphModule,
patterns: Sequence[Pattern],
) -> TransformationPlan:
"""Find all non-overlapping tree-based pattern matches.
Deep-copies `graph_module` so the original is never modified. Each
pattern's :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` must
return :class:`~embedl_deploy._internal.core.pattern.PatternMatch` objects
with a populated ``tree_match``. Overlapping matches are resolved by
marking later overlaps as ``apply=False``.
Returns a :class:`~embedl_deploy.TransformationPlan` that can be inspected
and edited before calling :func:`~embedl_deploy.apply_transformation_plan`.
:param graph_module:
The traced graph module to analyse.
:param patterns:
Patterns to search for. Order matters: patterns are matched in
sequence, and earlier matches claim nodes first. Supply longest (most
specific) patterns first to ensure they take priority over shorter ones
when sub-graphs overlap.
:return:
A :class:`~embedl_deploy.TransformationPlan` containing `graph_module`
(a deep copy) and ``matches`` (a nested dict of discovered pattern
matches).
Example:
.. code-block:: python
from torch import fx
from embedl_deploy import get_transformation_plan
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
graph_module = fx.symbolic_trace(model)
plan = get_transformation_plan(
graph_module, patterns=TENSORRT_PATTERNS
)
for node, pats in plan.matches.items():
for name, match in pats.items():
print(f"{node}: {name} apply={match.apply}")
"""
if not getattr(graph_module, "_deep_copy_done", False):
graph_module = copy.deepcopy(graph_module)
setattr(graph_module, "_deep_copy_done", True)
pattern_matches: list[PatternMatch] = []
for pattern in patterns:
pattern_matches.extend(pattern.match(graph_module))
consumed: set[fx.Node] = set()
for pm in pattern_matches:
tree_nodes = set(pm.tree_match.get_tree_nodes())
if consumed.intersection(tree_nodes):
pm.apply = False
continue
pm.apply = True
consumed.update(tree_nodes)
return TransformationPlan(
model=graph_module,
matches=_matches_to_dict(pattern_matches),
)
[docs]
def apply_transformation_plan(
plan: TransformationPlan,
) -> TransformationResult:
"""Apply the enabled matches from `plan`.
Only matches with ``apply=True`` are applied via
:meth:`~embedl_deploy._internal.core.pattern.Pattern.replace`. The plan's
model is modified in place (it is already a deep copy created by
:func:`~embedl_deploy.get_transformation_plan`).
:param plan:
The plan to apply (from
:func:`~embedl_deploy.get_transformation_plan`).
:return:
A :class:`~embedl_deploy.TransformationResult` containing ``model``
(transformed), ``report`` (summary), and ``matches`` (applied matches).
:raises ValueError:
If any nodes are included in more than one enabled pattern.
Example:
.. code-block:: python
result = apply_transformation_plan(plan)
print(result.report)
torch.onnx.export(result.model, x, "deployed.onnx")
"""
graph_module = plan.model
pattern_matches = _dict_to_matches(plan.matches)
enabled = [pm for pm in pattern_matches if pm.apply]
skipped = [pm for pm in pattern_matches if not pm.apply]
consumed: set[fx.Node] = set()
for pm in enabled:
tree_nodes = set(pm.tree_match.get_tree_nodes())
if inter := consumed.intersection(tree_nodes):
msg = f"Nodes: {inter} included in more than one enabled pattern."
raise ValueError(msg)
consumed.update(tree_nodes)
for pm in enabled:
pm.pattern.replace(pm)
if enabled:
graph_module.graph.lint() # type: ignore[no-untyped-call]
graph_module.recompile()
graph_module.eval()
input_node = next(iter(graph_module.graph.nodes))
meta = input_node.meta.get("tensor_meta")
if meta is not None and hasattr(meta, "shape"):
ShapeProp(graph_module).propagate(torch.randn(meta.shape)) # type: ignore[no-untyped-call]
report = _build_report(enabled, skipped)
return TransformationResult(
model=graph_module,
report=report,
matches=enabled,
)
[docs]
def transform(
model: nn.Module | fx.GraphModule,
patterns: Sequence[Pattern],
) -> TransformationResult:
"""Apply pattern transformations to `model` in one step.
Conversion patterns (``is_conversion=True``) are applied iteratively until
no new matches are found, then fusion patterns are matched and applied in a
single pass. The original model is never modified — a deep copy is made
internally by :func:`~embedl_deploy.get_transformation_plan`.
:param model:
The model to transform. If model is a ``nn.Module`` it will be traced
with :func:`~torch.fx.symbolic_trace`.
:param patterns:
Patterns to match and apply. Order matters: patterns are matched in
sequence, and earlier matches claim nodes first. Supply longest (most
specific) patterns first to ensure they take priority over shorter ones
when sub-graphs overlap.
:return:
A :class:`~embedl_deploy.TransformationResult` containing ``model``
(transformed), ``report`` (summary), and ``matches`` (applied matches).
Example:
.. code-block:: python
from embedl_deploy import transform
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model
"""
graph_module = (
model if isinstance(model, fx.GraphModule) else symbolic_trace(model)
)
conversions = [p for p in patterns if p.is_conversion]
fusions = [p for p in patterns if not p.is_conversion]
if conversions:
while True:
conv_plan = get_transformation_plan(graph_module, conversions)
if not conv_plan.matches:
break
conv_result = apply_transformation_plan(conv_plan)
graph_module = conv_result.model
fuse_plan = get_transformation_plan(graph_module, fusions)
fuse_result = apply_transformation_plan(fuse_plan)
return fuse_result