元学习代码解析

您所在的位置:网站首页 代码的框架是什么意思 元学习代码解析

元学习代码解析

2023-04-23 11:03| 来源: 网络整理| 查看: 265

前言

本文是专门针对深度学习初学者的代码解析教程。代码地址:dragen1860/MAML-Pytorch

对于非初学者,根本不需要看代码解析,自己去分析效率更高。

我比较认可的pytorch学习路线是:

读官方文档的入门部分,对pytorch大体框架(有哪些必须的组成部分)有个感性的认识。读论文的代码,目标是逐行读懂,遇到不懂的地方就查文档、查csdn,这一过程会有很多收获。读官方文档的全部章节,加深对pytorch框架的认识,思考框架各个组成部分之间的联系,思考过程中,遇到不懂的地方就读源代码,这一过程收获最大。不适合python初学者,不适合不聪明的人。不建议用windows系统,很多库对windows极其不友好,比如PIL

这一学习路线,1+2大约需要15个小时;3属于基本功,花多长时间纯靠个人兴趣,不花时间也能做一个合格的调参侠。

首先,深度学习框架中的“框架”是什么意思?

一个深度学习项目,就是根据要实现的深度学习算法,(1)要先定义好网络,(2)然后从数据集中随机取出一个batch送入网络中,(3)最后通过输出结果与真实结果之间的误差,更新参数。

其中,(1)和(2)都各自需要继承pytorch中的某一个定制类。所谓定制类,就是类中有一些特殊的函数,我们继承这些类,就必须要针对自己的算法实现这些特殊函数。当我们把这些函数实现好了,(1)和(2)的代码就接近写完了。

这就是框架的力量,当我们调用pytorch框架,必须根据自己算法实现其中的定制类,定制的东西写好了,其它部分只需要交给框架完成,省时省力。

所以,本文源码解析,就是要向读者介绍(1)、(2)、(3)是如何依据框架实现的,下面先贴一下代码总体结构图。

数据加载

这一模块要实现:针对从网络上下载好数据集,(1)从这一数据集中随机取出一组数据组成一个batch,(2)把得到的batch转变为合法输入,具体来说,要得到能直接送进神经网络中的张量。

数据加载需要继承torch.utils.data.Dataset类,通过继承它,再在主函数配合以torch.utils. data.DataLoader,就可以定义出一个迭代器。随后,在每一次主函数的训练中都会从这个迭代器中取出一个batch的数据,送到神经网络中训练。

这个类是定制类,以它构造的子类,一定要定义__init__(),以实现上面的(1),要定义__getitem(),以实现上面的(2)。

网络搭建

这一模块要实现神经网络前向传播的整个过程,搭建好后,输入数据,就可以得到结果。

如果要用pytorch定义自已的网络,就一定要继承torch.nn.Module类,它是专门为神经网络设计的模块化接口(nn构建于autograd之上,可以用来定义和运行神经网络)。nn.Module是nn中十分重要的类, 包含网络各层的定义及forward方法。

继承之后,必须要实现__init__()方法,一般会在这一方法中定义好网络要用到的各层。

随后,必须要实现forward方法,具体来说,就是定义输入数据是如何在网络中前向传播的。

网络训练

这一模块要实现神经网络反向传播的整个过程,搭建好后,输入有标签数据,就能更新神经网络参数。

它依然要继承torch.nn.Module类,

首先会在__init__()中,实例化上一小节创建的神经网络类net=learner()

其次会在forward()中,先喂数据给net,进行正向传播,再根据结果,结合MAML算法,进行反向传播,更新参数。

forward具体怎么实现,三言两句讲不清楚,建议直接读源代码,只要不是傻子,最多花你一个小时。

主函数

没啥好写的,这部分不属于框架,它要求你按照逻辑,调用上面实现的三个模块,训练网络,打印结果。

只要读过pytorch文档的入门章节,能不费力得写出leetcode中一半的medium难度题目,就可以毫不费力地写出主函数。

总结

虽说是对MAML的源码解析,但涉及MAML的部分很少,因为对于其它深度学习算法,也可以使用这个源码,只需要重新写一下网络训练文件,即meta.py就好。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3