我想将一个Pytorch中的PyramidPooling类改写为TF2.6的版本使用,但其中的nn.ModuleList和nn.Sequential不知道如何修改,另外在class中的def和普通定义的函数def中,conv2d的使用有何区别?
import torch
from torch import nn
class PyramidPooling(nn.Module):
def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1):
super().__init__()
self.stages = []
self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales])
self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1)
self.relu = nn.LeakyReLU(0.2, inplace=True)
def _make_stage(self, in_channels, scale, ct_channels):
prior = nn.AvgPool2d(kernel_size=(scale, scale))
conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False)
relu = nn.LeakyReLU(0.2, inplace=True)
return nn.Sequential(prior, conv, relu)
def forward(self, feats):
h, w = feats.size(2), feats.size(3)
priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1)
return self.relu(self.bottleneck(priors))
相似问题