【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

目录 引言 二,项目搭建 2.1 平台选择 2.2 开始搭建 三,数据集 3.1 数据集下载 3.2 上传数据集 3.3 数据集的统计 四,项目运行 4.1 focal loss 4.2 训练前准备 4.3 slowfast对数据集训练 4.4 改进的slowfast对数据集训练 4.5 实验对比 4.6 实时查看GPU使用情况 引言





使用的项目的例子就用我之前的slowfast项目: 01【mmaction2 slowfast 行为分析(商用级别)】项目下载 02【mmaction2 slowfast 行为分析(商用级别)】项目demo搭建

2.1 平台选择

我还是用极链AI:https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite 创建实例:

2.2 开始搭建


cd home


git clone https://github.com/Wenhai-Zhu/JN-OpenLib-mmaction2.git


git clone https://gitee.com/YFwinston/JN-OpenLib-mmaction2.git

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进 环境搭建(AI云平台的操作方法)

pip install mmcv-full==1.2.7 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html pip install mmpycocotools pip install moviepy opencv-python terminaltables seaborn decord -i https://pypi.douban.com/simple


conda create -n JN-OpenLib-mmaction2-pytorch1.6-py3.6 -y python=3.6 conda activate JN-OpenLib-mmaction2-pytorch1.6-py3.6 pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install mmcv-full==1.2.7 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html pip install mmpycocotools pip install moviepy opencv-python terminaltables seaborn decord -i https://pypi.douban.com/simple


cd JN-OpenLib-mmaction2/ python setup.py develop

注意:上面的 cu102/torch1.6.0 一定要和创建环境的配置一直,cuda版本,torch版本

三,数据集 3.1 数据集下载


链接: https://pan.baidu.com/s/1wI7PVB9g5k6CcVDOfICW7A 提取码: du5o


options={ '0':'None','1':'handshake', '2':'point', '3':'hug', '4':'push','5':'kick', '6':'punch'} 3.2 上传数据集



cd user-data


mkdir slowfastDataSet

上传数据集:采用下面链接对应的方法 https://cloud.videojj.com/help/docs/data_manage.html#vcloud-oss-cli

3.3 数据集的统计


我们在AI平台上创建一个notebook(要在这里面写数据集统计代码) 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

重命名为dataTemp.ipynb 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进 代码如下:

import json #统计数据集中训练集/测试集的数据分布 file_dir = "/user-data/slowfastDataSet/Datasets/Interaction/annotations/train/" #file_dir = "/user-data/slowfastDataSet/Datasets/Interaction/annotations/test/" #训练集/测试集下文件名字 names = ['seq1','seq2','seq3','seq4','seq6','seq7','seq8','seq9','seq11','seq12', 'seq13','seq14','seq16','seq17','seq18','seq19'] #names = ['seq5','seq10','seq15','seq20'] #动作类别统计 action1=0 action2=0 action3=0 action4=0 action5=0 action6=0 #开始统计 for name in names: file_name = file_dir + name + '.json' f = open(file_name, encoding='utf-8') setting = json.load(f) # 把json文件转化为python用的类型 f.close() for file_1 in setting['metadata']: str = file_1.split("_") if str[1].isdigit(): action = setting['metadata'][file_1]['av']['1'] actions = action.split(",") if '1' in actions: action1 = 1 + action1 if '2' in actions: action2 = 1 + action2 if '3' in actions: action3 = 1 + action3 if '4' in actions: action4 = 1 + action4 if '5' in actions: action5 = 1 + action5 if '6' in actions: action6 = 1 + action6 print("action1",action1) print("action2",action2) print("action3",action3) print("action4",action4) print("action5",action5) print("action6",action6)

当我们对训练集进行统计时: 结果:

action1 1011 action2 709 action3 757 action4 358 action5 250 action6 320

当我们对测试集进行统计时: 结果:

action1 243 action2 132 action3 209 action4 95 action5 64 action6 94

我们在用excel,把这些图用图表的形式展示出来。 在这里插入图片描述 在这里插入图片描述 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进 从上面统计数据来看,可以判断这个数据集是不平衡的。

四,项目运行 4.1 focal loss

简而言之,focal loss的作用就是将预测值低的类,赋予更大的损失函数权重,在不平衡的数据中,难分类别的预测值低,那么这些难分样本的损失函数被赋予的权重就更大。 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

4.2 训练前准备

