Source code for embedl_deploy._internal.lattice.patterns.conversions.general

# Copyright (C) 2026 Embedl AB

"""General-purpose Lattice conversion patterns.

These patterns rewrite standard Torch operations into the limited
subset accepted by Lattice hardware.
"""

import logging

import torch
from torch import fx, nn

from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.types import (
    Graft,
    NodeInserter,
    Tree,
    TreeMatch,
)
from embedl_deploy._internal.core.tree.utils import (
    get_input_shape,
    get_module,
    resolve_module,
)
from embedl_deploy._internal.lattice.modules.activation import (
    LatticeLeakyReLU,
)
from embedl_deploy._internal.lattice.modules.conv import LatticeConv2d
from embedl_deploy._internal.lattice.modules.pool import (
    LatticeAdaptiveAvgPool2d,
    LatticeMaxPool2d,
)

_LOG = logging.getLogger(__name__)


def _is_nonconforming_conv2d(node: fx.Node) -> bool:
    """Return ``True`` when `node` is a ``Conv2d`` that needs rewriting."""
    mod = get_module(node)
    if not isinstance(mod, nn.Conv2d):
        return False
    if isinstance(mod, LatticeConv2d):
        return False
    return not LatticeConv2d.is_compatible(mod)


def _make_lattice_conv(tree_match: TreeMatch) -> tuple[nn.Module, ...]:
    """Build the replacement ``Conv2d`` for the matched node."""
    node = tree_match.get_node(0)
    conv = resolve_module(node, nn.Conv2d)
    new_conv = LatticeConv2d(conv)
    if new_conv.weight.shape != conv.weight.shape:
        _LOG.warning(
            "%s: kernel_size snapped from %s to %s — "
            "weights reinitialized, output will differ; "
            "retraining recommended.",
            node.name,
            conv.kernel_size,
            new_conv.kernel_size,
        )
    elif new_conv.stride != conv.stride or new_conv.padding != conv.padding:
        _LOG.warning(
            "%s: stride/padding snapped from %s/%s to %s/%s — "
            "output will differ; retraining recommended.",
            node.name,
            conv.stride,
            conv.padding,
            new_conv.stride,
            new_conv.padding,
        )
    return (new_conv,)


