admin 发布的文章
GPT算法详解
一、背景
自从transformer出来之后,后面的算法基本上都是基于这个为基础,比如bert是以Encode层,GPT系列的GPT、GPT2、GPT3都是Decode层,下面我们主要讲解一下GPT。
1、论文
论文名字:《Improving Language Understanding by Generative Pre-Training》
论文地址:Improving Language Understanding by Generative Pre-Training
2、论文发表时间
时间:2018年6月
团队:openAI、特斯拉老板马斯克的公司
二、架构
1、架构图
:注意:
GPT 使用 Transformer 的 Decoder 结构,并对 Transformer Decoder 进行了一些改动,原本的 Decoder 包含了两个 Multi-Head Attention 结构,GPT 只保留了 Mask Multi-Head Attention
2、网络结构参数
网络层数:12层
参数个数:
参数文件大小:
训练数据大小:
最大上下文字数:512
三、预训练
预训练阶段就是用海量的文本数据通过无监督学习的方式来获取语言学知识。
目标函数:
$L_\theta(x) = \sum_{i=1}^{m}logP(w_i|w_1,...,w_{i-1})$
四、微调
1、Textual Entailment
这个问题的目标是判断两个句子是包含关系,还是矛盾关系或者中立关系。
对于这个问题,输入是两个句子,premise p和hypothesis h,如上图所示,p和h被拼接了起来,然后输入给transformer。
2、Similarity
这个问题是判断两个句子是否相似。
对于这个问题,输入的两个句子,正向和反向各拼接一次,然后分别输入给transformer,得到的输出拼接在输入给下一层。
3、Question Answer和Commonsense reasoning
对于这两个问题,输入都是一个document d,一个问题q,还有若干个答案answer。
对于这个问题,输入的d和q先拼接起来,然后这个拼接和每个answer都拼接起来,输入给transformer,得到的结果分别输入给一个全连接层,再得到的结果去进行softmax形成概率分布。
Tensorflow 命令行参数定义tf.flags和tf.app.flags
一、版本
[tensorflow1.10 到 tensorflow1.13] 中都有
tf.flags和tf.app.flags
[tensorflow1.14 到 tensorflow1.15] 以及 tensorflow2.0不在有了
二、作用
Tensorflow 采用tf.app.flags 来进行命令行参数传递.
如 - flags_test.py
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# Settings for some training parameters.
flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'],
'Learning rate policy for training.')
flags.DEFINE_float('base_learning_rate', .0001,
'The base learning rate for model training.')
flags.DEFINE_integer('learning_rate_decay_step', 2000,
'Decay the base learning rate at a fixed step.')
flags.DEFINE_integer('train_batch_size', 12,
'The number of images in each batch during training.')
flags.DEFINE_multi_integer('train_crop_size', [513, 513],
'Image crop size [height, width] during training.')
flags.DEFINE_boolean('upsample_logits', True,
'Upsample logits during training.')
flags.DEFINE_string('dataset', 'dataset_name',
'Name of the test dataset.')
def main(_):
print(FLAGS.learning_policy)
print(FLAGS.base_learning_rate)
print(FLAGS.learning_rate_decay_step)
print(FLAGS.train_batch_size)
print(FLAGS.train_crop_size)
print(FLAGS.upsample_logits)
print(FLAGS.dataset)
if __name__ == '__main__':
tf.app.run()