fix
This commit is contained in:
parent
d34d1f7a14
commit
3584f014aa
|
@ -5,9 +5,51 @@
|
|||
|
||||
|
||||
```python
|
||||
! pip install datasets transformers sacrebleu sentencepiece
|
||||
! pip install datasets transformers "sacrebleu>=1.4.12,<2.0.0" sentencepiece
|
||||
```
|
||||
|
||||
Requirement already satisfied: datasets in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (1.6.2)
|
||||
Requirement already satisfied: transformers in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (4.4.2)
|
||||
Collecting sacrebleu<2.0.0,>=1.4.12
|
||||
Downloading sacrebleu-1.5.1-py3-none-any.whl (54 kB)
|
||||
[K |████████████████████████████████| 54 kB 235 kB/s
|
||||
[?25hCollecting sentencepiece
|
||||
Downloading sentencepiece-0.1.96-cp38-cp38-macosx_10_6_x86_64.whl (1.1 MB)
|
||||
[K |████████████████████████████████| 1.1 MB 438 kB/s
|
||||
[?25hRequirement already satisfied: numpy>=1.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.21.1)
|
||||
Requirement already satisfied: multiprocess in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.70.12.2)
|
||||
Requirement already satisfied: fsspec in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2021.7.0)
|
||||
Requirement already satisfied: huggingface-hub<0.1.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.0.15)
|
||||
Requirement already satisfied: pandas in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.3.1)
|
||||
Requirement already satisfied: dill in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.3.4)
|
||||
Requirement already satisfied: tqdm<4.50.0,>=4.27 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (4.49.0)
|
||||
Requirement already satisfied: pyarrow>=1.0.0<4.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (5.0.0)
|
||||
Requirement already satisfied: xxhash in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.0.2)
|
||||
Requirement already satisfied: requests>=2.19.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.26.0)
|
||||
Requirement already satisfied: packaging in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (20.9)
|
||||
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.10.3)
|
||||
Requirement already satisfied: regex!=2019.12.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (2021.8.3)
|
||||
Requirement already satisfied: sacremoses in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.0.45)
|
||||
Requirement already satisfied: filelock in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (3.0.12)
|
||||
Collecting portalocker==2.0.0
|
||||
Downloading portalocker-2.0.0-py2.py3-none-any.whl (11 kB)
|
||||
Requirement already satisfied: typing-extensions in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from huggingface-hub<0.1.0->datasets) (3.10.0.0)
|
||||
Requirement already satisfied: pyparsing>=2.0.2 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from packaging->datasets) (2.4.7)
|
||||
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (1.26.6)
|
||||
Requirement already satisfied: certifi>=2017.4.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2021.5.30)
|
||||
Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2.0.4)
|
||||
Requirement already satisfied: idna<4,>=2.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.2)
|
||||
Requirement already satisfied: python-dateutil>=2.7.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)
|
||||
Requirement already satisfied: pytz>=2017.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2021.1)
|
||||
Requirement already satisfied: six>=1.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)
|
||||
Requirement already satisfied: click in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (8.0.1)
|
||||
Requirement already satisfied: joblib in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (1.0.1)
|
||||
Installing collected packages: portalocker, sentencepiece, sacrebleu
|
||||
Successfully installed portalocker-2.0.0 sacrebleu-1.5.1 sentencepiece-0.1.96
|
||||
[33mWARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.
|
||||
You should consider upgrading via the '/Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/bin/python3 -m pip install --upgrade pip' command.[0m
|
||||
|
||||
|
||||
如果您正在本地打开这个notebook,请确保您认真阅读并安装了transformer-quick-start-zh的readme文件中的所有依赖库。您也可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)找到本notebook的多GPU分布式训练版本。
|
||||
|
||||
# 微调transformer模型解决翻译任务
|
||||
|
@ -44,6 +86,25 @@ raw_datasets = load_dataset("wmt16", "ro-en")
|
|||
metric = load_metric("sacrebleu")
|
||||
```
|
||||
|
||||
Downloading: 2.81kB [00:00, 523kB/s]
|
||||
Downloading: 3.19kB [00:00, 758kB/s]
|
||||
Downloading: 41.0kB [00:00, 11.0MB/s]
|
||||
|
||||
|
||||
Downloading and preparing dataset wmt16/ro-en (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a...
|
||||
|
||||
|
||||
Downloading: 100%|██████████| 225M/225M [00:18<00:00, 12.2MB/s]
|
||||
Downloading: 100%|██████████| 23.5M/23.5M [00:16<00:00, 1.44MB/s]
|
||||
Downloading: 100%|██████████| 38.7M/38.7M [00:03<00:00, 9.82MB/s]
|
||||
|
||||
|
||||
Dataset wmt16 downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a. Subsequent calls will reuse this data.
|
||||
|
||||
|
||||
Downloading: 5.40kB [00:00, 2.08MB/s]
|
||||
|
||||
|
||||
这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。
|
||||
|
||||
|
||||
|
@ -128,23 +189,23 @@ show_random_elements(raw_datasets["train"])
|
|||
<tbody>
|
||||
<tr>
|
||||
<th>0</th>
|
||||
<td>{'en': 'The Bulgarian gymnastics team won the gold medal at the traditional Grand Prix series competition in Thiais, France, which wrapped up on Sunday (March 30th).', 'ro': 'Echipa bulgară de gimnastică a câştigat medalia de aur la tradiţionala competiţie Grand Prix din Thiais, Franţa, care s-a încheiat duminică (30 martie).'}</td>
|
||||
<td>{'en': 'I do not believe that this is the right course.', 'ro': 'Nu cred că acesta este varianta corectă.'}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>1</th>
|
||||
<td>{'en': 'Being on that committee, however, you will know that this was a very hot topic in negotiations between Norway and some Member States.', 'ro': 'Totuşi, făcând parte din această comisie, ştiţi că acesta a fost un subiect foarte aprins în negocierile dintre Norvegia şi unele state membre.'}</td>
|
||||
<td>{'en': 'A total of 104 new jobs were created at the European Chemicals Agency, which mainly supervises our REACH projects.', 'ro': 'Un total de 104 noi locuri de muncă au fost create la Agenția Europeană pentru Produse Chimice, care, în special, supraveghează proiectele noastre REACH.'}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>2</th>
|
||||
<td>{'en': 'The overwhelming vote shows just this.', 'ro': 'Ceea ce demonstrează şi votul favorabil.'}</td>
|
||||
<td>{'en': 'In view of the above, will the Council say what stage discussions for Turkish participation in joint Frontex operations have reached?', 'ro': 'Care este stadiul negocierilor referitoare la participarea Turciei la operațiunile comune din cadrul Frontex?'}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>3</th>
|
||||
<td>{'en': '[Photo illustration by Catherine Gurgenidze for Southeast European Times]', 'ro': '[Ilustraţii foto de Catherine Gurgenidze pentru Southeast European Times]'}</td>
|
||||
<td>{'en': 'We now fear that if the scope of this directive is expanded, the directive will suffer exactly the same fate as the last attempt at introducing 'Made in' origin marking - in other words, that it will once again be blocked by the Council.', 'ro': 'Acum ne temem că, dacă sfera de aplicare a directivei va fi extinsă, aceasta va avea exact aceeaşi soartă ca ultima încercare de introducere a marcajului de origine "Made in”, cu alte cuvinte, că va fi din nou blocată la Consiliu.'}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>4</th>
|
||||
<td>{'en': '(HU) Mr President, today the specific text of the agreement between the Hungarian Government and the European Commission has been formulated.', 'ro': '(HU) Domnule președinte, textul concret al acordului dintre guvernul ungar și Comisia Europeană a fost formulat astăzi.'}</td>
|
||||
<td>{'en': 'The country dropped nine slots to 85th, with a score of 6.58.', 'ro': 'Ţara a coborât nouă poziţii, pe locul 85, cu un scor de 6,58.'}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
@ -206,13 +267,13 @@ metric.compute(predictions=fake_preds, references=fake_labels)
|
|||
|
||||
|
||||
|
||||
{'bp': 1.0,
|
||||
{'score': 0.0,
|
||||
'counts': [4, 2, 0, 0],
|
||||
'totals': [4, 2, 0, 0],
|
||||
'precisions': [100.0, 100.0, 0.0, 0.0],
|
||||
'ref_len': 4,
|
||||
'score': 0.0,
|
||||
'bp': 1.0,
|
||||
'sys_len': 4,
|
||||
'totals': [4, 2, 0, 0]}
|
||||
'ref_len': 4}
|
||||
|
||||
|
||||
|
||||
|
@ -236,6 +297,13 @@ from transformers import AutoTokenizer
|
|||
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
||||
```
|
||||
|
||||
Downloading: 100%|██████████| 1.13k/1.13k [00:00<00:00, 466kB/s]
|
||||
Downloading: 100%|██████████| 789k/789k [00:00<00:00, 882kB/s]
|
||||
Downloading: 100%|██████████| 817k/817k [00:00<00:00, 902kB/s]
|
||||
Downloading: 100%|██████████| 1.39M/1.39M [00:01<00:00, 1.24MB/s]
|
||||
Downloading: 100%|██████████| 42.0/42.0 [00:00<00:00, 14.6kB/s]
|
||||
|
||||
|
||||
以我们使用的mBART模型为例,我们需要正确设置source语言和target语言。如果您要翻译的是其他双语语料,请查看[这里](https://huggingface.co/facebook/mbart-large-cc25)。我们可以检查source和target语言的设置:
|
||||
|
||||
|
||||
|
@ -348,6 +416,11 @@ preprocess_function(raw_datasets['train'][:2])
|
|||
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)
|
||||
```
|
||||
|
||||
100%|██████████| 611/611 [02:32<00:00, 3.99ba/s]
|
||||
100%|██████████| 2/2 [00:00<00:00, 3.76ba/s]
|
||||
100%|██████████| 2/2 [00:00<00:00, 3.89ba/s]
|
||||
|
||||
|
||||
更好的是,返回的结果会自动被缓存,避免下次处理的时候重新计算(但是也要注意,如果输入有改动,可能会被缓存影响!)。datasets库函数会对输入的参数进行检测,判断是否有变化,如果没有变化就使用缓存数据,如果有变化就重新处理。但如果输入参数不变,想改变输入的时候,最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外,上面使用到的`batched=True`这个参数是tokenizer的特点,以为这会使用多线程同时并行对输入进行处理。
|
||||
|
||||
## 微调transformer模型
|
||||
|
@ -361,6 +434,9 @@ from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqT
|
|||
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
|
||||
```
|
||||
|
||||
Downloading: 100%|██████████| 301M/301M [00:19<00:00, 15.1MB/s]
|
||||
|
||||
|
||||
由于我们微调的任务是机器翻译,而我们加载的是预训练的seq2seq模型,所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数(比如:预训练语言模型的神经网络head被扔掉了,同时随机初始化了机器翻译的神经网络head)。
|
||||
|
||||
|
||||
|
@ -452,6 +528,67 @@ trainer = Seq2SeqTrainer(
|
|||
trainer.train()
|
||||
```
|
||||
|
||||
1%|▏ | 500/38145 [1:05:10<91:20:54, 8.74s/it]
|
||||
|
||||
{'loss': 0.8588, 'learning_rate': 1.973784244330843e-05, 'epoch': 0.01}
|
||||
|
||||
|
||||
3%|▎ | 1000/38145 [2:09:32<73:56:07, 7.17s/it]
|
||||
|
||||
{'loss': 0.8343, 'learning_rate': 1.947568488661686e-05, 'epoch': 0.03}
|
||||
|
||||
|
||||
4%|▍ | 1500/38145 [3:03:53<57:17:10, 5.63s/it]
|
||||
|
||||
{'loss': 0.8246, 'learning_rate': 1.9213527329925285e-05, 'epoch': 0.04}
|
||||
|
||||
|
||||
5%|▌ | 1980/38145 [3:52:54<67:46:36, 6.75s/it]
|
||||
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
KeyboardInterrupt Traceback (most recent call last)
|
||||
|
||||
/var/folders/2k/x3py0v857kgcwqvvl00xxhxw0000gn/T/ipykernel_15169/4032920361.py in <module>
|
||||
----> 1 trainer.train()
|
||||
|
||||
|
||||
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
|
||||
1079 else:
|
||||
1080 # Revert to normal clipping otherwise, handling Apex or full precision
|
||||
-> 1081 torch.nn.utils.clip_grad_norm_(
|
||||
1082 amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
|
||||
1083 self.args.max_grad_norm,
|
||||
|
||||
|
||||
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py in clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite)
|
||||
40 total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
41 else:
|
||||
---> 42 total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
43 if total_norm.isnan() or total_norm.isinf():
|
||||
44 if error_if_nonfinite:
|
||||
|
||||
|
||||
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py in <listcomp>(.0)
|
||||
40 total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
41 else:
|
||||
---> 42 total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
43 if total_norm.isnan() or total_norm.isinf():
|
||||
44 if error_if_nonfinite:
|
||||
|
||||
|
||||
~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/functional.py in norm(input, p, dim, keepdim, out, dtype)
|
||||
1310 if not isinstance(p, str):
|
||||
1311 _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
|
||||
-> 1312 return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
|
||||
1313
|
||||
1314 # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
|
||||
|
||||
|
||||
KeyboardInterrupt:
|
||||
|
||||
|
||||
最后别忘了,查看如何上传模型 ,上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样,直接用模型名字就能使用您的模型啦。
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue