浅理解小样本学习:原型网络

您所在的位置:网站首页 is的原形是啥 浅理解小样本学习:原型网络

浅理解小样本学习:原型网络

2024-06-29 08:35| 来源: 网络整理| 查看: 265

!写在前面:本篇完全是当作对原型网络的一个学习笔记,只是简单的理解。当然如果存在错误欢迎批评指正!

原型网络的概念是2017年由Snell等人首次提出,论文:

Prototypical Networks for Few-shot Learninghttps://arxiv.org/abs/1703.05175

1. 小样本学习背景

        Few-Shot Learning,国外一般叫缩写FSL,国内翻译为小样本学习。但是我觉得翻译的并不是很好,并没有体现FSL的核心思想。我的理解FSL的核心是通过某种方法(现在通常是元学习的方法)利用通用数据得到泛化能力较强的预训练模型,然后在下游任务中根据预训练模型微调或者其他方法得到新模型。所以FSL其实是“少次的学习”,再者才是“小样本”。FSL的问题是如何利用少量样本适用在新的问题中,我们知道样本少会出现过拟合问题。还有个问题是跨领域适应问题(DA)问题,简单说就是比如利用一些金融、科学等数据得到通用模型,模型如何泛化到只有少量数据的医学类的问题中。要在数据量较少的情况下完成NLP任务,如果采用深度学习的方法,一方面会产生过拟合,另一方面是会导致模型的泛化性较差。而小样本学习采用元学习的方法旨在学习人类的学习能力。

        可以说FSL是要实现的目标,即要实现用少量样本学习一个模型。通常实现FSL的手段是元学习,即Meta-Learning。 Meta-X的意思是X about X,所以Meta-Learning是learn about learning,学习关于学习,当然也可以Mata-Meta-Learning,嵌套就是了。回到正题,元学习考虑的是学会人类“触类旁通”的能力,例如:一个小朋友去参观动物园的熊猫,虽然他之前从来没有见过熊猫更不认识熊猫,但是给他几张不同动物但是包含熊猫图片的卡片,小朋友可以很轻松的从这些照片中认出这个动物是卡片上的熊猫。元学习的方法是创建不同的情景式训练。情景式训练是将训练集划分为支持集(Support Set)和查询集(Query Set)。支撑集是用来更新模型的参数,而查询集是用来验证算法的好坏,并确定最好的算法。这就相当于学习人类的学习方式,在 训练过程中一边训练一边测试。支持集就是上述例子中动物园所有的动物,查询集就是给小朋友的动物卡片。查询集就可以当成是机器学习中的验证集。FSL对于任务分类的概念是N-way-K-shot,其中N-way是分为N个类,K-shot是每个类k个数据样本。所以当训练时会分为很多个Eposide,每个Eposide有N个类,每个类K+1(支持集k个查询集1个)个样本,即总共N*(K+1)个样本。

元学习与有监督学习输入和输出的区别 

        元学习主要分为基于参数优化的方法和基于度量的方法(不需要模型,非参数的方法),原型网络就是基于度量的方法。对于不需要模型且非参数的方法,能想到的最简单有效的方法就是KNN,即K个最近邻的方法,原型网络就是采用了这个方法。

        这里介绍一些关于元学习的学习网址或视频课程等:

Stanford CS330课程   Standford CS330课件

台大李宏毅教授视频课程   课程主页

火炉课堂

2. 原型网络

        主要分析原论文的摘要、介绍和原理部分,实验部分跳过。

摘要

        作者提出了原型网络这个概念来解决小样本分类问题。分类器必须泛化到训练集中没有出现的新类,而每个新类只有少量的样本。原型网络需要学习一个度量空间,在这个度量空间中,可通过计算到每个类的原型表示的距离(作者采用的是欧式距离而不是余弦距离)来进行分类。与近期的FSL方法相比,原型网络反映了一种更简单的归纳偏差,在这个有限的数据区域是有益的,并取得了很好的效果。作者对其进行分析表明,一些简单的设计决策可以比涉及复杂体系结构选择和元学习的最新方法产生实质性的改进。进一步将原型网络扩展到零样本,在CU-Birds数据集上获得了非常好的结果。

