# Copyright (C) 2026 Embedl AB
"""Lattice-conforming convolution modules.
Lattice hardware accelerators support a restricted set of convolution
parameters. ``LatticeConv2d`` subclasses ``Conv2d`` and snaps an
arbitrary source convolution to the closest configuration allowed by
the hardware.
"""
import torch
from torch import nn
def _as_pair(value: int | tuple[int, ...]) -> tuple[int, int]:
"""Return `value` as a 2-tuple, broadcasting scalars.
:param value:
Value to convert. If an integer, it is broadcast to both
spatial dimensions; if a tuple, it must have two entries.
:returns:
A 2-tuple of integers.
"""
if isinstance(value, tuple):
return (value[0], value[1])
return (value, value)
def _snap(value: int, allowed: tuple[int, ...]) -> int:
"""Return the entry in `allowed` closest to `value`.
:param value:
Value to snap.
:param allowed:
Allowed values to snap to.
:returns:
The entry in `allowed` closest to `value`. In case of ties, the
smaller value is preferred.
"""
return min(allowed, key=lambda v: (abs(v - value), v))
[docs]
class LatticeConv2d(nn.Conv2d):
"""``Conv2d`` snapped to Lattice's supported set.
Lattice hardware accepts only ``1×1`` and ``3×3`` convolutions with
stride 1 or 2 (and stride 1 is mandatory for the ``1×1`` kernel).
The constructor takes an arbitrary source :class:`~torch.nn.Conv2d`
and forwards its ``in_channels``, ``out_channels``, ``dilation``,
``groups``, and bias presence; kernel size and stride are snapped
to the nearest support values.
A ``1×1`` kernel is promoted to ``3×3`` whenever its stride exceeds
1. Padding is set to ``kernel_size // 2`` on each spatial axis to
preserve the output shape under the common ``"same"``-style
convention used by ResNet-family stems and downsamples.
Weights and bias are copied from the source convolution whenever
the snapped weight tensor has the same shape as the source's
(i.e., only stride and/or padding changed). If the kernel size
was snapped — and the weight tensor shape therefore changed — the
instance keeps freshly initialized weights, since there is no
well-defined way to reuse the original kernel values.
"""
#: Permitted spatial kernel sizes.
KERNEL_SIZES: tuple[int, ...] = (1, 3)
#: Permitted spatial strides.
STRIDES: tuple[int, ...] = (1, 2)
[docs]
@classmethod
def snapped_params(
cls, conv: nn.Conv2d
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]:
"""Return ``(kernel_size, stride, padding)`` after Lattice snapping.
:param conv:
Source convolution whose parameters are snapped to the
nearest values accepted by Lattice hardware.
:returns:
A three-tuple ``(kernel_size, stride, padding)`` of the snapped
parameters where each element is itself an ``(h, w)`` pair.
"""
kh, kw = _as_pair(conv.kernel_size)
sh, sw = _as_pair(conv.stride)
new_sh = _snap(sh, cls.STRIDES)
new_sw = _snap(sw, cls.STRIDES)
new_kh = _snap(kh, cls.KERNEL_SIZES)
new_kw = _snap(kw, cls.KERNEL_SIZES)
if new_kh == 1 and new_sh != 1:
new_kh = 3
if new_kw == 1 and new_sw != 1:
new_kw = 3
return (
(new_kh, new_kw),
(new_sh, new_sw),
(new_kh // 2, new_kw // 2),
)
[docs]
@classmethod
def is_compatible(cls, conv: nn.Conv2d) -> bool:
"""Return ``True`` when `conv` already matches Lattice's supported set.
A convolution is compatible when its kernel size, stride, and
padding equal what ``snapped_params`` would return for it.
:param conv:
Convolution to check.
:returns:
``True`` when `conv` already conforms to Lattice
constraints; ``False`` otherwise.
"""
kernel_size, stride, padding = cls.snapped_params(conv)
return (
_as_pair(conv.kernel_size) == kernel_size
and _as_pair(conv.stride) == stride
and _as_pair(conv.padding) == padding # type: ignore[arg-type]
)
def __init__(self, conv: nn.Conv2d) -> None:
"""Create a ``LatticeConv2d`` from an arbitrary ``Conv2d``.
:param conv:
Source convolution. Its ``in_channels``,
``out_channels``, ``dilation``, ``groups``, and bias
presence are forwarded unchanged; kernel size, stride,
and padding are snapped to Lattice's supported set.
"""
kernel_size, stride, padding = self.snapped_params(conv)
super().__init__(
conv.in_channels,
conv.out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=_as_pair(conv.dilation),
groups=conv.groups,
bias=conv.bias is not None,
)
# Preserve weights/bias when the snapped weight tensor has the
# same shape as the source's (i.e., only stride/padding
# changed). If the kernel size was snapped, weight shapes
# differ and we cannot meaningfully reuse the source kernel.
if self.weight.shape == conv.weight.shape:
with torch.no_grad():
self.weight.copy_(conv.weight)
if conv.bias is not None and self.bias is not None:
self.bias.copy_(conv.bias)