图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

您所在的位置:网站首页 split-part函数 图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

图像分割:3D Unet网络性能一定优于2D Unet吗,如果优于,为什么优于?

2023-04-13 17:26| 来源: 网络整理| 查看: 265

原文地址:U-Net: Convolutional Networks for Biomedical Image Segmentation

代码地址:Pytorch U-Net

U-Net的设计方法在图像分割领域非常流行。在U-Net发明之后,很多的研究者对其进行了改进,产生了U-Net++, U-Net3+, Attention U-Net等算法。这些算法针对特定的任务,有一定的提升效果。搭配一些数学思想,这些主干网络得以发挥巨大的作用,特别是在医学影像分割领域。除此之外,U-Net的Encoder,Decoder的思想,在CV的其它任务上也有广泛的应用。比如在图像上色的任务中,主干网络也是类似于U-Net的设计,中间采用对称连接(fusion)的方法。这篇文章是U-Net的学习笔记,同时对照分析一些代码(pytorch)。在本文末尾,我对2018 Atrial Segmentation Challenge 左心房分割的数据集做了一些U-Net的消融实验。





首先定义卷积模块,就是每一块横着连续的三个蓝色方块。这个模式出现了很多次,所以定义为一个class。这里的mid_channels和out_channels的通道数是一样的。使用3*3卷积。每一块使用Conv2d,BN, Relu。(原文没有BN,因为BN也是2015年提出来的,它有助于效果提升,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift)

class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.double_conv(x)


class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x)


class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x)


class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x)


class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits



除此之外,作者还使用了Drop-out layer在encoder的末尾。


作者一共设计了两种损失函数。一个是交叉信息熵(这个很好理解,做两个mask之间的差异),一个是Energy function。


criterion = nn.CrossEntropyLoss() loss_part1 = criterion(masks_pred, true_masks)

Energy function


p_{k}(\mathbf{x})=\exp \left(a_{k}(\mathbf{x})\right) /\left(\sum_{k^{\prime}=1}^{K} \exp \left(a_{k^{\prime}}(\mathbf{x})\right)\right)

这里的 a_{k}(\mathbf{x})\ 表示在通道中激活的像素, K 是classes的数量。函数大小在0-1之间。

结合交叉信息熵,计算每个位置mask分类的偏差,并求和得到能量 E 。

E=\sum_{\mathbf{x} \in \Omega} w(\mathbf{x}) \log \left(p_{\ell(\mathbf{x})}(\mathbf{x})\right)

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all batches, or for a single mask assert input.size() == target.size() if input.dim() == 2 and reduce_batch_first: raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})') if input.dim() == 2 or reduce_batch_first: inter = torch.dot(input.reshape(-1), target.reshape(-1)) sets_sum = torch.sum(input) + torch.sum(target) if sets_sum.item() == 0: sets_sum = 2 * inter return (2 * inter + epsilon) / (sets_sum + epsilon) else: # compute and average metric for each batch element dice = 0 for i in range(input.shape[0]): dice += dice_coeff(input[i, ...], target[i, ...]) return dice / input.shape[0] def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): # Average of Dice coefficient for all classes assert input.size() == target.size() dice = 0 for channel in range(input.shape[1]): dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) return dice / input.shape[1] def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 assert input.size() == target.size() fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) loss_part2 = dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True) loss = loss_part1 + loss_part2


optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)




w(\mathbf{x})=w_{c}(\mathbf{x})+w_{0} \cdot \exp \left(-\frac{\left(d_{1}(\mathbf{x})+d_{2}(\mathbf{x})\right)^{2}}{2 \sigma^{2}}\right)

a 原数据,b ground truth, c 分割mask,d 损失热力图


这篇文章使用的ISBI 2015数据集。训练集一共只有30张(512*512)。在warping error, rand error and pixel error这三个指标上比较和历史模型的优劣。

分割结果(IOU)在IBSI cell tracking challenge 2015