介绍

 用的原文中的例子介绍下原型网络的结构:这里是一个3-way-5-shot的例子,即共3个类(c1,c2,c3),每个类5个样本。原型网络首先要学习这个度量空间:即如何将相同类的样本映射到同一个区域,而不同类的样本尽可能离得更远。然后根据这个度量空间,计算每个类的类原型。计算5个样本的均值得到c1,c2,c3这三个类原型,对于查询样本x,通过欧式距离计算x与c1,c2,c3的距离,距离哪个近就是哪个类,比如上图x距离c2类原型更近,所以属于c2类。所以其实就是个聚类的思想。 

原理        

集合S是训练集,N表示训练集共有N个数据样本,xi表示D维的i类的特征向量,yi表示i类的标签。K表示训练集有K个类,yi\epsilon{1,...,K}表示在这K个类对应的类标签。Sk表示k这个类对应的数据样本集合。

 这一步就是计算类原型,f\phi (x)是嵌入函数,即将样本数据xi转换为嵌入向量,ck就是计算Sk这个集合的均值。

这一步是利用softmax函数计算x是k这个类的概率,softmax函数实际上就是将数据归一化到(0,1)这个区间,首先要保证是正数,所以采取了指数,然后再除以求和的总值。这里的参数是距离,即样本x与类原型ck的距离,这里用-d是因为离的近,指数值就越大,概率越大。

最后就是计算损失函数,采用 随机梯度下降更新参数\phi。作者给出的伪代码:

输入:训练集共N个样本,每个样本x对应一个类标签y,其中y表示的是K个类中的某个类。

输出:对一个Episode的损失值

首先第一步是用RandomSample函数选出一个episode——Vk。即从K个类中随机抽出Nc个类,其实这个Nc也就是N-way-K-shot中的N。

然后第二步就是选出支持集Sk和查询集Qk,支持集从集合Vk中选出Ns个样本,这Ns就是N-way-K-shot中的K。

第三步就是计算类原型,之前有介绍。不过这里论文原文估计是打错了,分母应该是Ns而不是Nc。 

最后计算损失值,先初始换,然后在查询集中做参数更新,也就是找到最好的一个度量空间,这样可以使得相同的类被映射的更靠近而不同的类尽可能的被映射的更远,这也是原型网络唯一需要关心的问题,那么通常Embedding可以选 CNN、RNN、BERT等等。最终损失函数如下:

 把公式简单的化简下就是上面那一长串了,只不过还是那个问题:分母应该是NsNq而不是NcNq。 

 后面还有一部分是作者解释为什么用欧氏距离更好而不是余弦距离,总的来说是个数学问题,我觉得没必要深究。

原型网络这篇论文的发表无疑对小样本学习又起到了推波助澜的作用,因为它简单的方法以及不错的性能后续作为了baseline模型。

以上就是我对原型网络大致的总结,大部分还没理解的透彻,所以后续我还是会从代码的角度剖析原型网络,以及还没有提到原型网络的缺点、后续更多人对它的改进等等的内容。

【------------------------------------------------------分割线------------------------------------------------------------】

        更新下后续对原型网络的改进:

        原型网络实际上就是一个KNN的方法:将输入通过嵌入函数映射到度量空间,计算出每个类的原型,然后利用欧式距离计算出预测点x与各个类原型的距离判断是哪个类。

        所以本质上原型网络能学习的就是这个嵌入函数。原作者才用的是CNN模型,所以改进可以采用强大的BERT模型。

        其次是如果不同类之间关系比较复杂,那么是否可以考虑计算类与类之间的关系?是否可以更换通过直接计算距离的方式进行预测?主要有三种改进:

         (从左至右),第一个是关系网络,它没有直接计算距离,而是将查询样本也作嵌入后得到的嵌入向量加到输入样本嵌入向量后面得到新嵌入向量后再经过一个网络层进行预测。



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3