Source code for embedl_deploy._internal.lattice.modules.pool

# Copyright (C) 2026 Embedl AB

"""Lattice-conforming pooling modules.

Lattice hardware accelerators support only a small subset of pooling
configurations.  In particular, ``MaxPool2d`` is restricted to a
``2×2`` kernel with stride 2 and zero padding, and the only supported
form of average pooling for the classification tail is the global
average pool, expressed as ``nn.AdaptiveAvgPool2d((1, 1))``.

The two classes here subclass the corresponding stock
:mod:`torch.nn` modules and expose the allowed parameters as class
constants so that callers detect a Lattice-conforming pool with
:func:`isinstance`.
"""

from torch import nn


[docs] class LatticeMaxPool2d(nn.MaxPool2d): """``MaxPool2d`` snapped to Lattice's supported set. Lattice hardware supports only a single max-pool configuration: a ``2×2`` kernel with stride 2 and zero padding. The constructor takes an arbitrary source :class:`~torch.nn.MaxPool2d` and emits that canonical configuration, preserving only ``dilation`` and ``ceil_mode``. """ #: The single permitted kernel size. KERNEL_SIZE: int = 2 #: The single permitted stride. STRIDE: int = 2 #: The single permitted padding. PADDING: int = 0
[docs] @classmethod def is_compatible(cls, pool: nn.MaxPool2d) -> bool: """Return ``True`` when `pool` already matches Lattice's supported set. :param pool: Pool to check. :returns: ``True`` when kernel size, stride, and padding of `pool` equal the single supported configuration; ``False`` otherwise. """ ks = pool.kernel_size st = pool.stride pad = pool.padding # PyTorch stores these as int or tuple depending on how they # were passed to the constructor; normalize to int for comparison. if isinstance(ks, tuple): ks = ks[0] if ks[0] == ks[1] else -1 if isinstance(st, tuple): st = st[0] if st[0] == st[1] else -1 if isinstance(pad, tuple): pad = pad[0] if pad[0] == pad[1] else -1 return ( ks == cls.KERNEL_SIZE and st == cls.STRIDE and pad == cls.PADDING )
def __init__(self, pool: nn.MaxPool2d) -> None: """Create a ``LatticeMaxPool2d`` from an arbitrary ``MaxPool2d``. :param pool: Source pool node. Its ``dilation`` and ``ceil_mode`` are preserved; kernel size, stride, and padding are replaced with the hardware-fixed values. """ super().__init__( kernel_size=self.KERNEL_SIZE, stride=self.STRIDE, padding=self.PADDING, dilation=pool.dilation, ceil_mode=pool.ceil_mode, )
[docs] class LatticeAdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): """Canonical Lattice global average pool. The only supported form of average pooling for Lattice hardware is :class:`~torch.nn.AdaptiveAvgPool2d` with ``output_size == (1, 1)``. The constructor takes no arguments — every instance has the canonical configuration declared by :attr:`OUTPUT_SIZE`. """ #: The single permitted output size. OUTPUT_SIZE: tuple[int, int] = (1, 1)
[docs] @classmethod def is_compatible(cls, pool: nn.AdaptiveAvgPool2d) -> bool: """Return ``True`` when `pool` already matches Lattice's supported set. :param pool: Pool to check. :returns: ``True`` when the output size of `pool` equals ``(1, 1)``; ``False`` otherwise. """ out = pool.output_size if not isinstance(out, tuple): out = (out, out) return out == cls.OUTPUT_SIZE
def __init__(self) -> None: """Create a canonical ``LatticeAdaptiveAvgPool2d``.""" super().__init__(self.OUTPUT_SIZE)