LlaMa1原理介绍
一、简介
论文:LLaMA: Open and Efficient Foundation Language Models :https://arxiv.org/pdf/2302.13971
github:https://github.com/meta-llama/llama
发布时间: 2023 年 2 月 25
LLaMA 是一系列从 7 B到 65B 参数的基础语言模型。Meta 训练这些模型使用了数万亿个 token,并且证明了完全可以只使用公开可得的数据集来训练最先进的模型,而无需使用专有和不可获取的数据集。特别是,LLaMA-13B 在大多数基准测试中表现优于GPT-3(175B),而 LLaMA-65B 在竞争中与最佳模型 Chinchilla70B 和PaLM-540B 持平。
二、LLaMa 预训练
1.预训练数据
LLaMa 预训练数据大约包含 1.4T tokens,对于绝大部分的训练数据,在训练期间模型只见到过1次,Wikipedia 和 Books 这两个数据集见过2次。
表1所示是 LLaMa 预训练数据的含量和分布,其中包含了 CommonCrawl 和 Books 等不同域的数据。
English CommonCrawl [67%]:对五个 CommonCrawl 数据集进行预处理,时间跨度从2017年到2020年,使用 CCNet 流水线。该过程在行级别进行数据去重,使用 fastText 线性分类器进行语言识别,以删除非英语页面,并使用 n-gram 语言模型过滤低质量内容。此外,还训练了一个线性模型,用于将页面分类为 Wikipedia 中的引用页面与随机抽样页面,并丢弃未被分类为引用的页面。
C4 [15%]:C4的预处理还包括去重和语言识别步骤:与 CCNet 的主要区别在于质量过滤,这主要依赖于标点符号的存在或网页中的词语和句子数量等启发式方法。
Github [4.5%]:使用 Google BigQuery 上可用的公共 GitHub 数据集。只保留了在 Apache、BSD 和 MIT 许可下发布的项目。此外,使用基于行长度或字母数字字符比例的启发式方法过滤低质量文件,并使用正则表达式删除了诸如头文件之类的样板文件。最后,对生成的数据集进行了文件级别的去重,使用完全匹配的方法
Wikipedia [4.5%]:添加了截至2022年6月至8月的 Wikipedia 数据,涵盖20种语言。处理数据以去除超链接、评论和其他格式样板。
Gutenberg and Books3 [4.5%]:添加了两个书的数据集,分别是 Gutenberg 以及 ThePile (训练 LLM 的常用公开数据集) 中的 Book3 部分。处理数据时执行重复数据删除,删除内容重叠超过 90% 的书籍。
ArXiv [2.5%]:处理了arXiv Latex文件,以添加科学数据到数据集中。移除了第一节之前的所有内容,以及参考文献。还移除了.tex文件中的注释,并且内联展开了用户编写的定义和宏,以增加论文之间的一致性。
Stack Exchange [2%]。作者添加了 Stack Exchange,这是一个涵盖各种领域的高质量问题和答案网站,范围从计算机科学到化学。作者从 28 个最大的网站保留数据,从文本中删除 HTML 标签并按分数对答案进行排序
2、Tokenizer
使用字节对编码(BPE)算法对数据进行分词,使用 SentencePiece 的实现。值得注意的是,作者将所有数字分割成单个数字。
三、网络结构
在最近关于大型语言模型的研究中,LLaMa 的网络基于 Transformer 架构。作者利用了随后提出的各种改进,这些改进在不同模型(如PaLM)中得到了应用。以下是与原始架构的主要区别,以及从哪里得到了这种变化的灵感(括号中)。
1、Pre-normalization [受 GPT3 的启发]
均方根:Root Meam Square
为了提高训练稳定性,LLaMa 对每个 Transformer 子层的输入进行归一化,而不是对输出进行归一化。LLaMa 使用了 RMSNorm 归一化函数。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
方差定义:
概率论中方差用来度量随机变量和其数学期望(即均值)之间的偏离程度。
统计中的方差(样本方差)是每个样本值与全体样本值的平均数之差的平方值的平均数。
均方根:RMS:
是root mean square的缩写。RMS值实际就是有效值,就是一组统计数据的平方的平均值的平方根。
常规的 Layer Normalization:
2、SwiGLU 激活函数 [受 PaLM 的启发]:
LLaMa 使用 SwiGLU 激活函数替换 ReLU 以提高性能,维度从4d变为2/3 *4d。
SwiGLU 是2019年提出的新的激活函数,它结合了 SWISH 和 GLU 两种者的特点。SwiGLU 主要是为了提升Transformer 中的 FFN(feed-forward network) 层的实现。
Swish(x) = x*sigmoid(ßx)
σ是sigmoid函数,β betaβ是可学习的参数或者一个固定超参数。
3、RoPE Rotary Embeddings [受 GPTNeo 的启发]:
LLaMa 没有使用之前的绝对位置编码,而是使用了旋转位置编码(RoPE),可以提升模型的外推性。关于 RoPE 的具体细节,可以参考下面的链接:
4、AdamW 优化器
LLaMa 使用了 AdamW 优化器进行训练,超参数为:β1 = 0.9,β2 = 0.95。
使用 cosine 学习率衰减策略,2000 步的 warm-up,最终学习率等于最大学习率的 10%,使用 0.1 的权重衰减和 1.0 的梯度裁剪。
四、总结
总结一下 LLaMa 的技术要点:
1、模型结构:主体结构依然是 GPT,不同之处在于位置编码使用了旋转编码(RoPE),归一化使用了 RMSNorm,激活函数使用了 SwiGLU 。
2、训练优化:选择了小 LLM配大数据的思路,预训练使用了 1.4T 的 token,经过了充分的训练。