图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

您所在的位置:网站首页 split-part函数 图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

2023-04-13 17:26| 来源: 网络整理| 查看: 265

原文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation

代码地址:Pytorch U-Net

U-Net的设计方法在图像分割领域非常流行。在U-Net发明之后,很多的研究者对其进行了改进,产生了U-Net++, U-Net3+, Attention U-Net等算法。这些算法针对特定的任务,有一定的提升效果。搭配一些数学思想,这些主干网络得以发挥巨大的作用,特别是在医学影像分割领域。除此之外,U-Net的Encoder,Decoder的思想,在CV的其它任务上也有广泛的应用。比如在图像上色的任务中,主干网络也是类似于U-Net的设计,中间采用对称连接(fusion)的方法。这篇文章是U-Net的学习笔记,同时对照分析一些代码(pytorch)。在本文末尾,我对2018 Atrial Segmentation Challenge 左心房分割的数据集做了一些U-Net的消融实验。

这篇笔记会分为几个部分

分析主干网络论文中图像增强的技巧损失函数的设计训练和测试的设计

主干网络U-Net的主干网络

U-Net的主干网络设计简洁。在encoder部分,使用四次下采样。每次采样之后,进行两次卷积和激活的操作。每次采样组成一个block(图中每一排的连续三个layer)。下采样至1024维的隐空间之后,开始上采样。Decoder部分和encoder的设计是对称的。为了得到encoder的信息,decoder的每一个block都有一半的信息来自encoder,并且采用fusion的方式叠加在一起(ResNet使用的是“加”操作)。Fusion的方式存储更多的信息,但是会带来更多的计算量,需要根据任务选择设计。最后得到通道数为2的输出。

首先定义卷积模块,就是每一块横着连续的三个蓝色方块。这个模式出现了很多次,所以定义为一个class。这里的mid_channels和out_channels的通道数是一样的。使用3*3卷积。每一块使用Conv2d,BN, Relu。(原文没有BN,因为BN也是2015年提出来的,它有助于效果提升,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift)

class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x)

定义下采样模块。使用最大池化层和上文定义的二次卷积block。

class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x)

定义上采样模块。上采样采用2*2卷积核,卷积方式是‘bilinear’,然后使用前文定义的二次卷积block。在每一次上采样的时候需要融合对应下采样同尺度的信息。因此在forward函数里面先把图像使用pad函数像素对齐,然后concat起来。因此decoder同级的维度数是原来encoder的两倍。

class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x)

输出维度回归。使用之前定义的二次卷积模块。

class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x)

主干网络整合。首先定义输入,输出维度在不同采样层的大小。在forward中,连续四次下采样,再连续四次的上采样,采用对称跳层连接。最后得到输出维度为2的图像。

class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits

数据增强

这篇文章提出的U-Net网络可以适用于很少的数据。但是神经网络依旧对数据的位置,明暗等因素非常敏感。因此作者专门写了一个part来讲述数据增强。一般常用的数据增强方法比如旋转,平移。作者同时还使用了平滑变形(3*3网格随机替换,高斯分布,方差为10)。(代码太长了,就不放了)

除此之外,作者还使用了Drop-out layer在encoder的末尾。

损失函数

作者一共设计了两种损失函数。一个是交叉信息熵(这个很好理解,做两个mask之间的差异),一个是Energy function。

交叉信息熵

criterion = nn.CrossEntropyLoss() loss_part1 = criterion(masks_pred, true_masks)

Energy function

首先定义softmax函数

p_{k}(\mathbf{x})=\exp \left(a_{k}(\mathbf{x})\right) /\left(\sum_{k^{\prime}=1}^{K} \exp \left(a_{k^{\prime}}(\mathbf{x})\right)\right)

这里的 a_{k}(\mathbf{x})\ 表示在通道中激活的像素, K 是classes的数量。函数大小在0-1之间。

结合交叉信息熵,计算每个位置mask分类的偏差,并求和得到能量 E 。

E=\sum_{\mathbf{x} \in \Omega} w(\mathbf{x}) \log \left(p_{\ell(\mathbf{x})}(\mathbf{x})\right)

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all batches, or for a single mask assert input.size() == target.size() if input.dim() == 2 and reduce_batch_first: raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') if input.dim() == 2 or reduce_batch_first: inter = torch.dot(input.reshape(-1), target.reshape(-1)) sets_sum = torch.sum(input) + torch.sum(target) if sets_sum.item() == 0: sets_sum = 2 * inter return (2 * inter + epsilon) / (sets_sum + epsilon) else: # compute and average metric for each batch element dice = 0 for i in range(input.shape[0]): dice += dice_coeff(input[i, ...], target[i, ...]) return dice / input.shape[0] def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all classes assert input.size() == target.size() dice = 0 for channel in range(input.shape[1]): dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) return dice / input.shape[1] def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 assert input.size() == target.size() fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) loss_part2 = dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True) loss = loss_part1 + loss_part2

采用动量梯度下降,定义0.99的初始动量,和学习率衰减。

optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

为了减少梯度下降的波动性,使用RMSProp

优化器:RMSProp

计算损失热力图。

