Source code for embedl_deploy._internal.core.plan

# Copyright (C) 2026 Embedl AB

"""Implementation of ``TransformationPlan`` functionality.

The plan-based workflow lets users inspect and edit pattern matches before
applying them to the model.
"""

import copy
import logging
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TypedDict

import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy._internal.core.patterns.main import (
    Pattern,
    PatternMatch,
    Phase,
)
from embedl_deploy._internal.core.tree.state import ReplaceSession
from embedl_deploy._internal.core.tree.utils import remove_orphaned_modules

_LOG = logging.getLogger(__name__)


[docs] @dataclass class TransformationPlan: """Editable transformation plan. Returned by :func:`~embedl_deploy.get_transformation_plan`. The ``matches`` dict maps ``input_node_name → pattern_class_name → PatternMatch``. Toggle ``match.apply = False`` to skip specific matches before calling :func:`~embedl_deploy.apply_transformation_plan`. """ #: Deep copy of the original graph (not yet modified by replacements). model: fx.GraphModule #: Nested dict of discovered matches, keyed by the last matched #: node's name and the pattern class name. matches: dict[str, dict[str, PatternMatch]] = field(default_factory=dict)
class TransformationReport(TypedDict): """Summary of an ``apply_transformation_plan`` run.""" applied: list[str] skipped: list[str] applied_count: int skipped_count: int total_count: int
[docs] @dataclass class TransformationResult: """Result of applying a transformation plan. Returned by :func:`~embedl_deploy.apply_transformation_plan`. """ #: The transformed model with fused / quantized modules. model: fx.GraphModule #: Summary of what was applied and what was skipped. report: TransformationReport #: The actual :class:`~embedl_deploy._internal.core.pattern.PatternMatch` #: objects that were applied. matches: list[PatternMatch]
def _matches_to_dict( matches: list[PatternMatch], ) -> dict[str, dict[str, PatternMatch]]: """Convert a flat match list to the nested-dict representation.""" result: dict[str, dict[str, PatternMatch]] = {} for pm in matches: node_name = pm.tree_match.get_tree_nodes()[-1].name result.setdefault(node_name, {}) pattern_name = type(pm.pattern).__name__ if pattern_name in result[node_name]: raise ValueError( f"Duplicate match for ({node_name!r}, {pattern_name!r}); " "the nested-dict representation cannot hold both" ) result[node_name][pattern_name] = pm return result def _dict_to_matches( match_dict: dict[str, dict[str, PatternMatch]], ) -> list[PatternMatch]: """Flatten the nested dict back to a list (preserves insertion order).""" return [pm for per_node in match_dict.values() for pm in per_node.values()] def _build_report( enabled: list[PatternMatch], skipped: list[PatternMatch], ) -> TransformationReport: """Build the summary dict for a TransformationResult.""" return { "applied": [repr(m) for m in enabled], "skipped": [repr(m) for m in skipped], "applied_count": len(enabled), "skipped_count": len(skipped), "total_count": len(enabled) + len(skipped), }
[docs] def get_transformation_plan( graph_module: fx.GraphModule, patterns: Sequence[Pattern], *, deep_copy: bool = True, ) -> TransformationPlan: """Find all non-overlapping tree-based pattern matches. Deep-copies `graph_module` by default so the original is never modified. Pass ``deep_copy=False`` to operate on the graph in place (e.g. when the caller already holds a copy). Each pattern's :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` must return :class:`~embedl_deploy._internal.core.pattern.PatternMatch` objects with a populated ``tree_match``. Overlapping matches are resolved by marking later overlaps as ``apply=False``. Returns a :class:`~embedl_deploy.TransformationPlan` that can be inspected and edited before calling :func:`~embedl_deploy.apply_transformation_plan`. :param graph_module: The traced graph module to analyze. :param patterns: Patterns to search for. Order matters: patterns are matched in sequence, and earlier matches claim nodes first. Supply longest (most specific) patterns first to ensure they take priority over shorter ones when sub-graphs overlap. :param deep_copy: When ``True`` (the default), `graph_module` is deep-copied before matching so the original is never modified. Set to ``False`` when operating on a model that has already been copied (e.g. inside quantization passes). :returns: A :class:`~embedl_deploy.TransformationPlan` containing `graph_module` and ``matches`` (a nested dict of discovered pattern matches). Example: .. code-block:: python from torch import fx from embedl_deploy import get_transformation_plan from embedl_deploy.tensorrt import TENSORRT_PATTERNS graph_module = fx.symbolic_trace(model) plan = get_transformation_plan( graph_module, patterns=TENSORRT_PATTERNS ) for node, pats in plan.matches.items(): for name, match in pats.items(): print(f"{node}: {name} apply={match.apply}") """ if deep_copy: graph_module = copy.deepcopy(graph_module) # Strip torch.export shape-guard nodes that ShapeProp cannot evaluate. guards = [ n for n in graph_module.graph.nodes if n.op == "call_module" and n.name.startswith("_guards") ] for node in guards: node.replace_all_uses_with(next(iter(node.args))) graph_module.graph.erase_node(node) if guards: graph_module.recompile() pattern_matches: list[PatternMatch] = [] for pattern in patterns: pattern_matches.extend(pattern.match(graph_module)) consumed: dict[fx.Node, str] = {} for pm in pattern_matches: name = type(pm.pattern).__name__ tree_nodes = pm.tree_match.get_tree_nodes() overlap = set(tree_nodes) & consumed.keys() if overlap: pm.apply = False blockers = {consumed[n] for n in overlap} _LOG.debug( "Skipping %s match at %s: %d node(s) already consumed by %s", name, tree_nodes[-1].name, len(overlap), ", ".join(sorted(blockers)), ) continue pm.apply = True for node in tree_nodes: consumed[node] = name return TransformationPlan( model=graph_module, matches=_matches_to_dict(pattern_matches), )
def _propagate_shapes(graph_module: fx.GraphModule) -> None: """Re-propagate tensor shapes after graph surgery. Builds fake inputs from placeholder ``tensor_meta`` and runs :class:`~torch.fx.passes.shape_prop.ShapeProp`. Pins tensors to the graph's parameter device so ``torch.export``'d graphs with device-dispatched ops (e.g. SDPA) don't crash on cross-device tensors. :param graph_module: The graph module whose shapes should be refreshed. """ try: device = next(graph_module.parameters()).device except StopIteration: device = torch.device("cpu") # Patterns may register new submodules (QuantStub, FusedX) whose # buffers default to CPU. Sync to the graph's device so ShapeProp # doesn't hit a mixed-device forward. if device.type != "cpu": graph_module.to(device) fake_args: list[torch.Tensor] = [] for n in graph_module.graph.nodes: if n.op != "placeholder": continue meta = n.meta.get("tensor_meta") if meta is None or not hasattr(meta, "shape"): fake_args.clear() break dtype = getattr(meta, "dtype", torch.float32) if dtype.is_floating_point: fake_args.append( torch.randn(meta.shape, dtype=dtype, device=device) ) else: fake_args.append( torch.zeros(meta.shape, dtype=dtype, device=device) ) if fake_args: # `no_grad` keeps ShapeProp from materialising an autograd tape # for the whole forward pass — for large transformer graphs # (SAM3, ViT-L, …) the activation tape can blow GPU memory. with torch.no_grad(): # ShapeProp.propagate lacks type stubs in torch ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
[docs] def apply_transformation_plan( plan: TransformationPlan, ) -> TransformationResult: """Apply the enabled matches from `plan`. Only matches with ``apply=True`` are applied via :meth:`~embedl_deploy._internal.core.pattern.Pattern.replace`. The plan's model is modified in place (it is already a deep copy created by :func:`~embedl_deploy.get_transformation_plan`). After replacement, dead code and orphaned submodules are removed, the graph is linted and recompiled, the model is put into eval mode, and shape metadata is re-propagated when available. :param plan: The plan to apply (from :func:`~embedl_deploy.get_transformation_plan`). :returns: A :class:`~embedl_deploy.TransformationResult` containing ``model`` (transformed), ``report`` (summary), and ``matches`` (applied matches). :raises ValueError: If any nodes are included in more than one enabled pattern. Example: .. code-block:: python result = apply_transformation_plan(plan) print(result.report) torch.onnx.export(result.model, x, "deployed.onnx") """ graph_module = plan.model pattern_matches = _dict_to_matches(plan.matches) enabled = [pm for pm in pattern_matches if pm.apply] skipped = [pm for pm in pattern_matches if not pm.apply] consumed: set[fx.Node] = set() for pm in enabled: tree_nodes = set(pm.tree_match.get_tree_nodes()) if inter := consumed.intersection(tree_nodes): msg = f"Nodes: {inter} included in more than one enabled pattern." raise ValueError(msg) consumed.update(tree_nodes) with ReplaceSession(): for pm in enabled: pm.pattern.replace(pm) if enabled: graph_module.graph.eliminate_dead_code() remove_orphaned_modules(graph_module) # Graph.lint lacks type stubs in torch graph_module.graph.lint() # type: ignore[no-untyped-call] graph_module.recompile() graph_module.eval() _propagate_shapes(graph_module) report = _build_report(enabled, skipped) return TransformationResult( model=graph_module, report=report, matches=enabled, )
[docs] def transform( model: nn.Module | fx.GraphModule, patterns: Sequence[Pattern], ) -> TransformationResult: """Apply pattern transformations to `model` in one step. Recomposition and conversion patterns are applied iteratively until no new matches are found, then fusion patterns are matched and applied in a single pass. The original model is deep-copied on the first processing call and never modified. :param model: The model to transform. If model is a ``nn.Module`` it will be traced with :func:`~torch.fx.symbolic_trace`. :param patterns: Patterns to match and apply. Order matters: patterns are matched in sequence, and earlier matches claim nodes first. Supply longest (most specific) patterns first to ensure they take priority over shorter ones when sub-graphs overlap. :returns: A :class:`~embedl_deploy.TransformationResult` containing ``model`` (transformed), ``report`` (summary), and ``matches`` (applied matches). Example: .. code-block:: python from embedl_deploy import transform from embedl_deploy.tensorrt import TENSORRT_PATTERNS deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model """ graph_module = ( model if isinstance(model, fx.GraphModule) else symbolic_trace(model) ) _iterative = (Phase.RECOMPOSITION, Phase.CONVERSION) conversions = [p for p in patterns if p.phase in _iterative] fusions = [p for p in patterns if p.phase == Phase.FUSION] model_is_original = True if conversions: while True: conv_plan = get_transformation_plan( graph_module, conversions, deep_copy=model_is_original, ) model_is_original = False if not conv_plan.matches: break conv_result = apply_transformation_plan(conv_plan) graph_module = conv_result.model fuse_plan = get_transformation_plan( graph_module, fusions, deep_copy=model_is_original, ) fuse_result = apply_transformation_plan(fuse_plan) return fuse_result