Yolov3(Mxnet)训练自己的数据集

您所在的位置:网站首页 mobilenet训练自己的数据集 Yolov3(Mxnet)训练自己的数据集

Yolov3(Mxnet)训练自己的数据集

2023-08-08 00:09| 来源: 网络整理| 查看: 265

        Mxnet中的Gluoncv提供标准VOC和COCO数据集上的预训练模型、数据读取类和训练程序,如果我们想使用model_zoo里面的预训练模型,并在自己的数据集上微调,则需要调整一些程序,下面介绍在自有的VOC格式的数据集上训练Yolov3的方法。

        主干网络仍使用Gluoncv提供的官方Yolov3训练程序,下面链接中的train_yolo3.py:

        https://gluon-cv.mxnet.io/model_zoo/detection.html#yolo-v3

修改数据读取方式

        官方程序中给出了读取VOC和COCO数据集的方法,假如我们已经有类VOC或者COCO数据集的自有数据集,可以通过继承Gluoncv中的数据读取类的方式来读取自己的数据集,并且保留所有属性。原始训练文件train_yolo3中的get_dataset函数如下所示,其中使用了VOCDetection和COCODetection类,我们以VOC为例。

def get_dataset(dataset, args): if dataset.lower() == 'voc': train_dataset = gdata.VOCDetection( splits=[(2007, 'trainval'), (2012, 'trainval')]) val_dataset = gdata.VOCDetection( splits=[(2007, 'test')]) val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) #省略COCO和其它处理 return train_dataset, val_dataset, val_metric

        VOC采用的数据读取为VOCDetection类,定义方式如下所示,可以看到其中定义了类别name,并且有指定数据路径(root),因此我们可以继承VOCDetection类来满足需求。

class VOCDetection(VisionDataset): CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'voc'), splits=((2007, 'trainval'), (2012, 'trainval')), transform=None, index_map=None, preload_label=True): super(VOCDetection, self).__init__(root) …… @property def classes(self): """Category names.""" try: self._validate_class_names(self.CLASSES) except AssertionError as e: raise RuntimeError("Class names must not contain {}".format(e)) return type(self).CLASSES

        首先我们在训练文件中需要import VOCDetection类,假设我们自己的VOC格式的数据集种类为['1', '2','3', '4'],可以按照如下方式定义VOCLike类读取自己的数据。

from gluoncv.data import VOCDetection classes_name = ['1', '2','3', '4'] class VOCLike(VOCDetection): CLASSES = classes_name def __init__(self, root, splits, transform=None, index_map=None, preload_label=True): super(VOCLike, self).__init__(root, splits, transform, index_map, preload_label)

        然后将训练文件中的get_dataset函数修改为如下所示,root参数可指定数据集路径,可根据需求指定splits中的名字,如splits=[(2007, 'trainval'), (2012, 'trainval')]等方式。

def get_dataset(dataset, args): if dataset.lower() == 'voc': train_dataset = VOCLike(root='VOCdevkit', splits=((2007, 'trainval'),)) val_dataset = VOCLike(root='VOCdevkit', splits=((2007, 'test'),)) val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=classes) elif dataset.lower() == 'coco': #…… if args.num_samples < 0: args.num_samples = len(train_dataset) if args.mixup: from gluoncv.data import MixupDetection train_dataset = MixupDetection(train_dataset) return train_dataset, val_dataset, val_metric

2. 修改网络输出类别

        Gluoncv提供在VOC和COCO上的预训练模型,因此我们可以方便地使用预训练模型在自己的数据集上微调参数。官方提供两种方法,第一种为get_model下VOC预训练模型,然后通过reset_class设置成自己需要的类别;第二种使用get_model中的cuctom直接设置成需要的类别,并复用VOC参数。

(1)使用VOC然后reset_class

        train_yolo3训练程序中有get_model部分,原始代码如下:

if args.syncbn and len(ctx) > 1: net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm, norm_kwargs={'num_devices': len(ctx)}) async_net = get_model(net_name, pretrained_base=False) else: net = get_model(net_name, pretrained_base=True) async_net = net

        其中get_model的参数pretrained_base表示加载imagenet上预训练的基础主干网络参数,如果想加载VOC上训好的检测模型参数,测需要将pretrained_base改为pretrained,在加载模型之后,需要使用reset_class函数更改预训练模型以满足自己数据集类别。

if args.syncbn and len(ctx) > 1: net = get_model(net_name, pretrained=True, norm_layer=gluon.contrib.nn.SyncBatchNorm, norm_kwargs={'num_devices': len(ctx)}) async_net = get_model(net_name, pretrained=True) else: net = get_model(net_name, pretrained=True) async_net = net net.reset_class(classes_name) async_net.reset_class(classes_name)

(2)使用cuctom,根据官方finetune_detection.py示例,效果和上面一样(未测试)。

net = gcv.model_zoo.get_model(net_name, classes= classes_name, pretrained_base=False, transfer='voc')

        cuctom定义如下,可以发现实现和上面一样,也是根据transfer参数get_model预训练模型,然后reset_class。

from ...model_zoo import get_model net = get_model( 'yolo3_mobilenet0.25_' + str(transfer), pretrained=True, **kwargs) reuse_classes = [x for x in classes if x in net.classes] net.reset_class(classes, reuse_weights=reuse_classes)

        reset_class中的resue_weights参数,可以根据需要修改输出层类别数量,并且选择复用的参数。比如我们自己的数据集中有部分种类和VOC一样,那这部分输出层参数就可以直接复用VOC预训练模型的参数,其它不一样的类的输出分支重新初始化。再比如我们现在只想要一个检测行人的模型,如果模型中包含其它无用类,会使得模型有冗余,我们可以使用resue_weights参数将reset_class后的输出层直接复用VOC训练参数,将模型修改为只检测行人而不用重新微调。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3