DF机器图像算法赛道

您所在的位置:网站首页 天气雷达图像识别方法 DF机器图像算法赛道

DF机器图像算法赛道

2024-07-15 09:03| 来源: 网络整理| 查看: 265

DF机器图像算法赛道-云状识别(天气识别)比赛总结 比赛目的题目简介数据分析类别比例图片尺寸 方案实现代码

比赛目的

之前一直用的tensorflow框架,想上手一下pytorch框架,然后一番搜索之后发现了基于pytorch的fastai高级库,fastai对于pytorch的关系和keras对于tensorflow的关系非常相似。正好拿这个比赛学习一下fastai,扩展一下pytorch方面的知识。最后成绩天气识别56/1054,云状识别31//1318。 因为是练手,所以实现的功能非常简单,基本就是baseline级别的东西。不得不说fastai使用起来非常简便,从开始写到出答案只需要几个小时的时间,并且想法迭代起来也非常快速。 与其说这是一个比赛的总结,不如说这是一个安利fastai的广告,fastai赛高。

题目简介

两个赛题都是分类问题,输入为尺寸不一的图片,输出为图片的类别。 例如天气的类别为:

天气现象编号雨凇1雾凇2雾霾3霜4露5结冰6降雨7降雪8冰雹9

云状的类别和天气差不多,分为1~29个类别,并且包含一张图片对应多个分类的情况,难度比天气大。

数据分析 类别比例

总的数据数量为10665,多标签的数据为454条,占比较少,由于只实现了单标签的网络,因此数据统计的时候去掉了多标签的数据。 最多的5个种类 在这里插入图片描述 数据非常不平衡,多的种类数量有1500+,少的种类数量只有个位数,这里不禁要吐槽一下数据的质量。 虽然做了数据类别的可视化,类别非常不平衡,但是方案中并没有实现,改进一下的话应该是可以提分的。 感觉在强大的数据平衡手段都拯救不了个位数和1500+的差距,深度学习,数据还是第一重要的。

图片尺寸

由于网络在训练是要resize到一个固定的尺寸,因此,统计一下数据图片的尺寸十分有必要。 下图所示,X轴为图片的宽,Y轴为图片的高。 在这里插入图片描述 统计之后可以发现图片大多数分布在0~2000之间,在几次控制变量试验后,选定了resize为512*512的训练尺寸。

方案实现 网络:Densenet201损失函数:FocalLoss在线数据增强 左右翻转 − 10 ° -10\degree −10° ~ 10 ° 10\degree 10°的旋转 1 1 1~ 1.1 1.1 1.1倍的缩放亮度调节 其他Trick: mixup使用fp_16训练,减少显存的消耗增大BatchSizeonecycle学习率策略 代码

代码非常简洁,基本上能想到的常见的trick,fastai都已经给你封装好了。

from fastai.vision import * import pandas as pd import numpy as np import os from fastai.callbacks import * os.environ["CUDA_VISIBLE_DEVICES"] = "1,0" from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True data_path = '/home/sjw/Desktop/cloud/data/Train' test_data_path = '/home/sjw/Desktop/cloud/data/Test' df = pd.read_csv('/home/sjw/Desktop/cloud/data/Train_label.csv') # 去除多标签的数据 classes_list = [str(i) for i in range(1,30)] df_single = df[df.Code.isin(classes_list)] # 重新排列索引 df_single = df_single.reset_index(drop=True) # 洗牌 df_single = df_single.sample(frac=1).reset_index(drop=True) for i in range(len(df_single)): df_single.Code[i] = int(df_single.Code[i]) test_df = pd.read_csv('/home/sjw/Desktop/cloud/data/submit_example.csv') save_dir = '/home/sjw/Desktop/cloud/fastai/densenet201' data_num = len(df_single) for part in range(1,5): val_list = list(range(data_num//5 * part + 1, data_num//5 * (part + 1) + 1)) part_save_dir = os.path.join(save_dir, 'part_{}'.format(part)) # 准备数据 data_set = (ImageList.from_df(df_single, data_path) # 从csv中读取dilenames .split_by_idx(val_list) # 验证集百分之20 .label_from_df() .add_test(ImageList.from_df(test_df, test_data_path)) .transform(get_transforms(do_flip=True, flip_vert=False, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75, xtra_tfms=[*rand_resize_crop(512)]), size=(512, 512), resize_method=ResizeMethod.SQUISH) .databunch(bs=32, num_workers=6) .normalize()) class FocalLoss(nn.Module): def __init__(self, alpha=1., gamma=1.): super().__init__() self.alpha = alpha self.gamma = gamma def forward(self, inputs, targets, **kwargs): CE_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets) pt = torch.exp(-CE_loss) F_loss = self.alpha * ((1 - pt) ** self.gamma) * CE_loss return F_loss.mean() # learn = cnn_learner(data_set, models.densenet201, # metrics=[accuracy], # loss_func=LabelSmoothingCrossEntropy(), # model_dir=part_save_dir) learn = cnn_learner(data_set, models.densenet201, metrics=[accuracy], loss_func=FocalLoss(), model_dir=part_save_dir) learn = learn.mixup().to_fp16() learn.model = nn.DataParallel(learn.model, device_ids=[0,1]) # learn.lr_find(stop_div=True, num_it=250) # learn.recorder.plot(suggestion=True) step_name = 'best_' + 'part' + str(part) learn.freeze() learn.fit(2, 1e-3) learn.unfreeze() unfreeze_min_grad_lr = 1e-3 learn.fit_one_cycle(20, [unfreeze_min_grad_lr / 100, unfreeze_min_grad_lr / 10, unfreeze_min_grad_lr], callbacks=[SaveModelCallback(learn, every='improvement', monitor='accuracy', name=step_name)]) learn.load(step_name) learn = learn.to_fp32() x, y = learn.TTA(ds_type=DatasetType.Test) write_csv = test_df for i in range(len(x)): write_csv.iloc[i, 1] = np.argmax(x[i]).tolist() + 1 write_csv.to_csv(os.path.join(part_save_dir,'part{}_newlabel.csv'.format(part)), index=False) learn.destroy()


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3