Source code for embedl_deploy._internal.tensorrt.modules.pool

# Copyright (C) 2026 Embedl AB

"""Fused ``nn.Module`` replacements for pooling-based patterns."""

import torch
from torch import nn

from embedl_deploy._internal.core.modules import FusedModule


[docs] class FusedAdaptiveAvgPool2d(FusedModule): """Fused wrapper for ``AdaptiveAvgPool2d``.""" inputs_to_quantize: set[int] = set() def __init__(self, pool: nn.AdaptiveAvgPool2d) -> None: super().__init__() self.pool = pool
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply adaptive average pooling.""" return self.pool(x)
def __repr__(self) -> str: return f"FusedAdaptiveAvgPool2d(output_size={self.pool.output_size})"