# Copyright (C) 2026 Embedl AB
"""General-purpose Lattice conversion patterns.
These patterns rewrite standard Torch operations into the limited
subset accepted by Lattice hardware.
"""
import logging
import torch
from torch import fx, nn
from embedl_deploy._internal.core.patterns.main import Pattern, Phase
from embedl_deploy._internal.core.tree.types import (
Graft,
NodeInserter,
Tree,
TreeMatch,
)
from embedl_deploy._internal.core.tree.utils import (
get_input_shape,
get_module,
resolve_module,
)
from embedl_deploy._internal.lattice.modules.activation import (
LatticeLeakyReLU,
)
from embedl_deploy._internal.lattice.modules.conv import LatticeConv2d
from embedl_deploy._internal.lattice.modules.pool import (
LatticeAdaptiveAvgPool2d,
LatticeMaxPool2d,
)
_LOG = logging.getLogger(__name__)
def _is_nonconforming_conv2d(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a ``Conv2d`` that needs rewriting."""
mod = get_module(node)
if not isinstance(mod, nn.Conv2d):
return False
if isinstance(mod, LatticeConv2d):
return False
return not LatticeConv2d.is_compatible(mod)
def _make_lattice_conv(tree_match: TreeMatch) -> tuple[nn.Module, ...]:
"""Build the replacement ``Conv2d`` for the matched node."""
node = tree_match.get_node(0)
conv = resolve_module(node, nn.Conv2d)
new_conv = LatticeConv2d(conv)
if new_conv.weight.shape != conv.weight.shape:
_LOG.warning(
"%s: kernel_size snapped from %s to %s — "
"weights reinitialized, output will differ; "
"retraining recommended.",
node.name,
conv.kernel_size,
new_conv.kernel_size,
)
elif new_conv.stride != conv.stride or new_conv.padding != conv.padding:
_LOG.warning(
"%s: stride/padding snapped from %s/%s to %s/%s — "
"output will differ; retraining recommended.",
node.name,
conv.stride,
conv.padding,
new_conv.stride,
new_conv.padding,
)
return (new_conv,)
[docs]
class LatticeConv2dPattern(Pattern):
"""Snap an out-of-spec ``Conv2d`` to Lattice's supported set.
Lattice hardware accepts only ``1×1`` and ``3×3`` convolutions with
stride 1 or 2 (and stride 1 is mandatory for the ``1×1`` kernel).
Convolutions outside this set are replaced by a
:class:`~embedl_deploy._internal.lattice.modules.conv.LatticeConv2d`,
whose constructor performs the snapping.
Weights and bias are preserved when only stride and/or padding
needed snapping (the weight tensor shape is unchanged). When the
kernel size itself is snapped the replacement has freshly
initialized weights, since the source kernel cannot be meaningfully
reused at the new shape.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_nonconforming_conv2d,)
graft: Graft = (_make_lattice_conv,)
def _is_nonconforming_leaky_relu(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a ``LeakyReLU`` that needs rewriting."""
mod = get_module(node)
if not isinstance(mod, nn.LeakyReLU):
return False
if isinstance(mod, LatticeLeakyReLU):
return False
return not LatticeLeakyReLU.is_compatible(mod)
def _make_lattice_leaky_relu(
tree_match: TreeMatch,
) -> tuple[nn.Module, ...]:
"""Build the replacement ``LeakyReLU`` for the matched node."""
node = tree_match.get_node(0)
act = resolve_module(node, nn.LeakyReLU)
new_act = LatticeLeakyReLU(act)
if act.negative_slope != new_act.negative_slope:
_LOG.warning(
"%s: LeakyReLU negative_slope snapped from %s to %s — "
"output will differ, retraining is recommended",
node.name,
act.negative_slope,
new_act.negative_slope,
)
return (new_act,)
[docs]
class LatticeLeakyReLUPattern(Pattern):
"""Snap a ``LeakyReLU`` negative slope to Lattice's supported value.
Lattice hardware implements leaky ReLU with a fixed negative slope
of ``1/16`` (0.0625). Any other slope is replaced by that value
and a warning is logged advising retraining.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_nonconforming_leaky_relu,)
graft: Graft = (_make_lattice_leaky_relu,)
def _is_nonconforming_maxpool2d(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a ``MaxPool2d`` that needs rewriting."""
mod = get_module(node)
if not isinstance(mod, nn.MaxPool2d):
return False
if isinstance(mod, LatticeMaxPool2d):
return False
return not LatticeMaxPool2d.is_compatible(mod)
def _make_lattice_maxpool(tree_match: TreeMatch) -> tuple[nn.Module, ...]:
"""Build the replacement ``MaxPool2d`` for the matched node."""
node = tree_match.get_node(0)
pool = resolve_module(node, nn.MaxPool2d)
_LOG.warning(
"%s: MaxPool2d snapped from kernel_size=%s stride=%s padding=%s "
"to kernel_size=%s stride=%s padding=%s — output will differ",
node.name,
pool.kernel_size,
pool.stride,
pool.padding,
LatticeMaxPool2d.KERNEL_SIZE,
LatticeMaxPool2d.STRIDE,
LatticeMaxPool2d.PADDING,
)
return (LatticeMaxPool2d(pool),)
[docs]
class LatticeMaxPool2dPattern(Pattern):
"""Snap an out-of-spec ``MaxPool2d`` to Lattice's supported set.
Lattice hardware supports only a ``2×2`` max-pool with stride 2 and
zero padding. Any other configuration is replaced by that single
canonical form, matching the stem rewrite performed by the Lattice
export reference script.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_nonconforming_maxpool2d,)
graft: Graft = (_make_lattice_maxpool,)
def _avg_pool_output_is_global(node: fx.Node) -> bool:
"""Return ``True`` when the pool's output collapses spatial dims to ``1×1``."""
mod = get_module(node)
shape = get_input_shape(node)
if shape is None or len(shape) != 4:
return False
in_h, in_w = int(shape[2]), int(shape[3])
if isinstance(mod, nn.AdaptiveAvgPool2d):
if isinstance(mod.output_size, tuple):
out_h, out_w = mod.output_size
else:
out_h = out_w = mod.output_size
out_h = out_h if out_h is not None else in_h
out_w = out_w if out_w is not None else in_w
return out_h == 1 and out_w == 1
if isinstance(mod, nn.AvgPool2d):
kh, kw = (
mod.kernel_size
if isinstance(mod.kernel_size, tuple)
else (mod.kernel_size, mod.kernel_size)
)
return kh == in_h and kw == in_w
return False
def _is_nonconforming_global_avg_pool(node: fx.Node) -> bool:
"""Return ``True`` when `node` is a global pool not in canonical form."""
mod = get_module(node)
if not isinstance(mod, (nn.AdaptiveAvgPool2d, nn.AvgPool2d)):
return False
if isinstance(mod, LatticeAdaptiveAvgPool2d):
return False
return _avg_pool_output_is_global(node)
def _make_lattice_global_pool(tree_match: TreeMatch) -> tuple[nn.Module, ...]:
"""Build the replacement global average pool for the matched node."""
del tree_match
return (LatticeAdaptiveAvgPool2d(),)
[docs]
class LatticeGlobalAvgPoolPattern(Pattern):
"""Normalize any global average pool to ``AdaptiveAvgPool2d((1, 1))``.
Lattice hardware represents the classification-tail pool as an
:class:`~torch.nn.AdaptiveAvgPool2d` with ``output_size == (1, 1)``.
Any other module that effectively performs a global average pool
over a 4-D input (an
:class:`~torch.nn.AdaptiveAvgPool2d` with the scalar shorthand
``1`` or with a ``None`` axis equal to the input spatial size, or an
:class:`~torch.nn.AvgPool2d` whose kernel covers the full spatial
extent) is rewritten to this canonical form.
The pattern is a no-op for inputs already in canonical form, so it
converges in a single conversion pass without infinite looping.
Requires shape metadata on the input node, propagated by
:class:`~torch.fx.passes.shape_prop.ShapeProp` (the
``transform`` driver runs this automatically between conversion
iterations).
"""
phase = Phase.CONVERSION
tree: Tree = (_is_nonconforming_global_avg_pool,)
graft: Graft = (_make_lattice_global_pool,)
def _flatten_dims(node: fx.Node) -> tuple[int, int]:
"""Return the ``(start_dim, end_dim)`` of a flatten `node`."""
mod = get_module(node)
if isinstance(mod, nn.Flatten):
return mod.start_dim, mod.end_dim
args = node.args
kwargs = node.kwargs
start_dim = kwargs.get("start_dim", args[1] if len(args) > 1 else 0)
end_dim = kwargs.get("end_dim", args[2] if len(args) > 2 else -1)
assert isinstance(start_dim, int)
assert isinstance(end_dim, int)
return start_dim, end_dim
def _is_convertible_flatten(node: fx.Node) -> bool:
"""Return ``True`` for flatten nodes whose input shape is known."""
if node.op == "call_function":
is_flat = node.target is torch.flatten
elif node.op == "call_method":
is_flat = node.target == "flatten"
else:
is_flat = isinstance(get_module(node), nn.Flatten)
if not is_flat:
return False
return get_input_shape(node) is not None
def _flatten_to_reshape(tree_match: TreeMatch) -> tuple[NodeInserter, ...]:
"""Return a ``NodeInserter`` that emits a ``reshape`` call."""
flatten_node = tree_match.get_node(0)
shape = get_input_shape(flatten_node)
assert shape is not None
start_dim, end_dim = _flatten_dims(flatten_node)
ndim = len(shape)
if start_dim < 0:
start_dim += ndim
if end_dim < 0:
end_dim += ndim
# Dimensions outside the flattened span are baked in as static
# values from the traced shape. Lattice targets a fixed batch
# size of 1, so this is intentional.
pre = [int(d) for d in shape[:start_dim]]
post = [int(d) for d in shape[end_dim + 1 :]]
target_shape: list[int] = [*pre, -1, *post]
def _insert(
graph_module: fx.GraphModule,
prev_args: tuple[fx.Node, ...],
) -> list[fx.Node]:
order = {
n: i for i, n in enumerate(graph_module.graph.nodes)
}.__getitem__
latest = max(prev_args, key=order)
with graph_module.graph.inserting_after(latest):
node = graph_module.graph.call_method(
"reshape", (prev_args[0], *target_shape)
)
return [node]
return (_insert,)
[docs]
class FlattenToReshapePattern(Pattern):
"""Replace ``flatten`` with an equivalent static ``reshape``.
Lattice does not support ``flatten`` as a stand-alone op but does
support ``reshape`` to a fully-static target shape. When the input
shape of the flatten node is known (via
:class:`~torch.fx.passes.shape_prop.ShapeProp`) the dimensions
outside ``[start_dim, end_dim]`` are kept verbatim and the
flattened span is collapsed to a single ``-1`` axis, producing a
statically-shaped ``reshape`` call equivalent to the original
``flatten``.
Matches all three flatten spellings: the
:class:`~torch.nn.Flatten` module, the ``torch.flatten`` function,
and the ``Tensor.flatten`` method.
.. note::
The replacement ``reshape`` bakes in the batch dimension from
the traced shape (typically 1). This is intentional: Lattice
hardware targets a fixed batch size of 1, so the resulting model
is not expected to handle variable batch sizes.
"""
phase = Phase.CONVERSION
tree: Tree = (_is_convertible_flatten,)
graft: Graft = (_flatten_to_reshape,)