前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >u-net笔记

u-net笔记

作者头像
Sarlren
发布2022-10-28 11:32:26
3060
发布2022-10-28 11:32:26
举报
文章被收录于专栏:Sarlren的笔记

总览

通篇看下来,u-net的特点有以下几个:

  • 不使用padding,不论是conv还是pooling。这可能减少了padding引入的0对数据的污染。
  • 在预测图像边缘的时候,使用镜像而不是padding。
  • 深层特征和浅层特征结合。
  • 原文的loss加强了对边界的检测。当然,在很多复现中都没有使用原文的loss,因为这需要手工标注w_c(x)。

架构

u型的架构。类似Encoder-decoder结构。


踩坑记录

  • unet原文中使用same卷积,这就是为什么结果出来的时候size比原图像尺寸小184px。但是原文并没有提到这个小了的尺寸是怎么去掉的,也就是怎么用最终的结果进行pixel-wise的预测。我到现在也不知道这一点是怎么运行的。而且能找到的实现都是使用valid卷积,我也被迫这么干了。
  • unet最后的1*1 conv是有可能得到大于1的值的。在我训练使用的数据集中,数据都是0或者1的。我一开始没有留意,然后loss使用了MSELoss,结果就是它收敛了,但是结果并不是想要的。为了解决这个问题,我在1*1conv之后添加了一个sigmoid层以输出概率。训练loss改为了BCELoss。效果就可以如愿了。
  • 在data_loader中使用torchvision.transform的时候使用Compose,里面涉及到随机变换的时候,对data和label进行变换竟然是不一样的。这迫使我使用笨方法的if去进行变换。
  • 最后的输出还要把sigmoid输出的概率使用torch.where转化为只有0和1的。
  • torchvision.transform在使用某些插值选项时会提示warning。我一开始想办法消除了这些warning,但发现label的白色区域在使用warning提示的插值方法后会加上一层黑边。因此最好的办法是忽视这个warning。
  • pratheepan这个数据集竟然data格式不是一致的,还要留意最后的名字。

跑出来的结果示例

data:

label:

predict:

容易发现,好像predict会对人脸上的阴影造成误判。而且泛化能力还有待加强。


我的代码

代码语言:javascript
复制

# 网络结构
import torch
from torch import nn
import torch.nn.functional as F
from torchsummary import summary
from torchvision.transforms.functional import crop

def conv_block(in_channel, out_channel):
    # return nn.Sequential(
    #     nn.Conv2d(in_channel, out_channel, kernel_size=3, bias=False),
    #     nn.BatchNorm2d(out_channel),
    #     nn.ReLU(),
    #     nn.Conv2d(out_channel, out_channel, kernel_size=3, bias=False),
    #     nn.BatchNorm2d(out_channel),
    #     nn.ReLU()
    # )
    
    # 为了妥协一个尺寸 被迫加个padding
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )

def up_conv_block(in_channel, out_channel):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2),
        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )

class UNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = conv_block(3, 64)
        self.conv2 = conv_block(64, 128)
        self.conv3 = conv_block(128, 256)
        self.conv4 = conv_block(256, 512)
        self.conv5 = conv_block(512, 1024)
        self.dropout = nn.Dropout()
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.upconv1 = up_conv_block(1024, 512)
        self.conv6 = conv_block(1024, 512)

        self.upconv2 = up_conv_block(512, 256)
        self.conv7 = conv_block(512, 256)

        self.upconv3 = up_conv_block(256, 128)
        self.conv8 = conv_block(256, 128)

        self.upconv4 = up_conv_block(128, 64)
        self.conv9 = conv_block(128, 64)

        self.conv_predict = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, X):
        X1 = self.conv1(X)
        X2 = self.conv2(self.max_pool(X1))
        X3 = self.conv3(self.max_pool(X2))
        X4 = self.conv4(self.max_pool(X3))
        X5 = self.upconv1(self.dropout(self.conv5(self.max_pool(X4))))
        X4_crop = X4.clone().detach()[:, :, 
                            (X4.shape[2] - X5.shape[2]) // 2: X5.shape[2] + (X4.shape[2] - X5.shape[2]) // 2, 
                            (X4.shape[3] - X5.shape[3]) // 2: X5.shape[3] + (X4.shape[3] - X5.shape[3]) // 2 ]
        X6 = self.upconv2(self.conv6(torch.cat([X4_crop, X5], dim=1)))
        X3_crop = X3.clone().detach()[:, :, 
                            (X3.shape[2] - X6.shape[2]) // 2: X6.shape[2] + (X3.shape[2] - X6.shape[2]) // 2, 
                            (X3.shape[3] - X6.shape[3]) // 2: X6.shape[3] + (X3.shape[3] - X6.shape[3]) // 2 ]
        X7 = self.upconv3(self.conv7(torch.cat([X3_crop, X6], dim=1)))
        X2_crop = X2.clone().detach()[:, :, 
                            (X2.shape[2] - X7.shape[2]) // 2: X7.shape[2] + (X2.shape[2] - X7.shape[2]) // 2, 
                            (X2.shape[3] - X7.shape[3]) // 2: X7.shape[3] + (X2.shape[3] - X7.shape[3]) // 2 ]
        X8 = self.upconv4(self.conv8(torch.cat([X2_crop, X7], dim=1)))
        X1_crop = X1.clone().detach()[:, :, 
                            (X1.shape[2] - X8.shape[2]) // 2: X8.shape[2] + (X1.shape[2] - X8.shape[2]) // 2, 
                            (X1.shape[3] - X8.shape[3]) // 2: X8.shape[3] + (X1.shape[3] - X8.shape[3]) // 2 ]
        X9 = self.conv9(torch.cat([X1_crop, X8], dim=1))
        return torch.sigmoid(self.conv_predict(X9))


X = torch.randn([1, 3, 304, 304])
net = UNet()
# summary(net, (3, 304, 304))

# data_loader
from torchvision.transforms.transforms import ToTensor
from PIL import Image
import os
from torchvision import transforms as T
import random
import matplotlib.pyplot as plt 
class FaceDataSet(torch.utils.data.Dataset):
    def __init__(self, path='/content/drive/MyDrive/UNet/Face_Dataset'):
        self.path = path

    def __getitem__(self, index):
        directory_list = os.listdir(self.path + '/Pratheepan_Dataset/FacePhoto')
        pic_name = directory_list[index]
        image = Image.open(self.path + '/Pratheepan_Dataset/FacePhoto/' + pic_name)
        pic_name = pic_name[:pic_name.rfind('.') + 1] + 'png'
        GT = Image.open(self.path + '/Ground_Truth/GroundT_FacePhoto/' + pic_name)
        transform = T.Compose([
            T.ToTensor(),
            # T.RandomHorizontalFlip(),
            # T.RandomAffine(0, scale=(0.9, 1.1))
        ])
        image = transform(image)
        GT = transform(GT)

        if random.random() > 0.5:
            image = T.functional.hflip(image)
            GT = T.functional.hflip(GT)

        if random.random() > 0.5:
            scale = random.uniform(0.7, 1.3)
            transform = T.Compose([T.Resize((int(scale * image.shape[1]), int(scale * image.shape[2])))])
            image = transform(image)
            GT = transform(GT)

        shape = (16 + (image.shape[1] // 16) * 16, 16 + (image.shape[2] // 16) * 16)
        transform = T.Compose([T.Resize(shape)])
        image = transform(image)
        GT = transform(GT)
        return image, GT[:1, :, :]

    def __len__(self):
        return len(os.listdir(self.path + '/Pratheepan_Dataset/FacePhoto'))

data_loader = torch.utils.data.DataLoader(dataset=FaceDataSet(), batch_size=1, shuffle=True)
loss_list = []

# train
def train(dataLoader, trainModel):
    net = trainModel
    optimizer = torch.optim.Adam(net.parameters(), lr=0.002, betas=(0.5, 0.999)) 
    criterion = nn.BCELoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1, verbose=True)  # 打印信息
    
    def init_weights(m):  # 初始化参数,极其重要,且极大加快了训练速度
        if type(m) == nn.Conv2d:
            nn.init.kaiming_uniform_(m.weight)  # kaiming初始化 过于厉害
        elif type(m) == nn.BatchNorm2d:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)
    net.apply(init_weights)
    
    for epoch in range(10):  # 跑10个epoch(一个epoch就是对样本集所有样本的遍历)
        runningLoss = 0.0  # 初始化loss
        for i, data in enumerate(dataLoader, 0):  # 枚举loader,写法固定为index,data
            inputs, labels = data  # data中就是我们刚才定义的__getitem__的顺序
            optimizer.zero_grad()  # 初始化梯度,必须要有
            outputs = net(inputs)  # 把data中的样本放入net而不放入标签,得到outputs输出
             
            loss = criterion(outputs, labels)  # 根据outputs和原有的标签计算交叉熵
            loss.backward()  # 反向传播计算更新参数,必须要有
            optimizer.step()  # 更新参数
            runningLoss += float(loss.data)  # 把一个epoch中的loss更新
            scheduler.step(epoch + i / len(dataLoader))  # 更新lr
            print(f'now batch {i}, loss on batch: {loss}')
            
        print(f'epoch{epoch}:', runningLoss)
        loss_list.append(runningLoss)
        # if len(loss_list) > 0:
        #     if runningLoss < loss_list[-1]:
        #         torch.save(net.state_dict(), '/content/drive/MyDrive/UNet/unet.pth')
        # else:
        #     torch.save(net.state_dict(), '/content/drive/MyDrive/UNet/unet.pth')
        torch.save(net.state_dict(), '/content/drive/MyDrive/UNet/unet.pth')

        plt.plot(loss_list)
        plt.show()
    print('finish!')

# use

# train(data_loader, net)
net = UNet()
net.load_state_dict(torch.load('/content/drive/MyDrive/UNet/unet.pth'))

# torch.set_printoptions(profile='full')
data = data_loader.dataset[0]
display(T.Compose([T.ToPILImage()])(data[0]))
display(T.Compose([T.ToPILImage()])(data[1]))
output = net(data[0].unsqueeze(0))[0]
zero = torch.zeros_like(output)
one = torch.ones_like(output)
temp = torch.where(output >= 0.5, one, output)
processed_output = torch.where(temp < 0.5, zero, temp)
display(T.Compose([T.ToPILImage()])(processed_output))
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-04-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 总览
  • 架构
  • 踩坑记录
  • 跑出来的结果示例
  • 我的代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档