Source code for embedl_deploy._internal.core.plan

# Copyright (C) 2026 Embedl AB

"""Implementation of ``TransformationPlan`` functionality.

The plan-based workflow lets users inspect and edit pattern matches before
applying them to the model.
"""

import copy
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TypedDict

import torch
from torch import fx, nn
from torch.fx.passes.shape_prop import ShapeProp

from embedl_deploy._internal.core.modules import symbolic_trace
from embedl_deploy._internal.core.pattern import Pattern, PatternMatch


[docs] @dataclass class TransformationPlan: """Editable transformation plan. Returned by :func:`~embedl_deploy.get_transformation_plan`. The ``matches`` dict maps ``input_node_name → pattern_class_name → PatternMatch``. Toggle ``match.apply = False`` to skip specific matches before calling :func:`~embedl_deploy.apply_transformation_plan`. """ #: Deep copy of the original graph (not yet modified by replacements). model: fx.GraphModule #: Nested dict of discovered matches, keyed by the last matched #: node's name and the pattern class name. matches: dict[str, dict[str, PatternMatch]] = field(default_factory=dict)
class TransformationReport(TypedDict): """Summary of a :func:`apply_transformation_plan` run.""" applied: list[str] skipped: list[str] applied_count: int skipped_count: int total_count: int
[docs] @dataclass class TransformationResult: """Result of applying a transformation plan. Returned by :func:`~embedl_deploy.apply_transformation_plan`. """ #: The transformed model with fused / quantized modules. model: fx.GraphModule #: Summary of what was applied and what was skipped. report: TransformationReport #: The actual :class:`PatternMatch` objects that were applied. matches: list[PatternMatch]
def _matches_to_dict( matches: list[PatternMatch], ) -> dict[str, dict[str, PatternMatch]]: """Convert a flat match list to the nested-dict representation.""" result: dict[str, dict[str, PatternMatch]] = {} for pm in matches: node_name = pm.tree_match.get_tree_nodes()[-1].name result.setdefault(node_name, {}) pattern_name = type(pm.pattern).__name__ if pattern_name in result[node_name]: raise ValueError( f"Duplicate match for ({node_name!r}, {pattern_name!r}); " "the nested-dict representation cannot hold both" ) result[node_name][pattern_name] = pm return result def _dict_to_matches( match_dict: dict[str, dict[str, PatternMatch]], ) -> list[PatternMatch]: """Flatten the nested dict back to a list (preserves insertion order).""" return [pm for per_node in match_dict.values() for pm in per_node.values()] def _build_report( enabled: list[PatternMatch], skipped: list[PatternMatch], ) -> TransformationReport: """Build the summary dict for a TransformationResult.""" return { "applied": [repr(m) for m in enabled], "skipped": [repr(m) for m in skipped], "applied_count": len(enabled), "skipped_count": len(skipped), "total_count": len(enabled) + len(skipped), }
[docs] def get_transformation_plan( graph_module: fx.GraphModule, patterns: Sequence[Pattern], ) -> TransformationPlan: """Find all non-overlapping tree-based pattern matches. Deep-copies `graph_module` so the original is never modified. Each pattern's :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` must return :class:`~embedl_deploy._internal.core.pattern.PatternMatch` objects with a populated ``tree_match``. Overlapping matches are resolved by marking later overlaps as ``apply=False``. Returns a :class:`~embedl_deploy.TransformationPlan` that can be inspected and edited before calling :func:`~embedl_deploy.apply_transformation_plan`. :param graph_module: The traced graph module to analyse. :param patterns: Patterns to search for. Order matters: patterns are matched in sequence, and earlier matches claim nodes first. Supply longest (most specific) patterns first to ensure they take priority over shorter ones when sub-graphs overlap. :return: A :class:`~embedl_deploy.TransformationPlan` containing `graph_module` (a deep copy) and ``matches`` (a nested dict of discovered pattern matches). Example: .. code-block:: python from torch import fx from embedl_deploy import get_transformation_plan from embedl_deploy.tensorrt import TENSORRT_PATTERNS graph_module = fx.symbolic_trace(model) plan = get_transformation_plan( graph_module, patterns=TENSORRT_PATTERNS ) for node, pats in plan.matches.items(): for name, match in pats.items(): print(f"{node}: {name} apply={match.apply}") """ if not getattr(graph_module, "_deep_copy_done", False): graph_module = copy.deepcopy(graph_module) setattr(graph_module, "_deep_copy_done", True) pattern_matches: list[PatternMatch] = [] for pattern in patterns: pattern_matches.extend(pattern.match(graph_module)) consumed: set[fx.Node] = set() for pm in pattern_matches: tree_nodes = set(pm.tree_match.get_tree_nodes()) if consumed.intersection(tree_nodes): pm.apply = False continue pm.apply = True consumed.update(tree_nodes) return TransformationPlan( model=graph_module, matches=_matches_to_dict(pattern_matches), )
[docs] def apply_transformation_plan( plan: TransformationPlan, ) -> TransformationResult: """Apply the enabled matches from `plan`. Only matches with ``apply=True`` are applied via :meth:`~embedl_deploy._internal.core.pattern.Pattern.replace`. The plan's model is modified in place (it is already a deep copy created by :func:`~embedl_deploy.get_transformation_plan`). :param plan: The plan to apply (from :func:`~embedl_deploy.get_transformation_plan`). :return: A :class:`~embedl_deploy.TransformationResult` containing ``model`` (transformed), ``report`` (summary), and ``matches`` (applied matches). :raises ValueError: If any nodes are included in more than one enabled pattern. Example: .. code-block:: python result = apply_transformation_plan(plan) print(result.report) torch.onnx.export(result.model, x, "deployed.onnx") """ graph_module = plan.model pattern_matches = _dict_to_matches(plan.matches) enabled = [pm for pm in pattern_matches if pm.apply] skipped = [pm for pm in pattern_matches if not pm.apply] consumed: set[fx.Node] = set() for pm in enabled: tree_nodes = set(pm.tree_match.get_tree_nodes()) if inter := consumed.intersection(tree_nodes): msg = f"Nodes: {inter} included in more than one enabled pattern." raise ValueError(msg) consumed.update(tree_nodes) for pm in enabled: pm.pattern.replace(pm) if enabled: graph_module.graph.lint() # type: ignore[no-untyped-call] graph_module.recompile() graph_module.eval() input_node = next(iter(graph_module.graph.nodes)) meta = input_node.meta.get("tensor_meta") if meta is not None and hasattr(meta, "shape"): ShapeProp(graph_module).propagate(torch.randn(meta.shape)) # type: ignore[no-untyped-call] report = _build_report(enabled, skipped) return TransformationResult( model=graph_module, report=report, matches=enabled, )
[docs] def transform( model: nn.Module | fx.GraphModule, patterns: Sequence[Pattern], ) -> TransformationResult: """Apply pattern transformations to `model` in one step. Conversion patterns (``is_conversion=True``) are applied iteratively until no new matches are found, then fusion patterns are matched and applied in a single pass. The original model is never modified — a deep copy is made internally by :func:`~embedl_deploy.get_transformation_plan`. :param model: The model to transform. If model is a ``nn.Module`` it will be traced with :func:`~torch.fx.symbolic_trace`. :param patterns: Patterns to match and apply. Order matters: patterns are matched in sequence, and earlier matches claim nodes first. Supply longest (most specific) patterns first to ensure they take priority over shorter ones when sub-graphs overlap. :return: A :class:`~embedl_deploy.TransformationResult` containing ``model`` (transformed), ``report`` (summary), and ``matches`` (applied matches). Example: .. code-block:: python from embedl_deploy import transform from embedl_deploy.tensorrt import TENSORRT_PATTERNS deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model """ graph_module = ( model if isinstance(model, fx.GraphModule) else symbolic_trace(model) ) conversions = [p for p in patterns if p.is_conversion] fusions = [p for p in patterns if not p.is_conversion] if conversions: while True: conv_plan = get_transformation_plan(graph_module, conversions) if not conv_plan.matches: break conv_result = apply_transformation_plan(conv_plan) graph_module = conv_result.model fuse_plan = get_transformation_plan(graph_module, fusions) fuse_result = apply_transformation_plan(fuse_plan) return fuse_result