fix
This commit is contained in:
parent
7e7c2c30ae
commit
65f7a1f92a
|
@ -2,7 +2,6 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 前言\n",
|
||||
"本文包含大量源码和讲解,通过段落和横线分割了各个模块,同时网站配备了侧边栏,帮助大家在各个小节中快速跳转,希望大家阅读完能对BERT有深刻的了解。同时建议通过pycharm、vscode等工具对bert源码进行单步调试,调试到对应的模块再对比看本章节的讲解。\n",
|
||||
|
@ -27,22 +26,21 @@
|
|||
" - BertIntermediate\n",
|
||||
" - BertOutput\n",
|
||||
" - BertPooler"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"## 1-Tokenization分词-BertTokenizer\n",
|
||||
"和BERT 有关的 Tokenizer 主要写在[`models/bert/tokenization_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py)中。"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import collections\n",
|
||||
"import os\n",
|
||||
|
@ -459,11 +457,12 @@
|
|||
" else:\n",
|
||||
" output_tokens.extend(sub_tokens)\n",
|
||||
" return output_tokens"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"class BertTokenizer(PreTrainedTokenizer):\n",
|
||||
|
@ -493,16 +492,21 @@
|
|||
"- encode:对于单个句子输入,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;对于两个句子输入(多个句子只取前两个),分解词并加入特殊词形成“[CLS], x1, [SEP], x2, [SEP]”的结构并转换为下标列表;\n",
|
||||
"- decode:可以将 encode 方法的输出变为完整句子。\n",
|
||||
"以及,类自身的方法:"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"bt = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
||||
"bt('I like natural language progressing!')\n",
|
||||
"# {'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Downloading: 100%|██████████| 232k/232k [00:00<00:00, 698kB/s]\n",
|
||||
"Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 11.1kB/s]\n",
|
||||
|
@ -510,25 +514,20 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"execution_count": 4
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bt = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
||||
"bt('I like natural language progressing!')\n",
|
||||
"# {'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}"
|
||||
]
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"## 2-Model-BertModel\n",
|
||||
|
@ -639,13 +638,12 @@
|
|||
"- _prune_heads:提供了将注意力头剪枝的函数,输入为{layer_num: list of heads to prune in this layer}的字典,可以将指定层的某些注意力头剪枝。\n",
|
||||
"\n",
|
||||
"** 剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(注意先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法.**"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers.models.bert.modeling_bert import *\n",
|
||||
"class BertModel(BertPreTrainedModel):\n",
|
||||
|
@ -819,11 +817,12 @@
|
|||
" attentions=encoder_outputs.attentions,\n",
|
||||
" cross_attentions=encoder_outputs.cross_attentions,\n",
|
||||
" )"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"***\n",
|
||||
"### 2.1-BertEmbeddings\n",
|
||||
|
@ -837,13 +836,12 @@
|
|||
"三个 embedding 不带权重相加,并通过一层 LayerNorm+dropout 后输出,其大小为(batch_size, sequence_length, hidden_size)。\n",
|
||||
"\n",
|
||||
"** [这里为什么要用 LayerNorm+Dropout 呢?为什么要用 LayerNorm 而不是 BatchNorm?可以参考一个不错的回答:transformer 为什么使用 layer normalization,而不是其他的归一化方法?](https://www.zhihu.com/question/395811291/answer/1260290120)**"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertEmbeddings(nn.Module):\n",
|
||||
" \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n",
|
||||
|
@ -903,11 +901,12 @@
|
|||
" embeddings = self.LayerNorm(embeddings)\n",
|
||||
" embeddings = self.dropout(embeddings)\n",
|
||||
" return embeddings"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"### 2.2-BertEncoder\n",
|
||||
|
@ -920,13 +919,12 @@
|
|||
"在 BertEncoder 中,gradient checkpoint 是通过 torch.utils.checkpoint.checkpoint 实现的,使用起来比较方便,可以参考文档:torch.utils.checkpoint - PyTorch 1.8.1 documentation,这一机制的具体实现比较复杂,在此不作展开。\n",
|
||||
"\n",
|
||||
"再往深一层走,就进入了 Encoder 的某一层:"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertEncoder(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1023,11 +1021,12 @@
|
|||
" attentions=all_self_attentions,\n",
|
||||
" cross_attentions=all_cross_attentions,\n",
|
||||
" )"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"#### 2.2.1.1 BertAttention\n",
|
||||
|
@ -1067,13 +1066,12 @@
|
|||
"\n",
|
||||
"- `prune_linear_layer`则负责将 Wk/Wq/Wv 权重矩阵(连同 bias)中按照 index 保留没有被剪枝的维度后转移到新的矩阵。\n",
|
||||
"接下来就到重头戏——Self-Attention 的具体实现。"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertAttention(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1122,11 +1120,12 @@
|
|||
" attention_output = self.output(self_outputs[0], hidden_states)\n",
|
||||
" outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them\n",
|
||||
" return outputs"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"##### 2.2.1.1.1 BertSelfAttention\n",
|
||||
|
@ -1344,13 +1343,12 @@
|
|||
"- head_mask 就是之前提到的对多头计算的 mask,如果不设置默认是全 1,在这里就不会起作用;\n",
|
||||
"- context_layer 即 attention 矩阵与 value 矩阵的乘积,原始的大小为:(batch_size, num_attention_heads, sequence_length, attention_head_size) ;\n",
|
||||
"- context_layer 进行转置和 view 操作以后,形状就恢复了(batch_size, sequence_length, hidden_size)。\n"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertSelfAttention(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1475,11 +1473,12 @@
|
|||
" if self.is_decoder:\n",
|
||||
" outputs = outputs + (past_key_value,)\n",
|
||||
" return outputs"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"##### 2.2.1.1.2 BertSelfOutput\n",
|
||||
|
@ -1499,13 +1498,12 @@
|
|||
"```\n",
|
||||
"\n",
|
||||
"**这里又出现了 LayerNorm 和 Dropout 的组合,只不过这里是先 Dropout,进行残差连接后再进行 LayerNorm。至于为什么要做残差连接,最直接的目的就是降低网络层数过深带来的训练难度,对原始输入更加敏感~**"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"class BertSelfOutput(nn.Module):\n",
|
||||
|
@ -1520,11 +1518,12 @@
|
|||
" hidden_states = self.dropout(hidden_states)\n",
|
||||
" hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
|
||||
" return hidden_states"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"#### 2.2.1.2 BertIntermediate\n",
|
||||
|
@ -1548,13 +1547,12 @@
|
|||
"\n",
|
||||
"- 这里的全连接做了一个扩展,以 bert-base 为例,扩展维度为 3072,是原始维度 768 的 4 倍之多;\n",
|
||||
"- 这里的激活函数默认实现为 gelu(Gaussian Error Linerar Units(GELUS)当然,它是无法直接计算的,可以用一个包含tanh的表达式进行近似(略)。"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertIntermediate(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1569,11 +1567,12 @@
|
|||
" hidden_states = self.dense(hidden_states)\n",
|
||||
" hidden_states = self.intermediate_act_fn(hidden_states)\n",
|
||||
" return hidden_states"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"#### 2.2.1.3 BertOutput\n",
|
||||
|
@ -1596,13 +1595,12 @@
|
|||
"\n",
|
||||
"这里的操作和 BertSelfOutput 不能说没有关系,只能说一模一样…… 非常容易混淆的两个组件。\n",
|
||||
"以下内容还包含基于 BERT 的应用模型,以及 BERT 相关的优化器和用法,将在下一篇文章作详细介绍。"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BertOutput(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1616,11 +1614,12 @@
|
|||
" hidden_states = self.dropout(hidden_states)\n",
|
||||
" hidden_states = self.LayerNorm(hidden_states + input_tensor)\n",
|
||||
" return hidden_states"
|
||||
]
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*** \n",
|
||||
"### 2.2.3 BertPooler\n",
|
||||
|
@ -1641,22 +1640,12 @@
|
|||
" pooled_output = self.activation(pooled_output)\n",
|
||||
" return pooled_output\n",
|
||||
"```"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"input to bert pooler size: 768\n",
|
||||
"torch.Size([1, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class BertPooler(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
|
@ -1682,18 +1671,28 @@
|
|||
"x = torch.rand(batch_size, seq_len, hidden_size)\n",
|
||||
"y = bert_pooler(x)\n",
|
||||
"print(y.size())"
|
||||
]
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": [
|
||||
"input to bert pooler size: 768\n",
|
||||
"torch.Size([1, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"source": [],
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 小总结\n",
|
||||
"本小节对Bert模型的实现进行分析了学习,希望读者能对Bert实现有一个更为细致的把握。\n",
|
||||
|
@ -1704,12 +1703,12 @@
|
|||
"- BertModel 包含复杂的封装和较多的组件。以 bert-base 为例,主要组件如下:\n",
|
||||
" - 总计Dropout出现了1+(1+1+1)x12=37次;\n",
|
||||
" - 总计LayerNorm出现了1+(1+1)x12=25次;\n",
|
||||
" - 总计dense全连接层出现了(1+1+1)x12+1=37次,并不是每个dense都配了激活函数……\n",
|
||||
"BertModel 有极大的参数量。以 bert-base 为例,其参数量为 109M。\n",
|
||||
"\n",
|
||||
"## 致谢\n",
|
||||
"本文主要由浙江大学李泺秋撰写,本项目同学负责整理和汇总。"
|
||||
]
|
||||
],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
## 前言
|
||||
本文包含大量源码和讲解,通过段落和横线分割了各个模块,同时网站配备了侧边栏,帮助大家在各个小节中快速跳转,希望大家阅读完能对BERT有深刻的了解。同时建议通过pycharm、vscode等工具对bert源码进行单步调试,调试到对应的模块再对比看本章节的讲解。
|
||||
|
||||
涉及到的jupyter可以在代码库:篇章3-编写一个Transformer模型:[BERT下载](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A03-%E7%BC%96%E5%86%99%E4%B8%80%E4%B8%AATransformer%E6%A8%A1%E5%9E%8B%EF%BC%9ABERT)
|
||||
涉及到的jupyter可以在[代码库:篇章3-编写一个Transformer模型:BERT,下载](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A03-%E7%BC%96%E5%86%99%E4%B8%80%E4%B8%AATransformer%E6%A8%A1%E5%9E%8B%EF%BC%9ABERT)
|
||||
|
||||
本篇章将基于[HuggingFace/Transformers, 48.9k Star](https://github.com/huggingface/transformers)进行学习。本章节的全部代码在[huggingface bert](https://github.com/huggingface/transformers/tree/master/src/transformers/models/bert)注意由于版本更新较快,可能存在差别,请以4.4.2版本为准。
|
||||
|
||||
HuggingFace 是一家总部位于纽约的聊天机器人初创服务商,很早就捕捉到 BERT 大潮流的信号并着手实现基于 pytorch 的 BERT 模型。这一项目最初名为 pytorch-pretrained-bert,在复现了原始效果的同时,提供了易用的方法以方便在这一强大模型的基础上进行各种玩耍和研究。
|
||||
本篇章将基于H[HuggingFace/Transformers, 48.9k Star](https://github.com/huggingface/transformers)进行学习。本章节的全部代码在[huggingface bert,注意由于版本更新较快,可能存在差别,请以4.4.2版本为准](https://github.com/huggingface/transformers/tree/master/src/transformers/models/bert)HuggingFace 是一家总部位于纽约的聊天机器人初创服务商,很早就捕捉到 BERT 大潮流的信号并着手实现基于 pytorch 的 BERT 模型。这一项目最初名为 pytorch-pretrained-bert,在复现了原始效果的同时,提供了易用的方法以方便在这一强大模型的基础上进行各种玩耍和研究。
|
||||
|
||||
随着使用人数的增加,这一项目也发展成为一个较大的开源社区,合并了各种预训练语言模型以及增加了 Tensorflow 的实现,并且在 2019 年下半年改名为 Transformers。截止写文章时(2021 年 3 月 30 日)这一项目已经拥有 43k+ 的star,可以说 Transformers 已经成为事实上的 NLP 基本工具。
|
||||
|
||||
|
@ -603,7 +601,7 @@ def forward(
|
|||
- set_input_embeddings:为 embedding 中的 word_embeddings 赋值;
|
||||
- _prune_heads:提供了将注意力头剪枝的函数,输入为{layer_num: list of heads to prune in this layer}的字典,可以将指定层的某些注意力头剪枝。
|
||||
|
||||
**剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(注意先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法.**
|
||||
** 剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(注意先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法.**
|
||||
|
||||
|
||||
```python
|
||||
|
@ -788,11 +786,11 @@ class BertModel(BertPreTrainedModel):
|
|||
|
||||
1. word_embeddings,上文中 subword 对应的嵌入。
|
||||
2. token_type_embeddings,用于表示当前词所在的句子,辅助区别句子与 padding、句子对间的差异。
|
||||
3. position_embeddings,句子中每个词的位置嵌入,用于区别词的顺序。和 transformer 论文中的设计不同,这一块是训练出来的,而不是通过 Sinusoidal 函数计算得到的固定嵌入。一般认为这种实现不利于拓展性(难以直接迁移到更长的句子中)。
|
||||
3。 position_embeddings,句子中每个词的位置嵌入,用于区别词的顺序。和 transformer 论文中的设计不同,这一块是训练出来的,而不是通过 Sinusoidal 函数计算得到的固定嵌入。一般认为这种实现不利于拓展性(难以直接迁移到更长的句子中)。
|
||||
|
||||
三个 embedding 不带权重相加,并通过一层 LayerNorm+dropout 后输出,其大小为(batch_size, sequence_length, hidden_size)。
|
||||
|
||||
**这里为什么要用 LayerNorm+Dropout 呢?为什么要用 LayerNorm 而不是 BatchNorm?可以参考一个不错的回答:[《transformer 为什么使用 layer normalization,而不是其他的归一化方法?》](https://www.zhihu.com/question/395811291/answer/1260290120)**
|
||||
** [这里为什么要用 LayerNorm+Dropout 呢?为什么要用 LayerNorm 而不是 BatchNorm?可以参考一个不错的回答:transformer 为什么使用 layer normalization,而不是其他的归一化方法?](https://www.zhihu.com/question/395811291/answer/1260290120)**
|
||||
|
||||
|
||||
```python
|
||||
|
@ -1107,7 +1105,7 @@ $$SDPA(Q, K, V) = softmax(\frac{QK^T}{\sqrt(d_k)})V$$
|
|||
|
||||
而这些注意力头,众所周知是并行计算的,所以上面的 query、key、value 三个权重是唯一的——这并不是所有 heads 共享了权重,而是“拼接”起来了。
|
||||
|
||||
**原论文中多头的理由为 Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. 而另一个比较靠谱的分析有:[《为什么 Transformer 需要进行 Multi-head Attention?》](https://www.zhihu.com/question/341222779/answer/814111138)**
|
||||
**[原论文中多头的理由为 Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. 而另一个比较靠谱的分析有:为什么 Transformer 需要进行 Multi-head Attention?](https://www.zhihu.com/question/341222779/answer/814111138)**
|
||||
|
||||
看看 forward 方法:
|
||||
```
|
||||
|
@ -1163,7 +1161,7 @@ def transpose_for_scores(self, x):
|
|||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
# ...
|
||||
```
|
||||
**关于爱因斯坦求和约定,参考以下文档:[《torch.einsum - PyTorch 1.8.1 documentation》](https://pytorch.org/docs/stable/generated/torch.einsum.html)**
|
||||
**[关于爱因斯坦求和约定,参考以下文档:torch.einsum - PyTorch 1.8.1 documentation](https://pytorch.org/docs/stable/generated/torch.einsum.html)**
|
||||
|
||||
|
||||
对于不同的positional_embedding_type,有三种操作:
|
||||
|
@ -1582,7 +1580,6 @@ print(y.size())
|
|||
- BertModel 包含复杂的封装和较多的组件。以 bert-base 为例,主要组件如下:
|
||||
- 总计Dropout出现了1+(1+1+1)x12=37次;
|
||||
- 总计LayerNorm出现了1+(1+1)x12=25次;
|
||||
- 总计dense全连接层出现了(1+1+1)x12+1=37次,并不是每个dense都配了激活函数……
|
||||
BertModel 有极大的参数量。以 bert-base 为例,其参数量为 109M。
|
||||
|
||||
## 致谢
|
||||
|
|
Loading…
Reference in New Issue