创建链接 /user-data/slowfastDataSet/Datasets 文件夹的软链接: 先进入JN-OpenLib-mmaction2

cd JN-OpenLib-mmaction2


ln -s /user-data/slowfastDataSet/Datasets data

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

4.3 slowfast对数据集训练 python tools/train.py configs/detection/via3/my_slowfast_kinetics_pretrained_r50_8x8x1_20e_via3_rgb.py --validate

在这里插入图片描述 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进 这里红色框出来的地方代表训练剩余时间

4.4 改进的slowfast对数据集训练


【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

class F_BCE(nn.Module): def __init__(self, pos_weight=1, reduction='mean'): super(F_BCE, self).__init__() self.pos_weight = pos_weight self.reduction = reduction def forward(self, logits, target): # logits: [N, *], target: [N, *] logits = F.sigmoid(logits) loss = - self.pos_weight * target * (1-logits)**2 * torch.log(logits) - \ (1 - target) * logits**2 * torch.log(1 - logits) if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() return loss

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

self.f_bce = F_BCE() self.BN = nn.BatchNorm1d(8)

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

cls_score = self.BN(cls_score) f_bce_loss = self.f_bce losses['loss_action_cls'] = f_bce_loss(cls_score, labels)

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进 bbox_head.py完整代码如下:

