首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

医学图像分割入门必看:深入理解U-Net网络结构,并基于Pytorch实现网络,每一行都有详细注释

医学图像分割入门必看:深入理解U-Net网络结构,并基于Pytorch实现网络,每一行都有详细注释

动动您的小手指,关注一下吧!

U-Net是一个被广泛应用于医学图像分割的神经网络(这一点可以查看我之前我分享的综述文章:U-Net在医学图像分割中的成功)。U-Net的结构虽然很简单,但是它在医学图像分割领域的效果确实极好的,分析其原因在于:

(1)关键的跳跃连接:在U-Net中每一次Down Sample都连接跳跃连接结构与对应的上采样进行级联,这种不同尺度的特征融合对上采样恢复像素大有帮助。更详细的说,就是高层(浅层)下采样倍数小,特征图具备更加细致的图特征,底层(深层)下采样倍数大,信息经过大量浓缩,空间损失大,但有助于目标区域(分类)判断,当深层和浅层的特征进行融合时,分割效果往往会非常好。

(2)医学图像的小样本特性:众所周知,医学图像的样本是十分稀少的。医学影像数据因其专业性和隐私性,相较于自然图像数据要难获取的多,所以,一般一个项目能用到的数据不过数百例,小样本是其典型特征。对于深度学习而言,小样本不能用大模型(杀鸡焉用宰牛刀),容易造成过拟合,使得网络性能下降。所以,网络结构复杂和参数量大的模型并不适合于医学影像。而原始U-Net的参数量为28M,这是非常轻量级的网络,即使数据量不够用,利用一些数据增强手段,一般都会有一个很好的适配性。

因此,基于上述两个原因,U-Net被广泛的应用在医学图像分割中,并且衍生出了许许多多的变体,可以查看我之前分享的文章(还未看过的可以回过头去看看,你一定会有很大的收获,别忘了点赞+关注哦)。

图1 U-Net结构

02 代码

以下是U-Net代码是实现,安装好torch之后,可以直接调试运行,无需其他修改,方便大家学习,如有其他不明白之处,欢迎大家留言讨论,相互学习。大家可以利用断点调试的方式,逐步执行,查看每一行的输入和输出。

import torchimport torchvision.transforms.functionalfrom torch import nnclass DoubleConvolution(nn.Module):    def __init__(self,in_channels, out_channels):        super().__init__()        # 第一,3x3卷积        self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)        # 卷积过后激活,标准步骤        self.act1 = nn.ReLU()        # 第二,3x3卷积        self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)        self.act2 = nn.ReLU()    def forward(self, x):        x = self.first(x)        x = self.act1(x)        x = self.second(x)        x = self.act2(x)        return xclass DownSample(nn.Module):    """收缩路径中的每一步都使用2×2最大池化层对特征图进行下采样。"""    def __init__(self):        super().__init__()        # 最大池化层        self.pool = nn.MaxPool2d(2)    def forward(self, x):        return self.pool(x)class UpSample(nn.Module):    """扩展路径中的每一步都用2×2上卷积对特征映射进行上采样。"""    def __init__(self, in_channels, out_channels):        super().__init__()        # ConvTranspose2d函数:该函数是用来进行转置卷积的,它主要做了这几件事:        # 首先,对输入的feature map进行padding操作,得到新的feature map;        # 然后,随机初始化一定尺寸的卷积核;        # 最后,用随机初始化的一定尺寸的卷积核在新的feature map上进行卷积操作。        # 当s=1时,对于原feature map不进行插值操作,只进行padding操作;        # 当s>1是还要进行插值操作,也就是这里采用的上采样操作        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)    def forward(self, x):        print(x.shape)  # [2, 1024, 16, 16]        x = self.up(x)        print(x.shape)  # [2, 512, 32, 32]        return xclass CropAndConcat(nn.Module):    """    裁剪和连接特征映射    在扩张路径的每一步,收缩路径的相应特征映射与当前特征映射相连接。   x:扩展路径中的当前特征映射   contracting_x:收缩路径对应的特征映射    """    def forward(self, x, contracting_x):        # 将收缩路径上的特征映射裁剪为当前特征映射的大小        contracting_x = torchvision.transforms.functional.center_crop\            (contracting_x, [x.shape[2], x.shape[3]])        # 连接特征映射        x = torch.cat([x, contracting_x], dim=1)        return xclass UNet(nn.Module):    def __init__(self, in_channels, out_channels):        super().__init__()        self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in                                        [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])        self.middle_conv = DoubleConvolution(512, 1024)        self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]])        self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in                                      [(1024, 512), (512, 256), (256, 128), (128, 64)]])        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)    def forward(self, x):        pass_through = []        for i in range(len(self.down_conv)):            x = self.down_conv[i](x)            pass_through.append(x)            x = self.down_sample[i](x)        x = self.middle_conv(x)        for i in range(len(self.up_conv)):            x = self.up_sample[i](x)            # pop()为取出pass_through中的元素            x = self.concat[i](x, pass_through.pop())            x = self.up_conv[i](x)        x = self.final_conv(x)        return xif __name__ == '__main__':    # 定义一个batch size为2,通道数维3,长款为256的图片    x = torch.randn(2, 3, 256, 256)    # 初始化模型    net = UNet(3, 64)    # 开始执行    result = net(x)    # 打印输出    print(result.shape)

03 图片(防止手机端代码加载不出来)

  • 发表于:
  • 原文链接https://page.om.qq.com/page/ON2OrNiwXRRr_ofq1_2gOazw0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券