一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba |
您所在的位置:网站首页 › yoon中文是什么意思 › 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba |
前言
不知读者发现没有,本文标题的信息含量很大,比如 出来了一个新的序列模型:Mamba,其基于SSM或S4发展为S6(S4 models with a selection mechanism and computed with a scan),其对应的论文为《Mamba: Linear-Time Sequence Modeling with Selective State Spaces》该Mamba模型的提出者为Albert Gu、Tri Dao,前者现在是CMU助理教授,多年来一直推动SSM架构发展,曾在DeepMind 工作,后者则为鼎鼎大名的Flash Attention一作 换言之,除了论文中展示的效果确实不错之外,由于提出者的背景不一般,所以关注的人比较多Transformer统治各大领域近7年了,7年来,挑战Transformer的模型其实不少 (比如linear attention, gated convolution and recurrent models, and SSMs),该模型能否真正颠覆Transformer的霸权呢?对此,我们可以细究其原理细节,看看其创新到底是否靠谱、力度是否大加之有一大模型项目开发营的朋友问道,可否在论文100课上解读下Mamba这篇论文,于此,便有了此文,且具备3个特点 清晰易懂:也为「不需要天天看paper的朋友」而写 在ChatGPT诞生后的一年来,以大模型为代表的技术发展特别快,经常一个月会出来很多新的技术、模型 而不一定非得是每天在实验室扎根于科研的人 才有资格去追踪前沿技术发展,还有一大帮可能是出于对前沿技术的了解、兴趣、热爱、应用而想追踪,可这帮朋友平时或因工作或事太多而不一定对每个新技术、新模型都去看一遍论文,即不可能天天看paper 那咋办呢?他们可能通过一些比如公众号之类的文章去了解,但有的公号文章写的不错,有的则写的不够清晰易懂甚至漏洞百出,会因此让读到这种文章的朋友对新技术、新模型产生畏难心理甚至被误导 故,我和我司来了,为帮助更多朋友更好、更快、更细致的了解大模型相关技术及其实践,我个人算是笔耕不辍(我自23年年初以来也史无前例的写了近30篇,详见:大模型与ChatGPT系列:原理、论文、代码、应用)、团队和我算讲课不停中英对比:部分关键的阐述中英文对照学习 考虑到这些新技术、新模型刚推出的时候,论文还是相对最严谨的参考,所以本文会延续前几篇文章的风格:对于一些关键的阐述会把原英文的表述用斜体且淡色的黑体表示,毕竟有的描述对其翻译相比,用原英文阐述更精准足够细致:从SSM、HiPPO、S4起步,逐步推导到Mamba 目前介绍mamba模型的文章,少部分写得很不错,大部分不是这个细节没深入,便是那个细节没深入,考虑到如果很多关键细节没有介绍的话,那没法彻底理解mamba模型 因此,本文会尽可能兼顾所有必须写清楚的细节(比如如果不理解SSM和S4则无法理解mamba模型,故本文会从HiPPO、SSM、S4起步,逐步推导到mamba),尽可能一文通透mamba模型两轮重大更新:考虑到之前本文的早期版本介绍的mamba前置知识不够彻底的清晰易懂,为让文科生都能一眼看明白,故 24年3.2-3.11,这10来天做了第一轮大规模修订 比如把前置知识特别是ssm/S4介绍的更加细致(过程中的核心参考之一是:A Visual Guide to Mamba and State Space Models,有些图来自该文,有些内容翻译自此文)24年3.23-3.25,这几天做了第二轮大规模修订 比如把此节「3.1.1 选择性状态空间模型:从S4到S6」的内容做了全面细致的补充 特别是把A B C三个矩阵分别在S4、mamba中各自所对应的背后含义、维度表示、维度变化一针见血的解释清楚总之,看本文之前,你可能看到的很多关于mamba的文章都不知所云,但看了本文之后,你再看那些文章你会有一种“他如果怎样怎样写,会更加清晰易懂”的感觉,毕竟“好懂的文章”只有一个标准:就是能一直不烧脑的读下去而不卡壳 第一部分 基础梳理:Transformer时间复杂度、RNN、SSM 1.1 Transformer的二次复杂度通过之前本博客内的另一篇文章《通透理解FlashAttention与FlashAttention2:全面降低显存读写、加快计算速度》,可知 简单理解的话,计算复杂度和序列长度的平方 因为我们需要拿第一个矩阵的每一行去与第二个矩阵的每一列做点乘,所以总共就需要 精确理解的话,当输入批次大小为 但这个结果是怎么一步一步计算得到的呢?请看原文 正因为现有的ChatGPT等大模型处理长文本算力消耗巨大,背后原因是Transformer架构中注意力机制的二次复杂度 一方面,有了针对注意力机制的各种所谓魔改,甚至也有S4、FlashAttention及其二代等二方面,S4、FlashAttention等作者提出了新的序列模型:Mamba,在很多语言任务上击败/匹配Transformer性能,具有线性复杂度和5倍推理吞吐量,下文详述 1.2 RNN关于什么是RNN,我之前博客内的这篇文章《如何从RNN起步,一步一步通俗理解LSTM》中做了详细介绍,每一个时刻的隐藏状态 总之,RNN在序列中的每个时间步需要两个输入,即时间步 这一点值得好好体会:先根据输入 至于为何要先介绍RNN呢,很快你就会明白了(RNN和SSM是一个本质) ![]() RNN主要存在两个问题 第一个问题在于,虽然每个隐藏状态都是所有先前隐藏状态的聚合,然随着时间的推移,RNN 往往会忘记某一部分信息,比如下图最后一个隐藏状态在生成名称“ Maarten”时不再包含有关单词“Hello”的信息(说白了,就是如此文所说的,在实践中,如本文开头所说,mamba论文的一作Albert Gu多年来一直在推动SSM的发展 他在SSM的基础上,通过此篇论文《Efficiently Modeling Long Sequences with Structured State Spaces》首次提出了结构化状态空间S4(这里有关于S4的更多论文),但这篇论文的可读性比较差当然,作者在YouTube上有一个关于这篇S4论文的精彩解读,比S4论文的可读性提高很多,且本文中也应用了其中的部分PPT截图,但还可以更加通俗易懂好在如本文开头所述,Maarten Grootendorst写了一篇《A Visual Guide to Mamba and State Space Models》,很通俗,包括本部分中的不少图来自该文,不少内容翻译自此文,至于原英文中有些表述不准确的地方,我则都已修正 1.3.1 什么是状态空间想象一下我们正在穿过一个迷宫,图中每个小框代表迷宫中的一个位置,并附有某个隐式的信息,例如你距离出口有多远 而上述迷宫可以简化建模为一个“状态空间表示state space representation”,每一个小框显示 你当前所在的位置(当前状态current state)下一步可以去哪里(未来可能的状态possible future states)以及哪些变化会将你带到下一个状态(向右或向左)而描述状态的变量(在我们的示例中为 X 和 Y 坐标以及到出口的距离)可以表示为“状态向量state vectors” SSM 是用于描述这些状态表示并根据某些输入预测其下一个状态可能是什么的模型 一般SSMs包括以下组成 映射输入序列x(t),比如在迷宫中向左和向下移动到潜在状态表示h(t),比如距离出口距离和 x/y 坐标并导出预测输出序列y(t),比如再次向左移动以更快到达出口然而,它不使用离散序列(如向左移动一次),而是将连续序列作为输入并预测输出序列 然后,请你再细品一下 上面的第一个方程是不和RNN循环结构:总之,通过求解这些方程,可以根据观察到的数据:输入序列和先前状态,去预测系统的未来状态 1.3.3 SSM的两个方程:状态方程与输出方程 总之,SSM的关键是找到:状态表示(state representation)——最终,我们可以通过下图统一这两个方程 回到我们的简化视角,现在可以关注只矩阵 由于除了连续的输入之外,还会通常碰到离散的输入(如文本序列),不过,就算SSM在离散数据上训练,它仍能学习到底层蕴含的连续信息,因为在SSM眼里,sequence不过是连续信号signal的采样,或者说连续的信号模型是离散的序列模型的概括 那模型如何处理离散化数据呢?答案是可以利用零阶保持技术(Zero-order hold technique) 这些采样值就是我们的离散输出,且可以针对A、B按如下方式做零阶保持(做了零阶保持的在对应变量上面加了个横杠) 最终使我们能够从连续 SSM 转变为离散SSM,使得不再是函数到函数x(t) → y(t),而是序列到序列xₖ → yₖ,所以你看到,矩阵 注意:我们在保存时,仍然保存矩阵 总之,离散 SSM 允许可以用离散时间步长重新表述问题 在每个时间步,都会涉及到隐藏状态的更新(比如 为方便大家理解其中的细节,我再展开一下 有没有眼前一亮?如此,便可以RNN的结构来处理 然后可以这样展开(其中, 在经典的图像识别任务中,我们用过滤器(即卷积核kernels)来导出聚合特征,而SSM也可以表示成卷积的形式 由于我们处理的是文本而不是图像,因此我们需要一维视角 而用来表示这个“过滤器”的内核源自 SSM 公式 但怎么理解这个公式呢?一般的文章可能一带而过,但本文咱们还是通过一个例子一步一步理解 与卷积一样,我们可以使用 SSM 内核来检查每组token并计算输出至此,总结一下,将 SSM 表示为卷积的一个主要好处是它可以像卷积神经网络CNN一样进行并行训练。然而,由于内核大小固定,它们的推理不如 RNN 那样快速 那有没两全其美的办法呢?最终是有的 作为从输入信号到输出信号的参数化映射,SSMs可以当做是RNN与CNN的结合「These models can be interpreted as acombination of recurrent neural networks (RNNs) and convolutional neural networks (CNNs)」,即推理用RNN结构,训练用CNN结构如我们之前在循环表示中看到的那样,矩阵 其实,某种意义上,算是矩阵A产生了隐藏状态(matrix A produces the hidden state) 由于矩阵A只记住之前的几个token和捕获迄今为止看到的每个token之间的区别,特别是在循环表示的上下文中,因为它只回顾以前的状态 那么我们怎样才能以保留比较长的memory的方式创建矩阵A呢? 答案是可以使用Hippo(Hippo的全称是High-order Polynomial Projection Operator,其对应的论文为:HiPPO: Recurrent Memory with Optimal Polynomial Projections),解决如何在有限的存储空间中有效地解决序列建模的长距离依赖问题HiPPO尝试将当前看到的所有输入信号压缩为系数向量(HiPPO attempts to compress all input signals it has seen thus far into a vector of coefficients)它使用矩阵 具体表示可以如下图所示 正由于HiPPO 矩阵可以产生一个隐藏状态来记住其历史(从数学上讲,它是通过跟踪Legendre polynomial的系数来实现的,这使得它能够逼近所有以前的历史),使得在被应用于循环表示和卷积表示中时,可以处理远程依赖性 如此,S4的定义就出来了:序列的结构化状态空间——Structured State Space for Sequences,一类可以有效处理长序列的 SSM(S4所对应的论文为:Efficiently Modeling Long Sequences with Structured State Spaces) 且对矩阵A 做了改进 注,本部分只作为选读,因为本部分要介绍的重点 上文已经介绍过了,但为何还是要增加这个选读部分呢,一者 本部分来自mamba论文的一作Albert Gu的解读,虽然其公式表达不如上文第二部分的表达顺眼(比如状态被他改写成x,输入被他改写成u),但有些论文的表达还是用的Albert Gu的这个表述,故权衡利弊,还是增加本部分 2.2.1 改进transformer不擅长处理超长的序列的问题:输入u到状态x序列数据一般都是离散的数据 比如文本、图、DNA 但现实生活中还有很多连续的数据,比如音频、视频,对于音视频这种信号而言,其一个重要特点就是有极长的context window而在transformer长context上往往会失败,或者注意力机制在有着超长上下文长度的任务上并不擅长(所以你才看到各种对注意力机制的改进,比如flashattention等等,即便如此一般也就32K的上下文长度,在面对100w的序列长度则无能为力),而S4擅长这类任务为了方便大家更好的理解,Albert Gu举了一个金融领域的例子 即根据输入,计算其EMA(如下图所示,黑色的一直在跳跃着的曲线是输入x,输出y是蓝色的线)July注:总之,相比用x 表示对应的summary,其实如果用h表示对应的summary,会更清晰,如此,也和上文的第二部分的表达统一起来了 2.2.2 HiPPO的定义与推导:state compresses the history of input我们已经知道 RNN 被诟病的一个点恰恰是 hidden state 的记忆能力有限(毕竟hidden state 的大小是固定的, 但是需要记忆的内容是随着 sequence length 增加的,用一个有限的容器去装源源不断的水流, 自然要有溢出) 那怎么改善这个问题呢?或者怎么定义一个好的 hidden state 的记忆 假设 发现HiPPO在低阶信号上work后,我们希望将它扩展到高阶信号上。阶数越高,与LLM越相似,工作的价值就越大 但是我们不能直接堆叠HiPPO算子,因为不断增加维度会引起维数爆炸:![]() 我们正式定义下S4 首先,有一个state space model,简称为SSM其次,在下图所示的两个方程中插入特定的矩阵值下图所示的便是S4的三个性质 最终,状态空间模型(SSM)将这些表示作为深度学习管道中的一层(A state-space model (SSM) uses these representations as a layer in a deep learning pipeline),并且矩阵 第二个性质是有效的online计算,这点之前在HiPPO提到了,就是计算下一时刻的state
SSM的一个问题是,当知道未来的signal的时候,训练是低效的。有没有办法并行化SSM?作者提出了使用一个卷积核 输入 问题好像解决了,但SSM还是存在两个问题 一个是计算复杂度的问题,最终通过给SSM做结构化(比如使用HiPPO矩阵,相当于变成了S4),即structured state space can be computed faster首先,Linear Time Invariance(LTI)规定 SSM中的A、B、C始终是固定不变的。这意味着 对于 SSM 生成的每个token,矩阵A 、B、C都是相同的(regardless of what sequence you give the SSM, the values of A,B,and C remain the same. We have a static representation that is not content-aware)使得SSM无法针对输入做针对性的推理「since it treats each token equally as a result of the fixed A, B, and C matrices. This is a problem as we want the SSM to reason about the input (prompt)」此外,如下图所示,无论输入x 是什么,矩阵 B都保持完全相同,因此与x无关 同样,无论输入如何,A和C也保持固定 比如 “I want to order a hamburger.”这句 如果没有选择性,S4会花费相同的“精力”来处理每个单词:凡事也有利有弊,虽然mamba可以“专注于”输入中对于当前任务更重要的部分,但坏处是没法再通过CNN做并行训练了,原因在于: 让我们回想一下之前计算的卷积核![]() 说白了,如果我们想要选择性,得用RNN模式进行训练(If we want selectivity, we’ll need to train with RNN mode),然偏偏RNN的训练速度非常慢,emmm,所以我们需要找到一种无需卷积的并行训练方式(详见下文的3.1.2节) 第三部分 Mamba的三大创新mamba(其对应论文为:Mamba: Linear-Time Sequence Modeling with Selective State Spaces,这是其对应的GitHub代码地址),在语言、音频、DNA序列模态上都实现SOTA,在最受关注的语言任务上,Mamba-3B超越同等规模的Transformer,与两倍大的Transformer匹敌,并且相关代码、预训练模型checkpoint都已开源 简言之,Mamba是一种状态空间模型(SSM),建立在更现代的适用于深度学习的结构化SSM (简称S6)基础上,与经典架构RNN有相似之处 3.1 Mamba = 有选择处理信息 + 硬件感知算法 + 更简单的SSM架构与先前的研究相比,Mamba主要有三点创新: 对输入信息有选择性处理(Selection Mechanism)硬件感知的算法(Hardware-aware Algorithm) 该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算(使得不用CNN也能并行训练),但为了减少GPU内存层次结构中不同级别之间的IO访问,它没有具体化扩展状态 当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发更简单的架构 将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计 3.1.1 选择性状态空间模型:从S4到S6作者认为,序列建模的一个基础问题是把上下文压缩成更小的状态(We argue that a fundamental problem of sequence modeling is compressing context into a smaller state),从这个角度来看 transformer的注意力机制虽然有效果但效率不算很高,毕竟其需要显式地存储整个上下文(storing the entire context,也就是KV缓存),直接导致训练和推理消耗算力大 好比,Transformer就像人类每写一个字之前,都把前面的所有字+输入都复习一遍,所以写的慢RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制On the other hand, recurrent models are efficient because they have a finite state, implying constant-time inference and linear-time training. However, their effectiveness is limited by how well this state has compressed the context. 好比,RNN每次只参考前面固定的字数(仔细体会这句话:When generating the output, the RNN only needs to consider the previous hidden state and current input. It prevents recalculating all previous hidden states which is what a Transformer would do),写的快是快,但容易忘掉更前面的内容而SSM的问题在于其中的矩阵A B C始终是不变的,无法针对不同的输入针对性的推理,详见上文的2.4节总之,序列模型的效率与效果的权衡点在于它们对状态的压缩程度: 高效的模型必须有一个小的状态(比如RNN或S4)而有效的模型必须有一个包含来自上下文的所有必要信息的状态(比如transformer)而mamba为了兼顾效率和效果,选择性的关注必须关注的、过滤掉可以忽略的 为方便大家理解,再进一步阐述mamba与其前身结构化空间模型S4的优势 3.1.1.1 mamba前身S4的4个参数的固定不变性首先,在其前身S4中,其有4个参数(∆, A, B, C) 且它们都是固定的,不随输入变化(即与输入无关),这些参数控制了以下两个阶段 其次,再回顾一下,通过之前的讲解,可知 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |