This commit is contained in:
erenup 2021-09-02 00:14:50 +08:00
parent d34d1f7a14
commit 3584f014aa
1 changed files with 147 additions and 10 deletions

View File

@ -5,9 +5,51 @@
```python
! pip install datasets transformers sacrebleu sentencepiece
! pip install datasets transformers "sacrebleu>=1.4.12,<2.0.0" sentencepiece
```
Requirement already satisfied: datasets in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (1.6.2)
Requirement already satisfied: transformers in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (4.4.2)
Collecting sacrebleu<2.0.0,>=1.4.12
Downloading sacrebleu-1.5.1-py3-none-any.whl (54 kB)
 |████████████████████████████████| 54 kB 235 kB/s
[?25hCollecting sentencepiece
Downloading sentencepiece-0.1.96-cp38-cp38-macosx_10_6_x86_64.whl (1.1 MB)
 |████████████████████████████████| 1.1 MB 438 kB/s
[?25hRequirement already satisfied: numpy>=1.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.21.1)
Requirement already satisfied: multiprocess in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.70.12.2)
Requirement already satisfied: fsspec in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2021.7.0)
Requirement already satisfied: huggingface-hub<0.1.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.0.15)
Requirement already satisfied: pandas in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.3.1)
Requirement already satisfied: dill in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.3.4)
Requirement already satisfied: tqdm<4.50.0,>=4.27 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (4.49.0)
Requirement already satisfied: pyarrow>=1.0.0<4.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (5.0.0)
Requirement already satisfied: xxhash in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.0.2)
Requirement already satisfied: requests>=2.19.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.26.0)
Requirement already satisfied: packaging in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (20.9)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.10.3)
Requirement already satisfied: regex!=2019.12.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (2021.8.3)
Requirement already satisfied: sacremoses in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.0.45)
Requirement already satisfied: filelock in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (3.0.12)
Collecting portalocker==2.0.0
Downloading portalocker-2.0.0-py2.py3-none-any.whl (11 kB)
Requirement already satisfied: typing-extensions in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from huggingface-hub<0.1.0->datasets) (3.10.0.0)
Requirement already satisfied: pyparsing>=2.0.2 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from packaging->datasets) (2.4.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (1.26.6)
Requirement already satisfied: certifi>=2017.4.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2021.5.30)
Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2.0.4)
Requirement already satisfied: idna<4,>=2.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.2)
Requirement already satisfied: python-dateutil>=2.7.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2021.1)
Requirement already satisfied: six>=1.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)
Requirement already satisfied: click in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (8.0.1)
Requirement already satisfied: joblib in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (1.0.1)
Installing collected packages: portalocker, sentencepiece, sacrebleu
Successfully installed portalocker-2.0.0 sacrebleu-1.5.1 sentencepiece-0.1.96
WARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.
You should consider upgrading via the '/Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/bin/python3 -m pip install --upgrade pip' command.
如果您正在本地打开这个notebook请确保您认真阅读并安装了transformer-quick-start-zh的readme文件中的所有依赖库。您也可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)找到本notebook的多GPU分布式训练版本。
# 微调transformer模型解决翻译任务
@ -44,6 +86,25 @@ raw_datasets = load_dataset("wmt16", "ro-en")
metric = load_metric("sacrebleu")
```
Downloading: 2.81kB [00:00, 523kB/s]
Downloading: 3.19kB [00:00, 758kB/s]
Downloading: 41.0kB [00:00, 11.0MB/s]
Downloading and preparing dataset wmt16/ro-en (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a...
Downloading: 100%|██████████| 225M/225M [00:18<00:00, 12.2MB/s]
Downloading: 100%|██████████| 23.5M/23.5M [00:16<00:00, 1.44MB/s]
Downloading: 100%|██████████| 38.7M/38.7M [00:03<00:00, 9.82MB/s]
Dataset wmt16 downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a. Subsequent calls will reuse this data.
Downloading: 5.40kB [00:00, 2.08MB/s]
这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集只需要使用对应的keytrainvalidationtest即可得到相应的数据。
@ -128,23 +189,23 @@ show_random_elements(raw_datasets["train"])
<tbody>
<tr>
<th>0</th>
<td>{'en': 'The Bulgarian gymnastics team won the gold medal at the traditional Grand Prix series competition in Thiais, France, which wrapped up on Sunday (March 30th).', 'ro': 'Echipa bulgară de gimnastică a câştigat medalia de aur la tradiţionala competiţie Grand Prix din Thiais, Franţa, care s-a încheiat duminică (30 martie).'}</td>
<td>{'en': 'I do not believe that this is the right course.', 'ro': 'Nu cred că acesta este varianta corectă.'}</td>
</tr>
<tr>
<th>1</th>
<td>{'en': 'Being on that committee, however, you will know that this was a very hot topic in negotiations between Norway and some Member States.', 'ro': 'Totuşi, făcând parte din această comisie, ştiţi că acesta a fost un subiect foarte aprins în negocierile dintre Norvegia şi unele state membre.'}</td>
<td>{'en': 'A total of 104 new jobs were created at the European Chemicals Agency, which mainly supervises our REACH projects.', 'ro': 'Un total de 104 noi locuri de muncă au fost create la Agenția Europeană pentru Produse Chimice, care, în special, supraveghează proiectele noastre REACH.'}</td>
</tr>
<tr>
<th>2</th>
<td>{'en': 'The overwhelming vote shows just this.', 'ro': 'Ceea ce demonstrează şi votul favorabil.'}</td>
<td>{'en': 'In view of the above, will the Council say what stage discussions for Turkish participation in joint Frontex operations have reached?', 'ro': 'Care este stadiul negocierilor referitoare la participarea Turciei la operațiunile comune din cadrul Frontex?'}</td>
</tr>
<tr>
<th>3</th>
<td>{'en': '[Photo illustration by Catherine Gurgenidze for Southeast European Times]', 'ro': '[Ilustraţii foto de Catherine Gurgenidze pentru Southeast European Times]'}</td>
<td>{'en': 'We now fear that if the scope of this directive is expanded, the directive will suffer exactly the same fate as the last attempt at introducing 'Made in' origin marking - in other words, that it will once again be blocked by the Council.', 'ro': 'Acum ne temem că, dacă sfera de aplicare a directivei va fi extinsă, aceasta va avea exact aceeaşi soartă ca ultima încercare de introducere a marcajului de origine "Made in”, cu alte cuvinte, că va fi din nou blocată la Consiliu.'}</td>
</tr>
<tr>
<th>4</th>
<td>{'en': '(HU) Mr President, today the specific text of the agreement between the Hungarian Government and the European Commission has been formulated.', 'ro': '(HU) Domnule președinte, textul concret al acordului dintre guvernul ungar și Comisia Europeană a fost formulat astăzi.'}</td>
<td>{'en': 'The country dropped nine slots to 85th, with a score of 6.58.', 'ro': 'Ţara a coborât nouă poziţii, pe locul 85, cu un scor de 6,58.'}</td>
</tr>
</tbody>
</table>
@ -206,13 +267,13 @@ metric.compute(predictions=fake_preds, references=fake_labels)
{'bp': 1.0,
{'score': 0.0,
'counts': [4, 2, 0, 0],
'totals': [4, 2, 0, 0],
'precisions': [100.0, 100.0, 0.0, 0.0],
'ref_len': 4,
'score': 0.0,
'bp': 1.0,
'sys_len': 4,
'totals': [4, 2, 0, 0]}
'ref_len': 4}
@ -236,6 +297,13 @@ from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
```
Downloading: 100%|██████████| 1.13k/1.13k [00:00<00:00, 466kB/s]
Downloading: 100%|██████████| 789k/789k [00:00<00:00, 882kB/s]
Downloading: 100%|██████████| 817k/817k [00:00<00:00, 902kB/s]
Downloading: 100%|██████████| 1.39M/1.39M [00:01<00:00, 1.24MB/s]
Downloading: 100%|██████████| 42.0/42.0 [00:00<00:00, 14.6kB/s]
以我们使用的mBART模型为例我们需要正确设置source语言和target语言。如果您要翻译的是其他双语语料请查看[这里](https://huggingface.co/facebook/mbart-large-cc25)。我们可以检查source和target语言的设置
@ -348,6 +416,11 @@ preprocess_function(raw_datasets['train'][:2])
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
```
100%|██████████| 611/611 [02:32<00:00, 3.99ba/s]
100%|██████████| 2/2 [00:00<00:00, 3.76ba/s]
100%|██████████| 2/2 [00:00<00:00, 3.89ba/s]
更好的是返回的结果会自动被缓存避免下次处理的时候重新计算但是也要注意如果输入有改动可能会被缓存影响。datasets库函数会对输入的参数进行检测判断是否有变化如果没有变化就使用缓存数据如果有变化就重新处理。但如果输入参数不变想改变输入的时候最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外,上面使用到的`batched=True`这个参数是tokenizer的特点以为这会使用多线程同时并行对输入进行处理。
## 微调transformer模型
@ -361,6 +434,9 @@ from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqT
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
```
Downloading: 100%|██████████| 301M/301M [00:19<00:00, 15.1MB/s]
由于我们微调的任务是机器翻译而我们加载的是预训练的seq2seq模型所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数比如预训练语言模型的神经网络head被扔掉了同时随机初始化了机器翻译的神经网络head
@ -452,6 +528,67 @@ trainer = Seq2SeqTrainer(
trainer.train()
```
1%|▏ | 500/38145 [1:05:10<91:20:54, 8.74s/it]
{'loss': 0.8588, 'learning_rate': 1.973784244330843e-05, 'epoch': 0.01}
3%|▎ | 1000/38145 [2:09:32<73:56:07, 7.17s/it]
{'loss': 0.8343, 'learning_rate': 1.947568488661686e-05, 'epoch': 0.03}
4%|▍ | 1500/38145 [3:03:53<57:17:10, 5.63s/it]
{'loss': 0.8246, 'learning_rate': 1.9213527329925285e-05, 'epoch': 0.04}
5%|▌ | 1980/38145 [3:52:54<67:46:36, 6.75s/it]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/var/folders/2k/x3py0v857kgcwqvvl00xxhxw0000gn/T/ipykernel_15169/4032920361.py in <module>
----> 1 trainer.train()
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
1079 else:
1080 # Revert to normal clipping otherwise, handling Apex or full precision
-> 1081 torch.nn.utils.clip_grad_norm_(
1082 amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
1083 self.args.max_grad_norm,
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py in clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite)
40 total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
41 else:
---> 42 total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
43 if total_norm.isnan() or total_norm.isinf():
44 if error_if_nonfinite:
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py in <listcomp>(.0)
40 total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
41 else:
---> 42 total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
43 if total_norm.isnan() or total_norm.isinf():
44 if error_if_nonfinite:
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/functional.py in norm(input, p, dim, keepdim, out, dtype)
1310 if not isinstance(p, str):
1311 _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
-> 1312 return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
1313
1314 # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
KeyboardInterrupt:
最后别忘了,查看如何上传模型 ,上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样直接用模型名字就能使用您的模型啦。