Source code for embedl_deploy._internal.tensorrt.modules.pool
# Copyright (C) 2026 Embedl AB
"""Fused ``nn.Module`` replacements for pooling-based patterns."""
import torch
from torch import nn
from embedl_deploy._internal.core.modules import FusedModule
[docs]
class FusedAdaptiveAvgPool2d(FusedModule):
"""Fused wrapper for ``AdaptiveAvgPool2d``."""
inputs_to_quantize: set[int] = set()
def __init__(self, pool: nn.AdaptiveAvgPool2d) -> None:
super().__init__()
self.pool = pool
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply adaptive average pooling."""
return self.pool(x)
def __repr__(self) -> str:
return f"FusedAdaptiveAvgPool2d(output_size={self.pool.output_size})"