医学图像分割入门必看:深入理解U-Net网络结构,并基于Pytorch实现网络,每一行都有详细注释
动动您的小手指,关注一下吧!
U-Net是一个被广泛应用于医学图像分割的神经网络(这一点可以查看我之前我分享的综述文章:U-Net在医学图像分割中的成功)。U-Net的结构虽然很简单,但是它在医学图像分割领域的效果确实极好的,分析其原因在于:
(1)关键的跳跃连接:在U-Net中每一次Down Sample都连接跳跃连接结构与对应的上采样进行级联,这种不同尺度的特征融合对上采样恢复像素大有帮助。更详细的说,就是高层(浅层)下采样倍数小,特征图具备更加细致的图特征,底层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断,当深层和浅层的特征进行融合时,分割效果往往会非常好。
(2)医学图像的小样本特性:众所周知,医学图像的样本是十分稀少的。医学影像数据因其专业性和隐私性,相较于自然图像数据要难获取的多,所以,一般一个项目能用到的数据不过数百例,小样本是其典型特征。对于深度学习而言,小样本不能用大模型(杀鸡焉用宰牛刀),容易造成过拟合,使得网络性能下降。所以,网络结构复杂和参数量大的模型并不适合于医学影像。而原始U-Net的参数量为28M,这是非常轻量级的网络,即使数据量不够用,利用一些数据增强手段,一般都会有一个很好的适配性。
因此,基于上述两个原因,U-Net被广泛的应用在医学图像分割中,并且衍生出了许许多多的变体,可以查看我之前分享的文章(还未看过的可以回过头去看看,你一定会有很大的收获,别忘了点赞+关注哦)。
图1 U-Net结构
02 代码
以下是U-Net代码是实现,安装好torch之后,可以直接调试运行,无需其他修改,方便大家学习,如有其他不明白之处,欢迎大家留言讨论,相互学习。大家可以利用断点调试的方式,逐步执行,查看每一行的输入和输出。
import torchimport torchvision.transforms.functionalfrom torch import nnclass DoubleConvolution(nn.Module): def __init__(self,in_channels, out_channels): super().__init__() # 第一,3x3卷积 self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # 卷积过后激活,标准步骤 self.act1 = nn.ReLU() # 第二,3x3卷积 self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.act2 = nn.ReLU() def forward(self, x): x = self.first(x) x = self.act1(x) x = self.second(x) x = self.act2(x) return xclass DownSample(nn.Module): """收缩路径中的每一步都使用2×2最大池化层对特征图进行下采样。""" def __init__(self): super().__init__() # 最大池化层 self.pool = nn.MaxPool2d(2) def forward(self, x): return self.pool(x)class UpSample(nn.Module): """扩展路径中的每一步都用2×2上卷积对特征映射进行上采样。""" def __init__(self, in_channels, out_channels): super().__init__() # ConvTranspose2d函数:该函数是用来进行转置卷积的,它主要做了这几件事: # 首先,对输入的feature map进行padding操作,得到新的feature map; # 然后,随机初始化一定尺寸的卷积核; # 最后,用随机初始化的一定尺寸的卷积核在新的feature map上进行卷积操作。 # 当s=1时,对于原feature map不进行插值操作,只进行padding操作; # 当s>1是还要进行插值操作,也就是这里采用的上采样操作 self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) def forward(self, x): print(x.shape) # [2, 1024, 16, 16] x = self.up(x) print(x.shape) # [2, 512, 32, 32] return xclass CropAndConcat(nn.Module): """ 裁剪和连接特征映射 在扩张路径的每一步,收缩路径的相应特征映射与当前特征映射相连接。 x:扩展路径中的当前特征映射 contracting_x:收缩路径对应的特征映射 """ def forward(self, x, contracting_x): # 将收缩路径上的特征映射裁剪为当前特征映射的大小 contracting_x = torchvision.transforms.functional.center_crop\ (contracting_x, [x.shape[2], x.shape[3]]) # 连接特征映射 x = torch.cat([x, contracting_x], dim=1) return xclass UNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in [(in_channels, 64), (64, 128), (128, 256), (256, 512)]]) self.down_sample = nn.ModuleList([DownSample() for _ in range(4)]) self.middle_conv = DoubleConvolution(512, 1024) self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]]) self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in [(1024, 512), (512, 256), (256, 128), (128, 64)]]) self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)]) self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): pass_through = [] for i in range(len(self.down_conv)): x = self.down_conv[i](x) pass_through.append(x) x = self.down_sample[i](x) x = self.middle_conv(x) for i in range(len(self.up_conv)): x = self.up_sample[i](x) # pop()为取出pass_through中的元素 x = self.concat[i](x, pass_through.pop()) x = self.up_conv[i](x) x = self.final_conv(x) return xif __name__ == '__main__': # 定义一个batch size为2,通道数维3,长款为256的图片 x = torch.randn(2, 3, 256, 256) # 初始化模型 net = UNet(3, 64) # 开始执行 result = net(x) # 打印输出 print(result.shape)
03 图片(防止手机端代码加载不出来)
领取专属 10元无门槛券
私享最新 技术干货