Source code for embedl_deploy._internal.lattice.modules.conv

# Copyright (C) 2026 Embedl AB

"""Lattice-conforming convolution modules.

Lattice hardware accelerators support a restricted set of convolution
parameters.  ``LatticeConv2d`` subclasses ``Conv2d`` and snaps an
arbitrary source convolution to the closest configuration allowed by
the hardware.
"""

import torch
from torch import nn


def _as_pair(value: int | tuple[int, ...]) -> tuple[int, int]:
    """Return `value` as a 2-tuple, broadcasting scalars.

    :param value:
        Value to convert.  If an integer, it is broadcast to both
        spatial dimensions; if a tuple, it must have two entries.
    :returns:
        A 2-tuple of integers.
    """
    if isinstance(value, tuple):
        return (value[0], value[1])
    return (value, value)


def _snap(value: int, allowed: tuple[int, ...]) -> int:
    """Return the entry in `allowed` closest to `value`.

    :param value:
        Value to snap.
    :param allowed:
        Allowed values to snap to.
    :returns:
        The entry in `allowed` closest to `value`.  In case of ties, the
        smaller value is preferred.
    """
    return min(allowed, key=lambda v: (abs(v - value), v))


[docs] class LatticeConv2d(nn.Conv2d): """``Conv2d`` snapped to Lattice's supported set. Lattice hardware accepts only ``1×1`` and ``3×3`` convolutions with stride 1 or 2 (and stride 1 is mandatory for the ``1×1`` kernel). The constructor takes an arbitrary source :class:`~torch.nn.Conv2d` and forwards its ``in_channels``, ``out_channels``, ``dilation``, ``groups``, and bias presence; kernel size and stride are snapped to the nearest support values. A ``1×1`` kernel is promoted to ``3×3`` whenever its stride exceeds 1. Padding is set to ``kernel_size // 2`` on each spatial axis to preserve the output shape under the common ``"same"``-style convention used by ResNet-family stems and downsamples. Weights and bias are copied from the source convolution whenever the snapped weight tensor has the same shape as the source's (i.e., only stride and/or padding changed). If the kernel size was snapped — and the weight tensor shape therefore changed — the instance keeps freshly initialized weights, since there is no well-defined way to reuse the original kernel values. """ #: Permitted spatial kernel sizes. KERNEL_SIZES: tuple[int, ...] = (1, 3) #: Permitted spatial strides. STRIDES: tuple[int, ...] = (1, 2)
[docs] @classmethod def snapped_params( cls, conv: nn.Conv2d ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int]]: """Return ``(kernel_size, stride, padding)`` after Lattice snapping. :param conv: Source convolution whose parameters are snapped to the nearest values accepted by Lattice hardware. :returns: A three-tuple ``(kernel_size, stride, padding)`` of the snapped parameters where each element is itself an ``(h, w)`` pair. """ kh, kw = _as_pair(conv.kernel_size) sh, sw = _as_pair(conv.stride) new_sh = _snap(sh, cls.STRIDES) new_sw = _snap(sw, cls.STRIDES) new_kh = _snap(kh, cls.KERNEL_SIZES) new_kw = _snap(kw, cls.KERNEL_SIZES) if new_kh == 1 and new_sh != 1: new_kh = 3 if new_kw == 1 and new_sw != 1: new_kw = 3 return ( (new_kh, new_kw), (new_sh, new_sw), (new_kh // 2, new_kw // 2), )
[docs] @classmethod def is_compatible(cls, conv: nn.Conv2d) -> bool: """Return ``True`` when `conv` already matches Lattice's supported set. A convolution is compatible when its kernel size, stride, and padding equal what ``snapped_params`` would return for it. :param conv: Convolution to check. :returns: ``True`` when `conv` already conforms to Lattice constraints; ``False`` otherwise. """ kernel_size, stride, padding = cls.snapped_params(conv) return ( _as_pair(conv.kernel_size) == kernel_size and _as_pair(conv.stride) == stride and _as_pair(conv.padding) == padding # type: ignore[arg-type] )
def __init__(self, conv: nn.Conv2d) -> None: """Create a ``LatticeConv2d`` from an arbitrary ``Conv2d``. :param conv: Source convolution. Its ``in_channels``, ``out_channels``, ``dilation``, ``groups``, and bias presence are forwarded unchanged; kernel size, stride, and padding are snapped to Lattice's supported set. """ kernel_size, stride, padding = self.snapped_params(conv) super().__init__( conv.in_channels, conv.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=_as_pair(conv.dilation), groups=conv.groups, bias=conv.bias is not None, ) # Preserve weights/bias when the snapped weight tensor has the # same shape as the source's (i.e., only stride/padding # changed). If the kernel size was snapped, weight shapes # differ and we cannot meaningfully reuse the source kernel. if self.weight.shape == conv.weight.shape: with torch.no_grad(): self.weight.copy_(conv.weight) if conv.bias is not None and self.bias is not None: self.bias.copy_(conv.bias)