for epoch in range(epochs): net.train() epoch_loss = 0 with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: for batch in train_loader: images = batch['image'] true_masks = batch['mask'] assert images.shape[1] == net.n_channels, \ f'Network has been defined with {net.n_channels} input channels, ' \ f'but loaded images have {images.shape[1]} channels. Please check that ' \ 'the images are loaded correctly.' images = images.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.long) with torch.cuda.amp.autocast(enabled=amp): masks_pred = net(images) loss = criterion(masks_pred, true_masks) \ + dice_loss(F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), multiclass=True) optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() grad_scaler.step(optimizer) grad_scaler.update() pbar.update(images.shape[0]) global_step += 1 epoch_loss += loss.item() experiment.log({ 'train loss': loss.item(), 'step': global_step, 'epoch': epoch }) pbar.set_postfix(**{'loss (batch)': loss.item()}) # Evaluation round division_step = (n_train // (10 * batch_size)) if division_step > 0: if global_step % division_step == 0: histograms = {} for tag, value in net.named_parameters(): tag = tag.replace('/', '.') histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) val_score = evaluate(net, val_loader, device) scheduler.step(val_score) logging.info('Validation Dice score: {}'.format(val_score)) experiment.log({ 'learning rate': optimizer.param_groups[0]['lr'], 'validation Dice': val_score, 'images': wandb.Image(images[0].cpu()), 'masks': { 'true': wandb.Image(true_masks[0].float().cpu()), 'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()), }, 'step': global_step, 'epoch': epoch, **histograms })


def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5): net.eval() img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False)) img = img.unsqueeze(0) img = img.to(device=device, dtype=torch.float32) with torch.no_grad(): output = net(img) if net.n_classes > 1: probs = F.softmax(output, dim=1)[0] else: probs = torch.sigmoid(output)[0] tf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor() ]) full_mask = tf(probs.cpu()).squeeze() if net.n_classes == 1: return (full_mask > out_threshold).numpy() else: return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()

这是一篇非常经典的医学影像分割论文。之后的医学影像分割论文也大多基于U-Net。尽管有很多改进和提升的地方,但是U-Net的拓扑结构是一直传承的。结合最新的数学(AI)思想,比如GAN, Domain Adaptation, VAE等等,U-Net可以把性能发挥地更好。

2018 Atrial Segmentation Challenge




实际上,每一个例子是一个3D的核磁共振图。冠军方案也是两个级联的3D V-Net。我们在实验中,还是把它先做成2D的图(每个例子一共有88张切片),提供给U-Net训练,看看它和3D方案的差距。


Experiment MethodSettingParamatersDICE3D VNet2*3D V-Net38M0.923(train validation) 0.932(test)Multi-task learning(whether post/pre)U-Net + post/pre label18M+0.901(train validation)0.921(test)Naive U-NetOriginal17M0.883(train validation)Naive U-Net0.6*training sample17M0.882(train validation)Naive U-NetHalf Channel4M0.899(train validation)Naive U-NetCut the deepest layer4M0.870(train validation)Naive U-Netw/o skip connection25M0.854(train validation)

实验对比了 3D V-Net,多任务学习的两个挑战赛参赛者的算法 和 我自己做的U-Net的一些消融实验。很明显地发现,这个核磁共振影像的3D性质是一个非常重要的信息需要网络去学习。如果把它切片成2D的,每一张图的信息就是独立的了,这样的效果也有所下降。第一名设计的级联VNet,DICE达到了0.923。多任务学习的设计方法,展现了心脏在射频消融前后,mask是有一定的分布差异的,这部分的信息是给予影像分割很好的指导。它和领域自适应的设计思路差不多,中间的特征提取层,加了一个网络的分枝,来分辨患者是否做过了射频消融,反向传播的时候,使用反转梯度的设计。该方法在挑战赛中的排名也不错,达到了0.901,仅次于大部分其他3D-VNet的方案。多任务学习的方法在这个领域也十分重要。






CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3