Source code for embedl_deploy._internal.lattice.modules.pool
# Copyright (C) 2026 Embedl AB
"""Lattice-conforming pooling modules.
Lattice hardware accelerators support only a small subset of pooling
configurations. In particular, ``MaxPool2d`` is restricted to a
``2×2`` kernel with stride 2 and zero padding, and the only supported
form of average pooling for the classification tail is the global
average pool, expressed as ``nn.AdaptiveAvgPool2d((1, 1))``.
The two classes here subclass the corresponding stock
:mod:`torch.nn` modules and expose the allowed parameters as class
constants so that callers detect a Lattice-conforming pool with
:func:`isinstance`.
"""
from torch import nn
[docs]
class LatticeMaxPool2d(nn.MaxPool2d):
"""``MaxPool2d`` snapped to Lattice's supported set.
Lattice hardware supports only a single max-pool configuration:
a ``2×2`` kernel with stride 2 and zero padding. The constructor
takes an arbitrary source :class:`~torch.nn.MaxPool2d` and emits
that canonical configuration, preserving only ``dilation`` and
``ceil_mode``.
"""
#: The single permitted kernel size.
KERNEL_SIZE: int = 2
#: The single permitted stride.
STRIDE: int = 2
#: The single permitted padding.
PADDING: int = 0
[docs]
@classmethod
def is_compatible(cls, pool: nn.MaxPool2d) -> bool:
"""Return ``True`` when `pool` already matches Lattice's supported set.
:param pool:
Pool to check.
:returns:
``True`` when kernel size, stride, and padding of `pool`
equal the single supported configuration; ``False``
otherwise.
"""
ks = pool.kernel_size
st = pool.stride
pad = pool.padding
# PyTorch stores these as int or tuple depending on how they
# were passed to the constructor; normalize to int for comparison.
if isinstance(ks, tuple):
ks = ks[0] if ks[0] == ks[1] else -1
if isinstance(st, tuple):
st = st[0] if st[0] == st[1] else -1
if isinstance(pad, tuple):
pad = pad[0] if pad[0] == pad[1] else -1
return (
ks == cls.KERNEL_SIZE and st == cls.STRIDE and pad == cls.PADDING
)
def __init__(self, pool: nn.MaxPool2d) -> None:
"""Create a ``LatticeMaxPool2d`` from an arbitrary ``MaxPool2d``.
:param pool:
Source pool node. Its ``dilation`` and ``ceil_mode`` are
preserved; kernel size, stride, and padding are replaced
with the hardware-fixed values.
"""
super().__init__(
kernel_size=self.KERNEL_SIZE,
stride=self.STRIDE,
padding=self.PADDING,
dilation=pool.dilation,
ceil_mode=pool.ceil_mode,
)
[docs]
class LatticeAdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
"""Canonical Lattice global average pool.
The only supported form of average pooling for Lattice hardware
is :class:`~torch.nn.AdaptiveAvgPool2d` with
``output_size == (1, 1)``. The constructor takes no arguments —
every instance has the canonical configuration declared by
:attr:`OUTPUT_SIZE`.
"""
#: The single permitted output size.
OUTPUT_SIZE: tuple[int, int] = (1, 1)
[docs]
@classmethod
def is_compatible(cls, pool: nn.AdaptiveAvgPool2d) -> bool:
"""Return ``True`` when `pool` already matches Lattice's supported set.
:param pool:
Pool to check.
:returns:
``True`` when the output size of `pool` equals
``(1, 1)``; ``False`` otherwise.
"""
out = pool.output_size
if not isinstance(out, tuple):
out = (out, out)
return out == cls.OUTPUT_SIZE
def __init__(self) -> None:
"""Create a canonical ``LatticeAdaptiveAvgPool2d``."""
super().__init__(self.OUTPUT_SIZE)