[docs] class LatticeConv2dPattern(Pattern): """Snap an out-of-spec ``Conv2d`` to Lattice's supported set. Lattice hardware accepts only ``1×1`` and ``3×3`` convolutions with stride 1 or 2 (and stride 1 is mandatory for the ``1×1`` kernel). Convolutions outside this set are replaced by a :class:`~embedl_deploy._internal.lattice.modules.conv.LatticeConv2d`, whose constructor performs the snapping. Weights and bias are preserved when only stride and/or padding needed snapping (the weight tensor shape is unchanged). When the kernel size itself is snapped the replacement has freshly initialized weights, since the source kernel cannot be meaningfully reused at the new shape. """ phase = Phase.CONVERSION tree: Tree = (_is_nonconforming_conv2d,) graft: Graft = (_make_lattice_conv,)
def _is_nonconforming_leaky_relu(node: fx.Node) -> bool: """Return ``True`` when `node` is a ``LeakyReLU`` that needs rewriting.""" mod = get_module(node) if not isinstance(mod, nn.LeakyReLU): return False if isinstance(mod, LatticeLeakyReLU): return False return not LatticeLeakyReLU.is_compatible(mod) def _make_lattice_leaky_relu( tree_match: TreeMatch, ) -> tuple[nn.Module, ...]: """Build the replacement ``LeakyReLU`` for the matched node.""" node = tree_match.get_node(0) act = resolve_module(node, nn.LeakyReLU) new_act = LatticeLeakyReLU(act) if act.negative_slope != new_act.negative_slope: _LOG.warning( "%s: LeakyReLU negative_slope snapped from %s to %s — " "output will differ, retraining is recommended", node.name, act.negative_slope, new_act.negative_slope, ) return (new_act,)
[docs] class LatticeLeakyReLUPattern(Pattern): """Snap a ``LeakyReLU`` negative slope to Lattice's supported value. Lattice hardware implements leaky ReLU with a fixed negative slope of ``1/16`` (0.0625). Any other slope is replaced by that value and a warning is logged advising retraining. """ phase = Phase.CONVERSION tree: Tree = (_is_nonconforming_leaky_relu,) graft: Graft = (_make_lattice_leaky_relu,)
def _is_nonconforming_maxpool2d(node: fx.Node) -> bool: """Return ``True`` when `node` is a ``MaxPool2d`` that needs rewriting.""" mod = get_module(node) if not isinstance(mod, nn.MaxPool2d): return False if isinstance(mod, LatticeMaxPool2d): return False return not LatticeMaxPool2d.is_compatible(mod) def _make_lattice_maxpool(tree_match: TreeMatch) -> tuple[nn.Module, ...]: """Build the replacement ``MaxPool2d`` for the matched node.""" node = tree_match.get_node(0) pool = resolve_module(node, nn.MaxPool2d) _LOG.warning( "%s: MaxPool2d snapped from kernel_size=%s stride=%s padding=%s " "to kernel_size=%s stride=%s padding=%s — output will differ", node.name, pool.kernel_size, pool.stride, pool.padding, LatticeMaxPool2d.KERNEL_SIZE, LatticeMaxPool2d.STRIDE, LatticeMaxPool2d.PADDING, ) return (LatticeMaxPool2d(pool),)
[docs] class LatticeMaxPool2dPattern(Pattern): """Snap an out-of-spec ``MaxPool2d`` to Lattice's supported set. Lattice hardware supports only a ``2×2`` max-pool with stride 2 and zero padding. Any other configuration is replaced by that single canonical form, matching the stem rewrite performed by the Lattice export reference script. """ phase = Phase.CONVERSION tree: Tree = (_is_nonconforming_maxpool2d,) graft: Graft = (_make_lattice_maxpool,)
def _avg_pool_output_is_global(node: fx.Node) -> bool: """Return ``True`` when the pool's output collapses spatial dims to ``1×1``.""" mod = get_module(node) shape = get_input_shape(node) if shape is None or len(shape) != 4: return False in_h, in_w = int(shape[2]), int(shape[3]) if isinstance(mod, nn.AdaptiveAvgPool2d): if isinstance(mod.output_size, tuple): out_h, out_w = mod.output_size else: out_h = out_w = mod.output_size out_h = out_h if out_h is not None else in_h out_w = out_w if out_w is not None else in_w return out_h == 1 and out_w == 1 if isinstance(mod, nn.AvgPool2d): kh, kw = ( mod.kernel_size if isinstance(mod.kernel_size, tuple) else (mod.kernel_size, mod.kernel_size) ) return kh == in_h and kw == in_w return False def _is_nonconforming_global_avg_pool(node: fx.Node) -> bool: """Return ``True`` when `node` is a global pool not in canonical form.""" mod = get_module(node) if not isinstance(mod, (nn.AdaptiveAvgPool2d, nn.AvgPool2d)): return False if isinstance(mod, LatticeAdaptiveAvgPool2d): return False return _avg_pool_output_is_global(node) def _make_lattice_global_pool(tree_match: TreeMatch) -> tuple[nn.Module, ...]: """Build the replacement global average pool for the matched node.""" del tree_match return (LatticeAdaptiveAvgPool2d(),)
[docs] class LatticeGlobalAvgPoolPattern(Pattern): """Normalize any global average pool to ``AdaptiveAvgPool2d((1, 1))``. Lattice hardware represents the classification-tail pool as an :class:`~torch.nn.AdaptiveAvgPool2d` with ``output_size == (1, 1)``. Any other module that effectively performs a global average pool over a 4-D input (an :class:`~torch.nn.AdaptiveAvgPool2d` with the scalar shorthand ``1`` or with a ``None`` axis equal to the input spatial size, or an :class:`~torch.nn.AvgPool2d` whose kernel covers the full spatial extent) is rewritten to this canonical form. The pattern is a no-op for inputs already in canonical form, so it converges in a single conversion pass without infinite looping. Requires shape metadata on the input node, propagated by :class:`~torch.fx.passes.shape_prop.ShapeProp` (the ``transform`` driver runs this automatically between conversion iterations). """ phase = Phase.CONVERSION tree: Tree = (_is_nonconforming_global_avg_pool,) graft: Graft = (_make_lattice_global_pool,)
def _flatten_dims(node: fx.Node) -> tuple[int, int]: """Return the ``(start_dim, end_dim)`` of a flatten `node`.""" mod = get_module(node) if isinstance(mod, nn.Flatten): return mod.start_dim, mod.end_dim args = node.args kwargs = node.kwargs start_dim = kwargs.get("start_dim", args[1] if len(args) > 1 else 0) end_dim = kwargs.get("end_dim", args[2] if len(args) > 2 else -1) assert isinstance(start_dim, int) assert isinstance(end_dim, int) return start_dim, end_dim def _is_convertible_flatten(node: fx.Node) -> bool: """Return ``True`` for flatten nodes whose input shape is known.""" if node.op == "call_function": is_flat = node.target is torch.flatten elif node.op == "call_method": is_flat = node.target == "flatten" else: is_flat = isinstance(get_module(node), nn.Flatten) if not is_flat: return False return get_input_shape(node) is not None def _flatten_to_reshape(tree_match: TreeMatch) -> tuple[NodeInserter, ...]: """Return a ``NodeInserter`` that emits a ``reshape`` call.""" flatten_node = tree_match.get_node(0) shape = get_input_shape(flatten_node) assert shape is not None start_dim, end_dim = _flatten_dims(flatten_node) ndim = len(shape) if start_dim < 0: start_dim += ndim if end_dim < 0: end_dim += ndim # Dimensions outside the flattened span are baked in as static # values from the traced shape. Lattice targets a fixed batch # size of 1, so this is intentional. pre = [int(d) for d in shape[:start_dim]] post = [int(d) for d in shape[end_dim + 1 :]] target_shape: list[int] = [*pre, -1, *post] def _insert( graph_module: fx.GraphModule, prev_args: tuple[fx.Node, ...], ) -> list[fx.Node]: order = { n: i for i, n in enumerate(graph_module.graph.nodes) }.__getitem__ latest = max(prev_args, key=order) with graph_module.graph.inserting_after(latest): node = graph_module.graph.call_method( "reshape", (prev_args[0], *target_shape) ) return [node] return (_insert,)
[docs] class FlattenToReshapePattern(Pattern): """Replace ``flatten`` with an equivalent static ``reshape``. Lattice does not support ``flatten`` as a stand-alone op but does support ``reshape`` to a fully-static target shape. When the input shape of the flatten node is known (via :class:`~torch.fx.passes.shape_prop.ShapeProp`) the dimensions outside ``[start_dim, end_dim]`` are kept verbatim and the flattened span is collapsed to a single ``-1`` axis, producing a statically-shaped ``reshape`` call equivalent to the original ``flatten``. Matches all three flatten spellings: the :class:`~torch.nn.Flatten` module, the ``torch.flatten`` function, and the ``Tensor.flatten`` method. .. note:: The replacement ``reshape`` bakes in the batch dimension from the traced shape (typically 1). This is intentional: Lattice hardware targets a fixed batch size of 1, so the resulting model is not expected to handle variable batch sizes. """ phase = Phase.CONVERSION tree: Tree = (_is_convertible_flatten,) graft: Graft = (_flatten_to_reshape,)