import torch import torch.nn as nn import torch.nn.functional as F from mmaction.core.bbox import bbox_target try: from mmdet.models.builder import HEADS as MMDET_HEADS mmdet_imported = True except (ImportError, ModuleNotFoundError): mmdet_imported = False class F_BCE(nn.Module): def __init__(self, pos_weight=1, reduction='mean'): super(F_BCE, self).__init__() self.pos_weight = pos_weight self.reduction = reduction def forward(self, logits, target): # logits: [N, *], target: [N, *] logits = F.sigmoid(logits) loss = - self.pos_weight * target * (1-logits)**2 * torch.log(logits) - \ (1 - target) * logits**2 * torch.log(1 - logits) if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() return loss class BBoxHeadAVA(nn.Module): """Simplest RoI head, with only two fc layers for classification and regression respectively. Args: temporal_pool_type (str): The temporal pool type. Choices are 'avg' or 'max'. Default: 'avg'. spatial_pool_type (str): The spatial pool type. Choices are 'avg' or 'max'. Default: 'max'. in_channels (int): The number of input channels. Default: 2048. num_classes (int): The number of classes. Default: 81. dropout_ratio (float): A float in [0, 1], indicates the dropout_ratio. Default: 0. dropout_before_pool (bool): Dropout Feature before spatial temporal pooling. Default: True. topk (int or tuple[int]): Parameter for evaluating multilabel accuracy. Default: (3, 5) multilabel (bool): Whether used for a multilabel task. Default: True. (Only support multilabel == True now). """ def __init__( self, temporal_pool_type='avg', spatial_pool_type='max', in_channels=2048, # The first class is reserved, to classify bbox as pos / neg num_classes=81, dropout_ratio=0, dropout_before_pool=True, topk=(3, 5), multilabel=True, loss_cfg = None): super(BBoxHeadAVA, self).__init__() assert temporal_pool_type in ['max', 'avg'] assert spatial_pool_type in ['max', 'avg'] self.temporal_pool_type = temporal_pool_type self.spatial_pool_type = spatial_pool_type self.in_channels = in_channels self.num_classes = num_classes self.dropout_ratio = dropout_ratio self.dropout_before_pool = dropout_before_pool self.multilabel = multilabel if topk is None: self.topk = () elif isinstance(topk, int): self.topk = (topk, ) elif isinstance(topk, tuple): assert all([isinstance(k, int) for k in topk]) self.topk = topk else: raise TypeError('topk should be int or tuple[int], ' f'but get { type(topk)}') # Class 0 is ignored when calculaing multilabel accuracy, # so topk cannot be equal to num_classes assert all([k 0: self.dropout = nn.Dropout(dropout_ratio) self.fc_cls = nn.Linear(in_channels, num_classes) self.debug_imgs = None self.f_bce = F_BCE() self.BN = nn.BatchNorm1d(6) def init_weights(self): nn.init.normal_(self.fc_cls.weight, 0, 0.01) nn.init.constant_(self.fc_cls.bias, 0) def forward(self, x): if self.dropout_before_pool and self.dropout_ratio > 0: x = self.dropout(x) x = self.temporal_pool(x) x = self.spatial_pool(x) if not self.dropout_before_pool and self.dropout_ratio > 0: x = self.dropout(x) x = x.view(x.size(0), -1) cls_score = self.fc_cls(x) # We do not predict bbox, so return None return cls_score, None def get_targets(self, sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg): pos_proposals = [res.pos_bboxes for res in sampling_results] neg_proposals = [res.neg_bboxes for res in sampling_results] pos_gt_labels = [res.pos_gt_labels for res in sampling_results] cls_reg_targets = bbox_target(pos_proposals, neg_proposals, pos_gt_labels, rcnn_train_cfg) return cls_reg_targets def recall_prec(self, pred_vec, target_vec): """ Args: pred_vec (tensor[N x C]): each element is either 0 or 1 target_vec (tensor[N x C]): each element is either 0 or 1 """ correct = pred_vec & target_vec # Seems torch 1.5 has no auto type conversion recall = correct.sum(1) / (target_vec.sum(1).float()+ 1e-6) prec = correct.sum(1) / (pred_vec.sum(1) + 1e-6) return recall.mean(), prec.mean() def multilabel_accuracy(self, pred, target, thr=0.5): pred = pred.sigmoid() pred_vec = pred > thr # Target is 0 or 1, so using 0.5 as the borderline is OK target_vec = target > 0.5 recall_thr, prec_thr = self.recall_prec(pred_vec, target_vec) recalls, precs = [], [] for k in self.topk: _, pred_label = pred.topk(k, 1, True, True) pred_vec = pred.new_full(pred.size(), 0, dtype=torch.bool) num_sample = pred.shape[0] for i in range(num_sample): pred_vec[i, pred_label[i]] = 1 recall_k, prec_k = self.recall_prec(pred_vec, target_vec) recalls.append(recall_k) precs.append(prec_k) return recall_thr, prec_thr, recalls, precs def loss(self, cls_score, bbox_pred, rois, labels, label_weights, bbox_targets=None, bbox_weights=None, reduce=True): losses = dict() if cls_score is not None: # Only use the cls_score #labels = labels[:, 1:] # pos_inds = torch.sum(labels, dim=-1) > 0 # cls_score = cls_score[pos_inds, 1:] # labels = labels[pos_inds] labels = labels[:, 1:] cls_score = cls_score[:, 1:] cls_score = self.BN(cls_score) f_bce_loss = self.f_bce losses['loss_action_cls'] = f_bce_loss(cls_score, labels) #bce_loss = F.binary_cross_entropy_with_logits #losses['loss_action_cls'] = bce_loss(cls_score, labels) recall_thr, prec_thr, recall_k, prec_k = self.multilabel_accuracy( cls_score, labels, thr=0.5) losses['recall@thr=0.5'] = recall_thr losses['prec@thr=0.5'] = prec_thr for i, k in enumerate(self.topk): losses[f'recall@top{ k}'] = recall_k[i] losses[f'prec@top{ k}'] = prec_k[i] return losses def get_det_bboxes(self, rois, cls_score, img_shape, flip=False, crop_quadruple=None, cfg = None): # might be used by testing w. augmentation if isinstance(cls_score, list): cls_score = sum(cls_score) / float(len(cls_score)) assert self.multilabel scores = cls_score.sigmoid() if cls_score is not None else None bboxes = rois[:, 1:] assert bboxes.shape[-1] == 4 # First reverse the flip img_h, img_w = img_shape if flip: bboxes_ = bboxes.clone() bboxes_[:, 0] = img_w - 1 - bboxes[:, 2] bboxes_[:, 2] = img_w - 1 - bboxes[:, 0] bboxes = bboxes_ # Then normalize the bbox to [0, 1] bboxes[:, 0::2] /= img_w bboxes[:, 1::2] /= img_h def _bbox_crop_undo(bboxes, crop_quadruple): decropped = bboxes.clone() if crop_quadruple is not None: x1, y1, tw, th = crop_quadruple decropped[:, 0::2] = bboxes[..., 0::2] * tw + x1 decropped[:, 1::2] = bboxes[..., 1::2] * th + y1 return decropped bboxes = _bbox_crop_undo(bboxes, crop_quadruple) return bboxes, scores if mmdet_imported: MMDET_HEADS.register_module()(BBoxHeadAVA)


python tools/train.py configs/detection/via3/my_slowfast_kinetics_pretrained_r50_8x8x1_20e_via3_rgb.py --validate

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

