import torch.nn as nn
# --- OPÇÃO 1: MaxPool Tradicional (Saída depende da entrada) ---
# Se a entrada for 10x10, com stride 1, a saída será 9x9.
pool_fixed = nn.MaxPool2d(kernel_size=2, stride=1)
# --- OPÇÃO 2: Average Pooling (Média dos valores) ---
pool_avg = nn.AvgPool2d(kernel_size=2, stride=2)
# --- OPÇÃO 3: Adaptive Pooling (Saída GARANTIDA em 7x7) ---
# Não importa se a entrada é 224x224 ou 32x32, o resultado será 7x7.
pool_adaptive = nn.AdaptiveMaxPool2d(output_size=(7, 7))
# Exemplo de uso no seu código:
# pool = nn.AdaptiveMaxPool2d(output_size=(7, 7))