w(\mathbf{x})=w_{c}(\mathbf{x})+w_{0} \cdot \exp \left(-\frac{\left(d_{1}(\mathbf{x})+d_{2}(\mathbf{x})\right)^{2}}{2 \sigma^{2}}\right)

a 原数据,b ground truth, c 分割mask,d 损失热力图

训练和测试

这篇文章使用的ISBI 2015数据集。训练集一共只有30张(512*512)。在warping error, rand error and pixel error这三个指标上比较和历史模型的优劣。

分割结果(IOU)在IBSI cell tracking challenge 2015

训练主函数(部分)

for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: images = batch['image'] true_masks = batch['mask'] assert images.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {images.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' images = images.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.long) with torch.cuda.amp.autocast(enabled=amp): masks_pred = net(images) loss = criterion(masks_pred, true_masks) \ + dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True) optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() pbar.update(images.shape[0]) global_step += 1 epoch_loss += loss.item() experiment.log({ 'train loss': loss.item(), 'step': global_step, 'epoch': epoch }) pbar.set_postfix(**{'loss (batch)': loss.item()}) # Evaluation round division_step = (n_train // (10 * batch_size)) if division_step > 0: if global_step % division_step == 0: histograms = {} for tag, value in net.named_parameters(): tag = tag.replace('/', '.') histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) val_score = evaluate(net, val_loader, device) scheduler.step(val_score) logging.info('Validation Dice score: {}'.format(val_score)) experiment.log({ 'learning rate': optimizer.param_groups[0]['lr'], 'validation Dice': val_score, 'images': wandb.Image(images[0].cpu()), 'masks': { 'true': wandb.Image(true_masks[0].float().cpu()), 'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()), }, 'step': global_step, 'epoch': epoch, **histograms })

测试主函数(部分)

def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1)[0] else: probs = torch.sigmoid(output)[0] tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor() ]) full_mask = tf(probs.cpu()).squeeze() if net.n_classes == 1: return (full_mask > out_threshold).numpy() else: return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()

这是一篇非常经典的医学影像分割论文。之后的医学影像分割论文也大多基于U-Net。尽管有很多改进和提升的地方,但是U-Net的拓扑结构是一直传承的。结合最新的数学(AI)思想,比如GAN, Domain Adaptation, VAE等等,U-Net可以把性能发挥地更好。

2018 Atrial Segmentation Challenge

2018年左心房分割挑战赛的数据集一共分为两类。提供100例MRI训练数据和54例MRI测试数据。

原数据集格式是nrrd(核磁共振图),把它读取成png切片,(读取的时候)通道为1。

MRI分割mask一个训练例子的所有切片

实际上,每一个例子是一个3D的核磁共振图。冠军方案也是两个级联的3D V-Net。我们在实验中,还是把它先做成2D的图(每个例子一共有88张切片),提供给U-Net训练,看看它和3D方案的差距。

数据增强的方式就是Crop然后放大。其他采用原来U-Net的设计。我训练的都是5个epoch。

Experiment MethodSettingParamatersDICE3D VNet2*3D V-Net38M0.923(train validation) 0.932(test)Multi-task learning(whether post/pre)U-Net + post/pre label18M+0.901(train validation)0.921(test)Naive U-NetOriginal17M0.883(train validation)Naive U-Net0.6*training sample17M0.882(train validation)Naive U-NetHalf Channel4M0.899(train validation)Naive U-NetCut the deepest layer4M0.870(train validation)Naive U-Netw/o skip connection25M0.854(train validation)

实验对比了 3D V-Net,多任务学习的两个挑战赛参赛者的算法 和 我自己做的U-Net的一些消融实验。很明显地发现,这个核磁共振影像的3D性质是一个非常重要的信息需要网络去学习。如果把它切片成2D的,每一张图的信息就是独立的了,这样的效果也有所下降。第一名设计的级联VNet,DICE达到了0.923。多任务学习的设计方法,展现了心脏在射频消融前后,mask是有一定的分布差异的,这部分的信息是给予影像分割很好的指导。它和领域自适应的设计思路差不多,中间的特征提取层,加了一个网络的分枝,来分辨患者是否做过了射频消融,反向传播的时候,使用反转梯度的设计。该方法在挑战赛中的排名也不错,达到了0.901,仅次于大部分其他3D-VNet的方案。多任务学习的方法在这个领域也十分重要。

本篇文章是探索U-Net的,所以我对U-Net做了一些消融实验以展示它的性质。按照原来的参数和结构做出来的结果准确率为0.883。我将训练样本减去40%,网络效果竟然有所提升。有些怀疑数据量太大,且相似性高导致了一定的过拟合。看来该任务相对比较简单。然后我把所有的通道数减一半,网络效果竟然提升了1.7%个点(0.899),原来网络过拟合的问题确凿无疑了。并且网络参数只有4M,在我的2080Ti上训练只需要6分钟。在第四个实验中,我把最后一层删掉了,网络参数也减少了很多。但是对比原来的拓扑结构,少了一个层次的信息,网络精度下降到了0.870。最后一个实验中,我删掉了跳层连接,网络精度下降地更多,只有0.854。并且我把原来跳层的一半信息用conv2d填补上了导致了多了8M的参数量。跳层连接的设计非常重要。

以上就是对U-Net网络的分析。(求赞赞)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3