4.5 实验对比


首先是原模型(改进前) 来看看结果 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

0s2021-11-04 11:26:56,234 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq5.json with 556 images now 2021-11-04 11:26:56,234 - mmaction - INFO - Evaluating mAP ... ==> 0.035635 seconds to process groundtruth results ==> 0.0550323 seconds to process prediction results ==> 0.122445 seconds to Convert groundtruth ==> 0.302411 seconds to convert prediction ==> 0.0125372 seconds to run_evaluator [email protected]= 0.822277391327469 PerformanceByCategory/[email protected]/handshake= 0.8867924528301887 PerformanceByCategory/[email protected]/point= 0.6091617933723197 PerformanceByCategory/[email protected]/hug= 0.7931456548347613 PerformanceByCategory/[email protected]/push= 0.8666666666666667 PerformanceByCategory/[email protected]/kick= 0.7 PerformanceByCategory/[email protected]/punch= 0.9945054945054945 2021-11-04 11:26:56,766 - mmaction - INFO - [email protected] 0.8223 2021-11-04 11:26:56,768 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq10.json with 557 images now 2021-11-04 11:26:56,768 - mmaction - INFO - Evaluating mAP ... ==> 0.0580194 seconds to process groundtruth results ==> 0.0607066 seconds to process prediction results ==> 0.097141 seconds to Convert groundtruth ==> 0.315272 seconds to convert prediction ==> 0.015708 seconds to run_evaluator [email protected]= 0.40562803653484647 PerformanceByCategory/[email protected]/handshake= 0.45454545454545453 PerformanceByCategory/[email protected]/point= 0.391812865497076 PerformanceByCategory/[email protected]/hug= 0.6896551724137931 PerformanceByCategory/[email protected]/push= 0.37356321839080453 PerformanceByCategory/[email protected]/kick= 0.0 PerformanceByCategory/[email protected]/punch= 0.0 2021-11-04 11:26:57,318 - mmaction - INFO - [email protected] 0.4056 2021-11-04 11:26:57,318 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq15.json with 607 images now 2021-11-04 11:26:57,318 - mmaction - INFO - Evaluating mAP ... ==> 0.0335379 seconds to process groundtruth results ==> 0.0631692 seconds to process prediction results ==> 0.10759 seconds to Convert groundtruth ==> 0.37948 seconds to convert prediction ==> 0.0157986 seconds to run_evaluator [email protected]= 0.5331220241603439 PerformanceByCategory/[email protected]/handshake= 0.8571428571428571 PerformanceByCategory/[email protected]/point= 0.25 PerformanceByCategory/[email protected]/hug= 0.7222222222222222 PerformanceByCategory/[email protected]/push= 0.452991452991453 PerformanceByCategory/[email protected]/kick= 0.1 PerformanceByCategory/[email protected]/punch= 0.38461538461538464 2021-11-04 11:26:57,921 - mmaction - INFO - [email protected] 0.5331 2021-11-04 11:26:57,921 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq20.json with 622 images now 2021-11-04 11:26:57,921 - mmaction - INFO - Evaluating mAP ... ==> 0.0534511 seconds to process groundtruth results ==> 0.085907 seconds to process prediction results ==> 0.128184 seconds to Convert groundtruth ==> 0.38158 seconds to convert prediction ==> 0.0196507 seconds to run_evaluator [email protected]= 0.49892216373873616 PerformanceByCategory/[email protected]/handshake= 0.5131485429992891 PerformanceByCategory/[email protected]/point= 0.5477982978485432 PerformanceByCategory/[email protected]/hug= 0.8906976744186047 PerformanceByCategory/[email protected]/push= 0.42094630515683146 PerformanceByCategory/[email protected]/kick= 0.0 PerformanceByCategory/[email protected]/punch= 0.17647058823529413 2021-11-04 11:26:58,594 - mmaction - INFO - [email protected] 0.4989 2021-11-04 11:26:58,599 - mmaction - INFO - Epoch(val) [20][2643] [email protected]: 0.8223, 1_mAP@root@61826ea1b4e40269dc480687:/home/JN-OpenLib-mmaction2# , [email protected]: 0.5650


然后是改进后的模型 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进

Evaluateing data/Interaction/annotations/test/seq5.json with 556 images now 2021-11-04 12:14:31,246 - mmaction - INFO - Evaluating mAP ... ==> 0.0391157 seconds to process groundtruth results ==> 0.0535824 seconds to process prediction results ==> 0.127968 seconds to Convert groundtruth ==> 0.653159 seconds to convert prediction ==> 0.0301211 seconds to run_evaluator [email protected]= 0.9069741460566432 PerformanceByCategory/[email protected]/handshake= 0.9942538308459411 PerformanceByCategory/[email protected]/point= 0.6242095754290875 PerformanceByCategory/[email protected]/hug= 0.9607235142118863 PerformanceByCategory/[email protected]/push= 1.0 PerformanceByCategory/[email protected]/kick= 0.9636363636363636 PerformanceByCategory/[email protected]/punch= 0.9395604395604397 2021-11-04 12:14:32,155 - mmaction - INFO - [email protected] 0.9070 2021-11-04 12:14:32,155 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq10.json with 557 images now 2021-11-04 12:14:32,156 - mmaction - INFO - Evaluating mAP ... ==> 0.0397999 seconds to process groundtruth results ==> 0.0588491 seconds to process prediction results ==> 0.0968714 seconds to Convert groundtruth ==> 0.665299 seconds to convert prediction ==> 0.0389373 seconds to run_evaluator [email protected]= 0.7747222320890791 PerformanceByCategory/[email protected]/handshake= 0.9065866429798629 PerformanceByCategory/[email protected]/point= 0.8147161450436735 PerformanceByCategory/[email protected]/hug= 0.9150861013434618 PerformanceByCategory/[email protected]/push= 0.9839296652614131 PerformanceByCategory/[email protected]/kick= 0.5192805038670355 PerformanceByCategory/[email protected]/punch= 0.43140462721305534 2021-11-04 12:14:33,073 - mmaction - INFO - [email protected] 0.7747 2021-11-04 12:14:33,073 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq15.json with 607 images now 2021-11-04 12:14:33,073 - mmaction - INFO - Evaluating mAP ... ==> 0.0371342 seconds to process groundtruth results ==> 0.0646772 seconds to process prediction results ==> 0.101891 seconds to Convert groundtruth ==> 0.739635 seconds to convert prediction ==> 0.0439842 seconds to run_evaluator [email protected]= 0.7889970077502289 PerformanceByCategory/[email protected]/handshake= 0.8908877061148943 PerformanceByCategory/[email protected]/point= 0.8125 PerformanceByCategory/[email protected]/hug= 0.9966124661246613 PerformanceByCategory/[email protected]/push= 0.6567905778432095 PerformanceByCategory/[email protected]/kick= 0.7602408702408703 PerformanceByCategory/[email protected]/punch= 0.4837407437954634 2021-11-04 12:14:34,069 - mmaction - INFO - [email protected] 0.7890 2021-11-04 12:14:34,070 - mmaction - INFO - Evaluateing data/Interaction/annotations/test/seq20.json with 622 images now 2021-11-04 12:14:34,070 - mmaction - INFO - Evaluating mAP ... ==> 0.0685616 seconds to process groundtruth results ==> 0.0815992 seconds to process prediction results ==> 0.108798 seconds to Convert groundtruth ==> 0.813663 seconds to convert prediction ==> 0.0507228 seconds to run_evaluator [email protected]= 0.7211447472837192 PerformanceByCategory/[email protected]/handshake= 0.9027139539608279 PerformanceByCategory/[email protected]/point= 0.4597024418984696 PerformanceByCategory/[email protected]/hug= 0.9709107878976433 PerformanceByCategory/[email protected]/push= 0.7328256368708507 PerformanceByCategory/[email protected]/kick= 0.7415679515722716 PerformanceByCategory/[email protected]/punch= 0.35279228700716714 2021-11-04 12:14:35,201 - mmaction - INFO - [email protected] 0.7211 2021-11-04 12:14:35,207 - mmaction - INFO - Epoch(val) [20][2643] [email protected]: 0.9070, [email protected]: 0.7747, [email protected]: 0.7890, [email protected]: 0.7211, [email protected]: 0.7980

可以看出原模型最后准确率:79.80% 可以看出这是明显的提高

4.6 实时查看GPU使用情况


有人写了个小工具gpustat把nvidia-smi封装了起来,用起来很爽很方便,推荐给大家。 首先安装:

pip install gpustat


gpustat -cp

输出为: 【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进


gpustat -cp -i 1

【slowfast 损失函数改进】深度学习网络通用改进方案:slowfast的损失函数(使用focal loss解决不平衡数据)改进




