diff --git a/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.ipynb b/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.ipynb index f68b2ec..aa1a322 100644 --- a/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.ipynb +++ b/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.ipynb @@ -2,53 +2,50 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "bTAeNLV3WdB0" - }, "source": [ "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n", "\n", "建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。\n", "如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。" - ] + ], + "metadata": { + "id": "bTAeNLV3WdB0" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "bAWdZQdTWdB3" - }, - "outputs": [], "source": [ "# ! pip install datasets transformers \n", "# -i https://pypi.tuna.tsinghua.edu.cn/simple" - ] + ], + "outputs": [], + "metadata": { + "id": "bAWdZQdTWdB3" + } }, { "cell_type": "markdown", - "metadata": { - "id": "7FC_ZTXsWdB3" - }, "source": [ "如果您是在本地机器上打开这个jupyter笔记本,请确保您的环境安装了上述库的最新版本。\n", "\n", "您可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/language-modeling)找到这个jupyter笔记本的具体的python脚本文件,还可以通过分布式的方式使用多个gpu或tpu来微调您的模型。" - ] + ], + "metadata": { + "id": "7FC_ZTXsWdB3" + } }, { "cell_type": "markdown", - "metadata": { - "id": "BgQvrzh3WdB4" - }, "source": [ "# 微调语言模型" - ] + ], + "metadata": { + "id": "BgQvrzh3WdB4" + } }, { "cell_type": "markdown", - "metadata": { - "id": "hqHxDcitWdB4" - }, "source": [ "在当前jupyter笔记本中,我们将说明如何使用语言模型任务微调任意[🤗Transformers](https://github.com/huggingface/transformers) 模型。 \n", "\n", @@ -65,43 +62,73 @@ "接下来,我们将说明如何轻松地为每个任务加载和预处理数据集,以及如何使用“Trainer”API对模型进行微调。\n", "\n", "当然您也可以直接在分布式环境或TPU上运行该jupyter笔记本的python脚本版本,可以在[examples文件夹](https://github.com/huggingface/transformers/tree/master/examples)中找到。" - ] + ], + "metadata": { + "id": "hqHxDcitWdB4" + } }, { "cell_type": "markdown", - "metadata": { - "id": "GobLIFiRWdB5" - }, "source": [ "## 准备数据" - ] + ], + "metadata": { + "id": "GobLIFiRWdB5" + } }, { "cell_type": "markdown", - "metadata": { - "id": "Tfkh562BWdB5" - }, "source": [ "在接下来的这些任务中,我们将使用[Wikitext 2](https://huggingface.co/datasets/wikitext#data-instances)数据集作为示例。您可以通过🤗Datasets库加载该数据集:" - ] + ], + "metadata": { + "id": "Tfkh562BWdB5" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vtDCQuMSWdB5" - }, - "outputs": [], + "execution_count": 1, "source": [ "from datasets import load_dataset\n", "datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 8.33kB [00:00, 1.49MB/s] \n", + "Downloading: 5.83kB [00:00, 1.77MB/s] \n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.91 MiB, post-processed: Unknown size, total: 17.41 MiB) to /Users/niepig/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 100%|██████████| 4.72M/4.72M [00:02<00:00, 1.91MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Dataset wikitext downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20. Subsequent calls will reuse this data.\n" + ] + } + ], + "metadata": { + "id": "vtDCQuMSWdB5" + } }, { "cell_type": "markdown", - "metadata": { - "id": "IqMUxoJrWdB7" - }, "source": [ "如果碰到以下错误:\n", "![request Error](images/request_error.png)\n", @@ -111,83 +138,80 @@ "MAC用户: 在 ```/etc/hosts``` 文件中添加一行 ```199.232.68.133 raw.githubusercontent.com```\n", "\n", "Windowso用户: 在 ```C:\\Windows\\System32\\drivers\\etc\\hosts``` 文件中添加一行 ```199.232.68.133 raw.githubusercontent.com```" - ] + ], + "metadata": { + "id": "IqMUxoJrWdB7" + } }, { "cell_type": "markdown", - "metadata": { - "id": "Wjl5FpYDWdB7" - }, "source": [ "当然您也可以用公开在[hub](https://huggingface.co/datasets)上的任何数据集替换上面的数据集,或者使用您自己的文件。只需取消注释以下单元格,并将路径替换为将导致您的文件路径:" - ] + ], + "metadata": { + "id": "Wjl5FpYDWdB7" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "fncDchlaWdB7" - }, - "outputs": [], "source": [ "# datasets = load_dataset(\"text\", data_files={\"train\": path_to_train.txt, \"validation\": path_to_validation.txt}" - ] + ], + "outputs": [], + "metadata": { + "id": "fncDchlaWdB7" + } }, { "cell_type": "markdown", - "metadata": { - "id": "ODIsscsTWdB8" - }, "source": [ "您还可以从csv或JSON文件加载数据集,更多信息请参阅[完整文档](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)。\n", "\n", "要访问一个数据中实际的元素,您需要先选择一个key,然后给出一个索引:" - ] + ], + "metadata": { + "id": "ODIsscsTWdB8" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, + "source": [ + "datasets[\"train\"][10]" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'text': ' The game \\'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \\' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific \" Potentials \" , skills unique to each character . They are divided into \" Personal Potential \" , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and \" Battle Potentials \" , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique \" Masters Table \" , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate \" Direct Command \" and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her \" Valkyria Form \" and become invincible , while Imca can target multiple enemy units with her heavy weapon . \\n'}" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "H887z9CbWdB8", "outputId": "f9d05402-f99b-40da-b672-887e6a8c5597" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'text': ' The game \\'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \\' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific \" Potentials \" , skills unique to each character . They are divided into \" Personal Potential \" , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and \" Battle Potentials \" , which are grown throughout the game and always grant boons to a character . To learn Battle Potentials , each character has a unique \" Masters Table \" , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special Abilities that grant them temporary boosts on the battlefield : Kurt can activate \" Direct Command \" and move around the battlefield without depleting his Action Point gauge , the character Reila can shift into her \" Valkyria Form \" and become invincible , while Imca can target multiple enemy units with her heavy weapon . \\n'}" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "datasets[\"train\"][10]" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "Y3vbv6yHWdB8" - }, "source": [ "为了快速了解数据的结构,下面的函数将显示数据集中随机选取的一些示例。" - ] + ], + "metadata": { + "id": "Y3vbv6yHWdB8" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "II9ha_LmWdB9" - }, - "outputs": [], + "execution_count": 3, "source": [ "from datasets import ClassLabel\n", "import random\n", @@ -208,22 +232,25 @@ " if isinstance(typ, ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" - ] + ], + "outputs": [], + "metadata": { + "id": "II9ha_LmWdB9" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 427 - }, - "id": "LaCaYQyJWdB9", - "outputId": "8fcf2a87-fa7c-46b1-bd03-26325ce69da9" - }, + "execution_count": 4, + "source": [ + "show_random_elements(datasets[\"train\"])" + ], "outputs": [ { + "output_type": "display_data", "data": { + "text/plain": [ + "" + ], "text/html": [ "\n", " \n", @@ -235,11 +262,11 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -247,27 +274,27 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -275,44 +302,40 @@ " \n", " \n", "
0MD 194D is the designation for an unnamed 0 @.@ 02 @-@ mile ( 0 @.@ 032 km ) connector between MD 194 and MD 853E , the old alignment that parallels the northbound direction of the modern highway south of Angell Road . \\nPlum cakes made with fresh plums came with other migrants from other traditions in which plum cake is prepared using plum as a primary ingredient . In some versions , the plums may become jam @-@ like inside the cake after cooking , or be prepared using plum jam . Plum cake prepared with plums is also a part of Ashkenazi Jewish cuisine , and is referred to as Pflaumenkuchen or Zwetschgenkuchen . Other plum @-@ based cakes are found in French , Italian and Polish cooking . \\n
1My sense , as though of hemlock I had drunk , \\n= = = Language = = = \\n
2
3A mimed stage show , Thunderbirds : F.A.B. , has toured internationally and popularised a staccato style of movement known colloquially as the \" Thunderbirds walk \" . The production has periodically been revived as Thunderbirds : F.A.B. – The Next Generation . \\n
4The town 's population not only recovered but grew ; the 1906 census of the Canadian Prairies listed the population at 1 @,@ 178 . A new study commissioned by the Dominion government determined that the cracks in the mountain continued to grow and that the risk of another slide remained . Consequently , parts of Frank closest to the mountain were dismantled or relocated to safer areas . \\n
5The Litigators is a 2011 legal thriller novel by John Grisham , his 25th fiction novel overall . The Litigators is about a two @-@ partner Chicago law firm attempting to strike it rich in a class action lawsuit over a cholesterol reduction drug by a major pharmaceutical drug company . The protagonist is a Harvard Law School grad big law firm burnout who stumbles upon the boutique and joins it only to find himself litigating against his old law firm in this case . The book is regarded as more humorous than most of Grisham 's prior novels . \\n
6In his 1998 autobiography For the Love of the Game , Jordan wrote that he had been preparing for retirement as early as the summer of 1992 . The added exhaustion due to the Dream Team run in the 1992 Olympics solidified Jordan 's feelings about the game and his ever @-@ growing celebrity status . Jordan 's announcement sent shock waves throughout the NBA and appeared on the front pages of newspapers around the world . \\n
7Research on new wildlife collars may be able to reduce human @-@ animal conflicts by predicting when and where predatory animals hunt . This can not only save human lives and the lives of their pets and livestock but also save these large predatory mammals that are important to the balance of ecosystems . \\nOn December 7 , 2006 , Headquarters Marine Corps released a message stating that 2nd Battalion 9th Marines would be reactivated during 2007 as part of the continuing Global War on Terror . 2nd Battalion 9th Marines was re @-@ activated on July 13 , 2007 and replaced the Anti @-@ Terrorism Battalion ( ATBn ) . In September 2008 , Marines and Sailors from 2 / 9 deployed to Al Anbar Province in support of Operation Iraqi Freedom . They were based in the city of Ramadi and returned in April 2009 without any Marines or Sailors killed in action . July 2010 Marines and Sailors from 2 / 9 deployed to Marjah , Helmand Province , Afghanistan in support of Operation Enduring Freedom . In December 2010 Echo Company from 2 / 9 were attached to 3 / 5 in Sangin , Afghanistan where they earned the notorious nickname of \" Green Hats . \" They returned February 2011 . They redeployed back to Marjah December 2011 and returned July 2012 . Echo and Weapons companies deployed once more to Afghanistan from January through April 2013 , participating in combat operations out of Camp Leatherneck . On April 1 , 2015 the battalion was deactivated in a ceremony at Camp Lejeune . \\n
8\" Love Me Like You \" ( Christmas Mix ) – 3 : 29 \\n( i ) = Indoor \\n
9
" - ], - "text/plain": [ - "" ] }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" + "metadata": {} } ], - "source": [ - "show_random_elements(datasets[\"train\"])" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 427 + }, + "id": "LaCaYQyJWdB9", + "outputId": "8fcf2a87-fa7c-46b1-bd03-26325ce69da9" + } }, { "cell_type": "markdown", - "metadata": { - "id": "LH0Uk_OsWdB9" - }, "source": [ "正如我们所看到的,一些文本是维基百科文章的完整段落,而其他的只是标题或空行。" - ] + ], + "metadata": { + "id": "LH0Uk_OsWdB9" + } }, { "cell_type": "markdown", - "metadata": { - "id": "9Nu5lPu8WdB-" - }, "source": [ "## 因果语言模型(Causal Language Modeling,CLM)" - ] + ], + "metadata": { + "id": "9Nu5lPu8WdB-" + } }, { "cell_type": "markdown", - "metadata": { - "id": "v7gOchUNWdB-" - }, "source": [ "对于因果语言模型(CLM),我们首先获取到数据集中的所有文本,并在它们被分词后将它们连接起来。然后,我们将在特定序列长度的例子中拆分它们。通过这种方式,模型将接收如下的连续文本块:\n", "\n", @@ -327,159 +350,228 @@ "取决于它们是否跨越数据集中的几个原始文本。标签将与输入相同,但向左移动。\n", "\n", "在本例中,我们将使用[`distilgpt2`](https://huggingface.co/distilgpt2) 模型。您同样也可以选择[这里](https://huggingface.co/models?filter=causal-lm)列出的任何一个checkpoint:" - ] + ], + "metadata": { + "id": "v7gOchUNWdB-" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z37txOiBWdB-" - }, - "outputs": [], + "execution_count": 5, "source": [ "model_checkpoint = \"distilgpt2\"" - ] + ], + "outputs": [], + "metadata": { + "id": "z37txOiBWdB-" + } }, { "cell_type": "markdown", - "metadata": { - "id": "mk8BWvYWWdB-" - }, "source": [ "为了用训练模型时使用的词汇对所有文本进行标记,我们必须下载一个预先训练过的分词器(Tokenizer)。而这些操作都可以由AutoTokenizer类完成:" - ] + ], + "metadata": { + "id": "mk8BWvYWWdB-" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mQwZ5UssWdB_" - }, - "outputs": [], + "execution_count": 6, "source": [ "from transformers import AutoTokenizer\n", " \n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 100%|██████████| 762/762 [00:00<00:00, 358kB/s]\n", + "Downloading: 100%|██████████| 1.04M/1.04M [00:04<00:00, 235kB/s]\n", + "Downloading: 100%|██████████| 456k/456k [00:02<00:00, 217kB/s]\n", + "Downloading: 100%|██████████| 1.36M/1.36M [00:05<00:00, 252kB/s]\n" + ] + } + ], + "metadata": { + "id": "mQwZ5UssWdB_" + } }, { "cell_type": "markdown", - "metadata": { - "id": "hAQJGvMxWdB_" - }, "source": [ "我们现在可以对所有的文本调用分词器,该操作可以简单地使用来自Datasets库的map方法实现。首先,我们定义一个在文本上调用标记器的函数:" - ] + ], + "metadata": { + "id": "hAQJGvMxWdB_" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wxhKKMYgWdB_" - }, - "outputs": [], + "execution_count": 7, "source": [ "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"])" - ] + ], + "outputs": [], + "metadata": { + "id": "wxhKKMYgWdB_" + } }, { "cell_type": "markdown", - "metadata": { - "id": "FM_kMpbCWdB_" - }, "source": [ "然后我们将它应用到datasets对象中的分词,使用```batch=True```和```4```个进程来加速预处理。而之后我们并不需要```text```列,所以将其舍弃。\n" - ] + ], + "metadata": { + "id": "FM_kMpbCWdB_" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rNb1U12YWdCA" - }, - "outputs": [], + "execution_count": 8, "source": [ "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "#0: 0%| | 0/2 [00:00\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, **kwargs)\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1033\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1034\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1035\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1036\u001b[0m \u001b[0;31m# Skip past any already trained steps if resuming training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 560\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 562\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 45\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: 1" + ] + } + ], + "metadata": { + "id": "a55CO2xGWdCF" + } }, { "cell_type": "markdown", - "metadata": { - "id": "PAFX3mCwWdCG" - }, "source": [ "一旦训练完成,我们就可以评估我们的模型,得到它在验证集上的perplexity,如下所示:" - ] + ], + "metadata": { + "id": "PAFX3mCwWdCG" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "g1A7eBP3WdCG" - }, - "outputs": [], "source": [ "import math\n", "eval_results = trainer.evaluate()\n", "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")" - ] + ], + "outputs": [], + "metadata": { + "id": "g1A7eBP3WdCG" + } }, { "cell_type": "markdown", - "metadata": { - "id": "lJzcfmsVWdCG" - }, "source": [ "## 掩蔽语言模型(Mask Language Modeling,MLM)" - ] + ], + "metadata": { + "id": "lJzcfmsVWdCG" + } }, { "cell_type": "markdown", - "metadata": { - "id": "sr5UHTPjWdCG" - }, "source": [ "掩蔽语言模型(MLM)我们将使用相同的数据集预处理和以前一样用一个额外的步骤:\n", "\n", "我们将随机\"MASK\"一些字符(使用\"[MASK]\"进行替换)以及调整标签为只包含在\"[MASK]\"位置处的标签(因为我们不需要预测没有被\"MASK\"的字符)。\n", "\n", "在本例中,我们将使用[`distilroberta-base`](https://huggingface.co/distilroberta-base)模型。您同样也可以选择[这里](https://huggingface.co/models?filter=causal-lm)列出的任何一个checkpoint:" - ] + ], + "metadata": { + "id": "sr5UHTPjWdCG" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "2X4qJPeXWdCG" - }, - "outputs": [], "source": [ "model_checkpoint = \"distilroberta-base\"" - ] + ], + "outputs": [], + "metadata": { + "id": "2X4qJPeXWdCG" + } }, { "cell_type": "markdown", - "metadata": { - "id": "GMDHywzqWdCH" - }, "source": [ "我们可以像之前一样应用相同的分词器函数,我们只需要更新我们的分词器来使用刚刚选择的checkpoint:" - ] + ], + "metadata": { + "id": "GMDHywzqWdCH" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "sgIqOa4uWdCH" - }, - "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", "tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])" - ] + ], + "outputs": [], + "metadata": { + "id": "sgIqOa4uWdCH" + } }, { "cell_type": "markdown", - "metadata": { - "id": "O4BALsgJWdCH" - }, "source": [ "像之前一样,我们把文本分组在一起,并把它们分成长度为`block_size`的样本。如果您的数据集由单独的句子组成,则可以跳过这一步。" - ] + ], + "metadata": { + "id": "O4BALsgJWdCH" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "d4jJo5X2WdCH" - }, - "outputs": [], "source": [ "lm_datasets = tokenized_datasets.map(\n", " group_texts,\n", @@ -845,34 +1000,35 @@ " batch_size=1000,\n", " num_proc=4,\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "d4jJo5X2WdCH" + } }, { "cell_type": "markdown", - "metadata": { - "id": "2wbjmPZQWdCI" - }, "source": [ "剩下的和我们之前的做法非常相似,只有两个例外。首先我们使用一个适合掩蔽语言模型的模型:" - ] + ], + "metadata": { + "id": "2wbjmPZQWdCI" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "LkJuOG4oWdCI" - }, - "outputs": [], "source": [ "from transformers import AutoModelForMaskedLM\n", "model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)" - ] + ], + "outputs": [], + "metadata": { + "id": "LkJuOG4oWdCI" + } }, { "cell_type": "markdown", - "metadata": { - "id": "MVKOscEdWdCI" - }, "source": [ "其次,我们使用一个特殊的data_collator。data_collator是一个函数,负责获取样本并将它们批处理成张量。\n", "\n", @@ -881,36 +1037,35 @@ "我们可以将其作为预处理步骤(`tokenizer`)进行处理,但在每个阶段,字符总是以相同的方式被掩盖。通过在data_collator中执行这一步,我们可以确保每次检查数据时都以新的方式完成随机掩蔽。\n", "\n", "为了实现掩蔽,`Transformers`为掩蔽语言模型提供了一个`DataCollatorForLanguageModeling`。我们可以调整掩蔽的概率:" - ] + ], + "metadata": { + "id": "MVKOscEdWdCI" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "A-k8wJK7WdCI" - }, - "outputs": [], "source": [ "from transformers import DataCollatorForLanguageModeling\n", "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)" - ] + ], + "outputs": [], + "metadata": { + "id": "A-k8wJK7WdCI" + } }, { "cell_type": "markdown", - "metadata": { - "id": "m83JcPGyWdCI" - }, "source": [ "然后我们要把所有的东西交给trainer,然后开始训练:" - ] + ], + "metadata": { + "id": "m83JcPGyWdCI" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "I12S2ZQxWdCJ" - }, - "outputs": [], "source": [ "trainer = Trainer(\n", " model=model,\n", @@ -919,54 +1074,58 @@ " eval_dataset=lm_datasets[\"validation\"][:100],\n", " data_collator=data_collator,\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "I12S2ZQxWdCJ" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "u_4PHx1CWdCJ" - }, - "outputs": [], "source": [ "trainer.train()" - ] + ], + "outputs": [], + "metadata": { + "id": "u_4PHx1CWdCJ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "MDKbrOmzWdCJ" - }, "source": [ "像以前一样,我们可以在验证集上评估我们的模型。\n", "\n", "与CLM目标相比,困惑度要低得多,因为对于MLM目标,我们只需要对隐藏的令牌(在这里占总数的15%)进行预测,同时可以访问其余的令牌。\n", "\n", "因此,对于模型来说,这是一项更容易的任务。" - ] + ], + "metadata": { + "id": "MDKbrOmzWdCJ" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "60hUa-W5WdCJ" - }, - "outputs": [], "source": [ "eval_results = trainer.evaluate()\n", "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")" - ] + ], + "outputs": [], + "metadata": { + "id": "60hUa-W5WdCJ" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "uPa5UTWWWdCK" - }, - "outputs": [], "source": [ "不要忘记将你的模型[上传](https://huggingface.co/transformers/model_sharing.html)到[🤗 模型中心](https://huggingface.co/models)。" - ] + ], + "outputs": [], + "metadata": { + "id": "uPa5UTWWWdCK" + } } ], "metadata": { @@ -979,12 +1138,20 @@ "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b" }, "kernelspec": { - "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)", - "name": "python3" + "name": "python3", + "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)" }, "language_info": { "name": "python", - "version": "" + "version": "3.8.10", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" }, "widgets": { "application/vnd.jupyter.widget-state+json": { @@ -1237,5 +1404,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 2 } \ No newline at end of file diff --git a/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.md b/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.md index 7fb7450..ea1aadd 100644 --- a/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.md +++ b/docs/篇章4-使用Transformers解决NLP任务/4.5-生成任务-语言模型.md @@ -41,6 +41,19 @@ from datasets import load_dataset datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') ``` + Downloading: 8.33kB [00:00, 1.49MB/s] + Downloading: 5.83kB [00:00, 1.77MB/s] + + + Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.91 MiB, post-processed: Unknown size, total: 17.41 MiB) to /Users/niepig/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20... + + + Downloading: 100%|██████████| 4.72M/4.72M [00:02<00:00, 1.91MB/s] + + + Dataset wikitext downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20. Subsequent calls will reuse this data. + + 如果碰到以下错误: ![request Error](images/request_error.png) @@ -114,11 +127,11 @@ show_random_elements(datasets["train"]) 0 - MD 194D is the designation for an unnamed 0 @.@ 02 @-@ mile ( 0 @.@ 032 km ) connector between MD 194 and MD 853E , the old alignment that parallels the northbound direction of the modern highway south of Angell Road . \n + Plum cakes made with fresh plums came with other migrants from other traditions in which plum cake is prepared using plum as a primary ingredient . In some versions , the plums may become jam @-@ like inside the cake after cooking , or be prepared using plum jam . Plum cake prepared with plums is also a part of Ashkenazi Jewish cuisine , and is referred to as Pflaumenkuchen or Zwetschgenkuchen . Other plum @-@ based cakes are found in French , Italian and Polish cooking . \n 1 - My sense , as though of hemlock I had drunk , \n + = = = Language = = = \n 2 @@ -126,27 +139,27 @@ show_random_elements(datasets["train"]) 3 - A mimed stage show , Thunderbirds : F.A.B. , has toured internationally and popularised a staccato style of movement known colloquially as the " Thunderbirds walk " . The production has periodically been revived as Thunderbirds : F.A.B. – The Next Generation . \n + 4 - + The town 's population not only recovered but grew ; the 1906 census of the Canadian Prairies listed the population at 1 @,@ 178 . A new study commissioned by the Dominion government determined that the cracks in the mountain continued to grow and that the risk of another slide remained . Consequently , parts of Frank closest to the mountain were dismantled or relocated to safer areas . \n 5 - + The Litigators is a 2011 legal thriller novel by John Grisham , his 25th fiction novel overall . The Litigators is about a two @-@ partner Chicago law firm attempting to strike it rich in a class action lawsuit over a cholesterol reduction drug by a major pharmaceutical drug company . The protagonist is a Harvard Law School grad big law firm burnout who stumbles upon the boutique and joins it only to find himself litigating against his old law firm in this case . The book is regarded as more humorous than most of Grisham 's prior novels . \n 6 - In his 1998 autobiography For the Love of the Game , Jordan wrote that he had been preparing for retirement as early as the summer of 1992 . The added exhaustion due to the Dream Team run in the 1992 Olympics solidified Jordan 's feelings about the game and his ever @-@ growing celebrity status . Jordan 's announcement sent shock waves throughout the NBA and appeared on the front pages of newspapers around the world . \n + 7 - Research on new wildlife collars may be able to reduce human @-@ animal conflicts by predicting when and where predatory animals hunt . This can not only save human lives and the lives of their pets and livestock but also save these large predatory mammals that are important to the balance of ecosystems . \n + On December 7 , 2006 , Headquarters Marine Corps released a message stating that 2nd Battalion 9th Marines would be reactivated during 2007 as part of the continuing Global War on Terror . 2nd Battalion 9th Marines was re @-@ activated on July 13 , 2007 and replaced the Anti @-@ Terrorism Battalion ( ATBn ) . In September 2008 , Marines and Sailors from 2 / 9 deployed to Al Anbar Province in support of Operation Iraqi Freedom . They were based in the city of Ramadi and returned in April 2009 without any Marines or Sailors killed in action . July 2010 Marines and Sailors from 2 / 9 deployed to Marjah , Helmand Province , Afghanistan in support of Operation Enduring Freedom . In December 2010 Echo Company from 2 / 9 were attached to 3 / 5 in Sangin , Afghanistan where they earned the notorious nickname of " Green Hats . " They returned February 2011 . They redeployed back to Marjah December 2011 and returned July 2012 . Echo and Weapons companies deployed once more to Afghanistan from January through April 2013 , participating in combat operations out of Camp Leatherneck . On April 1 , 2015 the battalion was deactivated in a ceremony at Camp Lejeune . \n 8 - " Love Me Like You " ( Christmas Mix ) – 3 : 29 \n + ( i ) = Indoor \n 9 @@ -188,6 +201,12 @@ from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) ``` + Downloading: 100%|██████████| 762/762 [00:00<00:00, 358kB/s] + Downloading: 100%|██████████| 1.04M/1.04M [00:04<00:00, 235kB/s] + Downloading: 100%|██████████| 456k/456k [00:02<00:00, 217kB/s] + Downloading: 100%|██████████| 1.36M/1.36M [00:05<00:00, 252kB/s] + + 我们现在可以对所有的文本调用分词器,该操作可以简单地使用来自Datasets库的map方法实现。首先,我们定义一个在文本上调用标记器的函数: @@ -204,6 +223,62 @@ def tokenize_function(examples): tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"]) ``` + #0: 0%| | 0/2 [00:00 + ----> 1 trainer.train() + + + ~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, **kwargs) + 1032 self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) + 1033 + -> 1034 for step, inputs in enumerate(epoch_iterator): + 1035 + 1036 # Skip past any already trained steps if resuming training + + + ~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self) + 519 if self._sampler_iter is None: + 520 self._reset() + --> 521 data = self._next_data() + 522 self._num_yielded += 1 + 523 if self._dataset_kind == _DatasetKind.Iterable and \ + + + ~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self) + 559 def _next_data(self): + 560 index = self._next_index() # may raise StopIteration + --> 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration + 562 if self._pin_memory: + 563 data = _utils.pin_memory.pin_memory(data) + + + ~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) + 42 def fetch(self, possibly_batched_index): + 43 if self.auto_collation: + ---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] + 45 else: + 46 data = self.dataset[possibly_batched_index] + + + ~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in (.0) + 42 def fetch(self, possibly_batched_index): + 43 if self.auto_collation: + ---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] + 45 else: + 46 data = self.dataset[possibly_batched_index] + + + KeyError: 1 + + 一旦训练完成,我们就可以评估我们的模型,得到它在验证集上的perplexity,如下所示: diff --git a/docs/篇章4-使用Transformers解决NLP任务/4.6-生成任务-机器翻译.ipynb b/docs/篇章4-使用Transformers解决NLP任务/4.6-生成任务-机器翻译.ipynb index 9424c5d..41f7a7e 100644 --- a/docs/篇章4-使用Transformers解决NLP任务/4.6-生成任务-机器翻译.ipynb +++ b/docs/篇章4-使用Transformers解决NLP任务/4.6-生成任务-机器翻译.ipynb @@ -2,50 +2,94 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "X4cRE8IbIrIV" - }, "source": [ "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n", "\n", "建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。\n", "如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。" - ] + ], + "metadata": { + "id": "X4cRE8IbIrIV" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "source": [ + "! pip install datasets transformers \"sacrebleu>=1.4.12,<2.0.0\" sentencepiece" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: datasets in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (1.6.2)\n", + "Requirement already satisfied: transformers in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (4.4.2)\n", + "Collecting sacrebleu<2.0.0,>=1.4.12\n", + " Downloading sacrebleu-1.5.1-py3-none-any.whl (54 kB)\n", + "\u001b[K |████████████████████████████████| 54 kB 235 kB/s \n", + "\u001b[?25hCollecting sentencepiece\n", + " Downloading sentencepiece-0.1.96-cp38-cp38-macosx_10_6_x86_64.whl (1.1 MB)\n", + "\u001b[K |████████████████████████████████| 1.1 MB 438 kB/s \n", + "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.21.1)\n", + "Requirement already satisfied: multiprocess in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.70.12.2)\n", + "Requirement already satisfied: fsspec in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2021.7.0)\n", + "Requirement already satisfied: huggingface-hub<0.1.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.0.15)\n", + "Requirement already satisfied: pandas in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (1.3.1)\n", + "Requirement already satisfied: dill in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (0.3.4)\n", + "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (4.49.0)\n", + "Requirement already satisfied: pyarrow>=1.0.0<4.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (5.0.0)\n", + "Requirement already satisfied: xxhash in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: requests>=2.19.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (2.26.0)\n", + "Requirement already satisfied: packaging in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from datasets) (20.9)\n", + "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.10.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (2021.8.3)\n", + "Requirement already satisfied: sacremoses in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (0.0.45)\n", + "Requirement already satisfied: filelock in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from transformers) (3.0.12)\n", + "Collecting portalocker==2.0.0\n", + " Downloading portalocker-2.0.0-py2.py3-none-any.whl (11 kB)\n", + "Requirement already satisfied: typing-extensions in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from huggingface-hub<0.1.0->datasets) (3.10.0.0)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from packaging->datasets) (2.4.7)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (1.26.6)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2021.5.30)\n", + "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from requests>=2.19.0->datasets) (3.2)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2017.3 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from pandas->datasets) (2021.1)\n", + "Requirement already satisfied: six>=1.5 in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: click in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (8.0.1)\n", + "Requirement already satisfied: joblib in /Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages (from sacremoses->transformers) (1.0.1)\n", + "Installing collected packages: portalocker, sentencepiece, sacrebleu\n", + "Successfully installed portalocker-2.0.0 sacrebleu-1.5.1 sentencepiece-0.1.96\n", + "\u001b[33mWARNING: You are using pip version 21.2.3; however, version 21.2.4 is available.\n", + "You should consider upgrading via the '/Users/niepig/Desktop/zhihu/learn-nlp-with-transformers/venv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], "metadata": { "id": "MOsHUjgdIrIW" - }, - "outputs": [], - "source": [ - "! pip install datasets transformers sacrebleu sentencepiece" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "HFASsisvIrIb" - }, "source": [ "如果您正在本地打开这个notebook,请确保您认真阅读并安装了transformer-quick-start-zh的readme文件中的所有依赖库。您也可以在[这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)找到本notebook的多GPU分布式训练版本。" - ] + ], + "metadata": { + "id": "HFASsisvIrIb" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rEJBSTyZIrIb" - }, "source": [ "# 微调transformer模型解决翻译任务" - ] + ], + "metadata": { + "id": "rEJBSTyZIrIb" + } }, { "cell_type": "markdown", - "metadata": { - "id": "kTCFado4IrIc" - }, "source": [ "在这个notebook中,我们将展示如何使用[🤗 Transformers](https://github.com/huggingface/transformers)代码库中的模型来解决自然语言处理中的翻译任务。我们将会使用[WMT dataset](http://www.statmt.org/wmt16/)数据集。这是翻译任务最常用的数据集之一。\n", "\n", @@ -54,85 +98,125 @@ "![Widget inference on a translation task](https://github.com/huggingface/notebooks/blob/master/examples/images/translation.png?raw=1)\n", "\n", "对于翻译任务,我们将展示如何使用简单的加载数据集,同时针对相应的仍无使用transformer中的Trainer接口对模型进行微调。" - ] + ], + "metadata": { + "id": "kTCFado4IrIc" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rJvQjiUqPjhM" - }, - "outputs": [], + "execution_count": 2, "source": [ "model_checkpoint = \"Helsinki-NLP/opus-mt-en-ro\" \n", "# 选择一个模型checkpoint" - ] + ], + "outputs": [], + "metadata": { + "id": "rJvQjiUqPjhM" + } }, { "cell_type": "markdown", - "metadata": { - "id": "4RRkXuteIrIh" - }, "source": [ "只要预训练的transformer模型包含seq2seq结构的head层,那么本notebook理论上可以使用各种各样的transformer模型[模型面板](https://huggingface.co/models),解决任何翻译任务。\n", "\n", "本文我们使用已经训练好的[`Helsinki-NLP/opus-mt-en-ro`](https://huggingface.co/Helsinki-NLP/opus-mt-en-ro) checkpoint来做翻译任务。 " - ] + ], + "metadata": { + "id": "4RRkXuteIrIh" + } }, { "cell_type": "markdown", - "metadata": { - "id": "whPRbBNbIrIl" - }, "source": [ "## 加载数据" - ] + ], + "metadata": { + "id": "whPRbBNbIrIl" + } }, { "cell_type": "markdown", - "metadata": { - "id": "W7QYTpxXIrIl" - }, "source": [ "\n", "我们将会使用🤗 Datasets库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用load_dataset和load_metric即可。我们使用WMT数据集中的English/Romanian双语翻译。\n" - ] + ], + "metadata": { + "id": "W7QYTpxXIrIl" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IreSlFmlIrIm" - }, - "outputs": [], + "execution_count": 3, "source": [ "from datasets import load_dataset, load_metric\n", "\n", "raw_datasets = load_dataset(\"wmt16\", \"ro-en\")\n", "metric = load_metric(\"sacrebleu\")" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 2.81kB [00:00, 523kB/s] \n", + "Downloading: 3.19kB [00:00, 758kB/s] \n", + "Downloading: 41.0kB [00:00, 11.0MB/s] \n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading and preparing dataset wmt16/ro-en (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a...\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 100%|██████████| 225M/225M [00:18<00:00, 12.2MB/s]\n", + "Downloading: 100%|██████████| 23.5M/23.5M [00:16<00:00, 1.44MB/s]\n", + "Downloading: 100%|██████████| 38.7M/38.7M [00:03<00:00, 9.82MB/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Dataset wmt16 downloaded and prepared to /Users/niepig/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/0d9fb3e814712c785176ad8cdb9f465fbe6479000ee6546725db30ad8a8b5f8a. Subsequent calls will reuse this data.\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 5.40kB [00:00, 2.08MB/s] \n" + ] + } + ], + "metadata": { + "id": "IreSlFmlIrIm" + } }, { "cell_type": "markdown", - "metadata": { - "id": "RzfPtOMoIrIu" - }, "source": [ "这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。" - ] + ], + "metadata": { + "id": "RzfPtOMoIrIu" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "GWiVUF0jIrIv", - "outputId": "3151a9fc-7239-4471-a8f0-548dd68d5a89" - }, + "execution_count": 4, + "source": [ + "raw_datasets" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "DatasetDict({\n", @@ -151,72 +235,67 @@ "})" ] }, - "execution_count": 4, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" + "metadata": {}, + "execution_count": 4 } ], - "source": [ - "raw_datasets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "u3EtYfeHIrIz" - }, - "source": [ - "给定一个数据切分的key(train、validation或者test)和下标即可查看数据。" - ] - }, - { - "cell_type": "code", - "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "X6HrpprwIrIz", - "outputId": "69f3873e-2d1f-4614-e43e-9e654277245c" - }, + "id": "GWiVUF0jIrIv", + "outputId": "3151a9fc-7239-4471-a8f0-548dd68d5a89" + } + }, + { + "cell_type": "markdown", + "source": [ + "给定一个数据切分的key(train、validation或者test)和下标即可查看数据。" + ], + "metadata": { + "id": "u3EtYfeHIrIz" + } + }, + { + "cell_type": "code", + "execution_count": 5, + "source": [ + "raw_datasets[\"train\"][0]\n", + "# 我们可以看到一句英语en对应一句罗马尼亚语言ro" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'translation': {'en': 'Membership of Parliament: see Minutes',\n", " 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}" ] }, - "execution_count": 5, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" + "metadata": {}, + "execution_count": 5 } ], - "source": [ - "raw_datasets[\"train\"][0]\n", - "# 我们可以看到一句英语en对应一句罗马尼亚语言ro" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "X6HrpprwIrIz", + "outputId": "69f3873e-2d1f-4614-e43e-9e654277245c" + } }, { "cell_type": "markdown", - "metadata": { - "id": "WHUmphG3IrI3" - }, "source": [ "为了能够进一步理解数据长什么样子,下面的函数将从数据集里随机选择几个例子进行展示。" - ] + ], + "metadata": { + "id": "WHUmphG3IrI3" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "i3j8APAoIrI3" - }, - "outputs": [], + "execution_count": 6, "source": [ "import datasets\n", "import random\n", @@ -237,22 +316,25 @@ " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" - ] + ], + "outputs": [], + "metadata": { + "id": "i3j8APAoIrI3" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 255 - }, - "id": "SZy5tRB_IrI7", - "outputId": "93e16172-d927-457d-fcab-04dcb4d2ef29" - }, + "execution_count": 7, + "source": [ + "show_random_elements(raw_datasets[\"train\"])" + ], "outputs": [ { + "output_type": "display_data", "data": { + "text/plain": [ + "" + ], "text/html": [ "\n", " \n", @@ -264,62 +346,58 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
0{'en': 'The Bulgarian gymnastics team won the gold medal at the traditional Grand Prix series competition in Thiais, France, which wrapped up on Sunday (March 30th).', 'ro': 'Echipa bulgară de gimnastică a câştigat medalia de aur la tradiţionala competiţie Grand Prix din Thiais, Franţa, care s-a încheiat duminică (30 martie).'}{'en': 'I do not believe that this is the right course.', 'ro': 'Nu cred că acesta este varianta corectă.'}
1{'en': 'Being on that committee, however, you will know that this was a very hot topic in negotiations between Norway and some Member States.', 'ro': 'Totuşi, făcând parte din această comisie, ştiţi că acesta a fost un subiect foarte aprins în negocierile dintre Norvegia şi unele state membre.'}{'en': 'A total of 104 new jobs were created at the European Chemicals Agency, which mainly supervises our REACH projects.', 'ro': 'Un total de 104 noi locuri de muncă au fost create la Agenția Europeană pentru Produse Chimice, care, în special, supraveghează proiectele noastre REACH.'}
2{'en': 'The overwhelming vote shows just this.', 'ro': 'Ceea ce demonstrează şi votul favorabil.'}{'en': 'In view of the above, will the Council say what stage discussions for Turkish participation in joint Frontex operations have reached?', 'ro': 'Care este stadiul negocierilor referitoare la participarea Turciei la operațiunile comune din cadrul Frontex?'}
3{'en': '[Photo illustration by Catherine Gurgenidze for Southeast European Times]', 'ro': '[Ilustraţii foto de Catherine Gurgenidze pentru Southeast European Times]'}{'en': 'We now fear that if the scope of this directive is expanded, the directive will suffer exactly the same fate as the last attempt at introducing 'Made in' origin marking - in other words, that it will once again be blocked by the Council.', 'ro': 'Acum ne temem că, dacă sfera de aplicare a directivei va fi extinsă, aceasta va avea exact aceeaşi soartă ca ultima încercare de introducere a marcajului de origine \"Made in”, cu alte cuvinte, că va fi din nou blocată la Consiliu.'}
4{'en': '(HU) Mr President, today the specific text of the agreement between the Hungarian Government and the European Commission has been formulated.', 'ro': '(HU) Domnule președinte, textul concret al acordului dintre guvernul ungar și Comisia Europeană a fost formulat astăzi.'}{'en': 'The country dropped nine slots to 85th, with a score of 6.58.', 'ro': 'Ţara a coborât nouă poziţii, pe locul 85, cu un scor de 6,58.'}
" - ], - "text/plain": [ - "" ] }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" + "metadata": {} } ], - "source": [ - "show_random_elements(raw_datasets[\"train\"])" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "id": "SZy5tRB_IrI7", + "outputId": "93e16172-d927-457d-fcab-04dcb4d2ef29" + } }, { "cell_type": "markdown", - "metadata": { - "id": "lnjDIuQ3IrI-" - }, "source": [ "metric是[`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)类的一个实例,查看metric和使用的例子:" - ] + ], + "metadata": { + "id": "lnjDIuQ3IrI-" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5o4rUteaIrI_", - "outputId": "4814f907-6225-4af0-ee63-376699dc79ee" - }, + "execution_count": 8, + "source": [ + "metric" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "Metric(name: \"sacrebleu\", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: \"\"\"\n", @@ -355,76 +433,72 @@ "\"\"\", stored examples: 0)" ] }, - "execution_count": 8, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" + "metadata": {}, + "execution_count": 8 } ], - "source": [ - "metric" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5o4rUteaIrI_", + "outputId": "4814f907-6225-4af0-ee63-376699dc79ee" + } }, { "cell_type": "markdown", - "metadata": { - "id": "jAWdqcUBIrJC" - }, "source": [ "我们使用`compute`方法来对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:" - ] + ], + "metadata": { + "id": "jAWdqcUBIrJC" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, + "source": [ + "fake_preds = [\"hello there\", \"general kenobi\"]\n", + "fake_labels = [[\"hello there\"], [\"general kenobi\"]]\n", + "metric.compute(predictions=fake_preds, references=fake_labels)" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'score': 0.0,\n", + " 'counts': [4, 2, 0, 0],\n", + " 'totals': [4, 2, 0, 0],\n", + " 'precisions': [100.0, 100.0, 0.0, 0.0],\n", + " 'bp': 1.0,\n", + " 'sys_len': 4,\n", + " 'ref_len': 4}" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6XN1Rq0aIrJC", "outputId": "d130ad50-c6ca-42bc-8b14-31021feb620d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'bp': 1.0,\n", - " 'counts': [4, 2, 0, 0],\n", - " 'precisions': [100.0, 100.0, 0.0, 0.0],\n", - " 'ref_len': 4,\n", - " 'score': 0.0,\n", - " 'sys_len': 4,\n", - " 'totals': [4, 2, 0, 0]}" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "fake_preds = [\"hello there\", \"general kenobi\"]\n", - "fake_labels = [[\"hello there\"], [\"general kenobi\"]]\n", - "metric.compute(predictions=fake_preds, references=fake_labels)" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "n9qywopnIrJH" - }, "source": [ "## 数据预处理" - ] + ], + "metadata": { + "id": "n9qywopnIrJH" + } }, { "cell_type": "markdown", - "metadata": { - "id": "YVx71GdAIrJH" - }, "source": [ "在将数据喂入模型之前,我们需要对数据进行预处理。预处理的工具叫Tokenizer。Tokenizer首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。\n", "\n", @@ -435,156 +509,150 @@ "\n", "\n", "这个被下载的tokens vocabulary会被缓存起来,从而再次使用的时候不会重新下载。" - ] + ], + "metadata": { + "id": "YVx71GdAIrJH" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eXNLu_-nIrJI" - }, - "outputs": [], + "execution_count": 10, "source": [ "from transformers import AutoTokenizer\n", "# 需要安装`sentencepiece`: pip install sentencepiece\n", " \n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 100%|██████████| 1.13k/1.13k [00:00<00:00, 466kB/s]\n", + "Downloading: 100%|██████████| 789k/789k [00:00<00:00, 882kB/s]\n", + "Downloading: 100%|██████████| 817k/817k [00:00<00:00, 902kB/s]\n", + "Downloading: 100%|██████████| 1.39M/1.39M [00:01<00:00, 1.24MB/s]\n", + "Downloading: 100%|██████████| 42.0/42.0 [00:00<00:00, 14.6kB/s]\n" + ] + } + ], + "metadata": { + "id": "eXNLu_-nIrJI" + } }, { "cell_type": "markdown", - "metadata": { - "id": "GLRyc5J9PjhS" - }, "source": [ "以我们使用的mBART模型为例,我们需要正确设置source语言和target语言。如果您要翻译的是其他双语语料,请查看[这里](https://huggingface.co/facebook/mbart-large-cc25)。我们可以检查source和target语言的设置:\n" - ] + ], + "metadata": { + "id": "GLRyc5J9PjhS" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kmXG36baPjhS" - }, - "outputs": [], + "execution_count": 11, "source": [ "if \"mbart\" in model_checkpoint:\n", " tokenizer.src_lang = \"en-XX\"\n", " tokenizer.tgt_lang = \"ro-RO\"" - ] + ], + "outputs": [], + "metadata": { + "id": "kmXG36baPjhS" + } }, { "cell_type": "markdown", + "source": [], "metadata": { "id": "Vl6IidfdIrJK" - }, - "source": [] + } }, { "cell_type": "markdown", - "metadata": { - "id": "rowT4iCLIrJK" - }, "source": [ "tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理,tokenizer预处理后得到的数据满足预训练模型输入格式" - ] + ], + "metadata": { + "id": "rowT4iCLIrJK" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, + "source": [ + "tokenizer(\"Hello, this one sentence!\")" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'input_ids': [125, 778, 3, 63, 141, 9191, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a5hBlsrHIrJL", "outputId": "072ee20c-db1d-4ba1-a98a-119405ea9552" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': [125, 778, 3, 63, 141, 9191, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}" - ] - }, - "execution_count": 12, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenizer(\"Hello, this one sentence!\")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "qo_0B1M2IrJM" - }, "source": [ "上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致,那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考[这个教程](https://huggingface.co/transformers/preprocessing.html)\n", "\n", "除了可以tokenize一句话,我们也可以tokenize一个list的句子。" - ] + ], + "metadata": { + "id": "qo_0B1M2IrJM" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, + "source": [ + "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 0], [187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}" + ] + }, + "metadata": {}, + "execution_count": 13 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LkLffVlKPjhT", "outputId": "f144d050-fc84-4a1a-9fc2-25281b681441" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 0], [187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "-uVqYJrePjhT" - }, "source": [ "注意:为了给模型准备好翻译的targets,我们使用`as_target_tokenizer`来控制targets所对应的特殊token:" - ] + ], + "metadata": { + "id": "-uVqYJrePjhT" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "DgCW0X0FPjhT", - "outputId": "352c44ab-f025-4cf6-98d1-786f6f07111a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input_ids': [10334, 1204, 3, 15, 8915, 27, 452, 59, 29579, 581, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n", - "tokens: ['▁Hel', 'lo', ',', '▁', 'this', '▁o', 'ne', '▁se', 'nten', 'ce', '!', '']\n" - ] - } - ], + "execution_count": 14, "source": [ "with tokenizer.as_target_tokenizer():\n", " print(tokenizer(\"Hello, this one sentence!\"))\n", @@ -592,47 +660,60 @@ " tokens = tokenizer.convert_ids_to_tokens(model_input['input_ids'])\n", " # 打印看一下special toke\n", " print('tokens: {}'.format(tokens))" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'input_ids': [10334, 1204, 3, 15, 8915, 27, 452, 59, 29579, 581, 23, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n", + "tokens: ['▁Hel', 'lo', ',', '▁', 'this', '▁o', 'ne', '▁se', 'nten', 'ce', '!', '']\n" + ] + } + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DgCW0X0FPjhT", + "outputId": "352c44ab-f025-4cf6-98d1-786f6f07111a" + } }, { "cell_type": "markdown", - "metadata": { - "id": "2C0hcmp9IrJQ" - }, "source": [ "如果您使用的是T5预训练模型的checkpoints,需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务,具体前缀例子如下:\n" - ] + ], + "metadata": { + "id": "2C0hcmp9IrJQ" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xS1JJSdmPjhU" - }, - "outputs": [], + "execution_count": 15, "source": [ "if model_checkpoint in [\"t5-small\", \"t5-base\", \"t5-larg\", \"t5-3b\", \"t5-11b\"]:\n", " prefix = \"translate English to Romanian: \"\n", "else:\n", " prefix = \"\"" - ] + ], + "outputs": [], + "metadata": { + "id": "xS1JJSdmPjhU" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CezpZ8gFPjhU" - }, "source": [ "现在我们可以把所有内容放在一起组成我们的预处理函数了。我们对样本进行预处理的时候,我们还会`truncation=True`这个参数来确保我们超长的句子被截断。默认情况下,对与比较短的句子我们会自动padding。" - ] + ], + "metadata": { + "id": "CezpZ8gFPjhU" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vc0BSBLIIrJQ" - }, - "outputs": [], + "execution_count": 16, "source": [ "max_input_length = 128\n", "max_target_length = 128\n", @@ -650,131 +731,147 @@ "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" - ] + ], + "outputs": [], + "metadata": { + "id": "vc0BSBLIIrJQ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "0lm8ozrJIrJR" - }, "source": [ "以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。" - ] + ], + "metadata": { + "id": "0lm8ozrJIrJR" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, + "source": [ + "preprocess_function(raw_datasets['train'][:2])" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'input_ids': [[393, 4462, 14, 1137, 53, 216, 28636, 0], [24385, 14, 28636, 14, 4646, 4622, 53, 216, 28636, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[42140, 494, 1750, 53, 8, 59, 903, 3543, 9, 15202, 0], [36199, 6612, 9, 15202, 122, 568, 35788, 21549, 53, 8, 59, 903, 3543, 9, 15202, 0]]}" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-b70jh26IrJS", "outputId": "89b26088-d2d2-4312-81d8-b0f5e62dd6a2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': [[393, 4462, 14, 1137, 53, 216, 28636, 0], [24385, 14, 28636, 14, 4646, 4622, 53, 216, 28636, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[42140, 494, 1750, 53, 8, 59, 903, 3543, 9, 15202, 0], [36199, 6612, 9, 15202, 122, 568, 35788, 21549, 53, 8, 59, 903, 3543, 9, 15202, 0]]}" - ] - }, - "execution_count": 22, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "preprocess_function(raw_datasets['train'][:2])" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "zS-6iXTkIrJT" - }, "source": [ "接下来对数据集datasets里面的所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features应用到(map)所有样本上。" - ] + ], + "metadata": { + "id": "zS-6iXTkIrJT" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DDtsaJeVIrJT" - }, - "outputs": [], + "execution_count": 18, "source": [ "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 611/611 [02:32<00:00, 3.99ba/s]\n", + "100%|██████████| 2/2 [00:00<00:00, 3.76ba/s]\n", + "100%|██████████| 2/2 [00:00<00:00, 3.89ba/s]\n" + ] + } + ], + "metadata": { + "id": "DDtsaJeVIrJT" + } }, { "cell_type": "markdown", - "metadata": { - "id": "voWiw8C7IrJV" - }, "source": [ "更好的是,返回的结果会自动被缓存,避免下次处理的时候重新计算(但是也要注意,如果输入有改动,可能会被缓存影响!)。datasets库函数会对输入的参数进行检测,判断是否有变化,如果没有变化就使用缓存数据,如果有变化就重新处理。但如果输入参数不变,想改变输入的时候,最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外,上面使用到的`batched=True`这个参数是tokenizer的特点,以为这会使用多线程同时并行对输入进行处理。" - ] + ], + "metadata": { + "id": "voWiw8C7IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "545PP3o8IrJV" - }, "source": [ "## 微调transformer模型" - ] + ], + "metadata": { + "id": "545PP3o8IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "FBiW8UpKIrJW" - }, "source": [ "既然数据已经准备好了,现在我们需要下载并加载我们的预训练模型,然后微调预训练模型。既然我们是做seq2seq任务,那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSeq2SeqLM`这个类。和tokenizer相似,`from_pretrained`方法同样可以帮助我们下载并加载模型,同时也会对模型进行缓存,就不会重复下载模型啦。" - ] + ], + "metadata": { + "id": "FBiW8UpKIrJW" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TlqNaB8jIrJW" - }, - "outputs": [], + "execution_count": 19, "source": [ "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Downloading: 100%|██████████| 301M/301M [00:19<00:00, 15.1MB/s]\n" + ] + } + ], + "metadata": { + "id": "TlqNaB8jIrJW" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CczA5lJlIrJX" - }, "source": [ "由于我们微调的任务是机器翻译,而我们加载的是预训练的seq2seq模型,所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数(比如:预训练语言模型的神经网络head被扔掉了,同时随机初始化了机器翻译的神经网络head)。" - ] + ], + "metadata": { + "id": "CczA5lJlIrJX" + } }, { "cell_type": "markdown", - "metadata": { - "id": "_N8urzhyIrJY" - }, "source": [ "\n", "为了能够得到一个`Seq2SeqTrainer`训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数[`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性" - ] + ], + "metadata": { + "id": "_N8urzhyIrJY" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Bliy8zgjIrJY" - }, - "outputs": [], + "execution_count": 20, "source": [ "batch_size = 16\n", "args = Seq2SeqTrainingArguments(\n", @@ -789,13 +886,14 @@ " predict_with_generate=True,\n", " fp16=False,\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "Bliy8zgjIrJY" + } }, { "cell_type": "markdown", - "metadata": { - "id": "km3pGVdTIrJc" - }, "source": [ "上面evaluation_strategy = \"epoch\"参数告诉训练代码:我们每个epcoh会做一次验证评估。\n", "\n", @@ -804,35 +902,34 @@ "由于我们的数据集比较大,同时`Seq2SeqTrainer`会不断保存模型,所以我们需要告诉它至多保存`save_total_limit=3`个模型。\n", "\n", "最后我们需要一个数据收集器data collator,将我们处理好的输入喂给模型。" - ] + ], + "metadata": { + "id": "km3pGVdTIrJc" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ZMdgZaOoPjhX" - }, - "outputs": [], + "execution_count": 21, "source": [ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)" - ] + ], + "outputs": [], + "metadata": { + "id": "ZMdgZaOoPjhX" + } }, { "cell_type": "markdown", - "metadata": { - "id": "7sZOdRlRIrJd" - }, "source": [ "设置好`Seq2SeqTrainer`还剩最后一件事情,那就是我们需要定义好评估方法。我们使用`metric`来完成评估。将模型预测送入评估之前,我们也会做一些数据后处理:" - ] + ], + "metadata": { + "id": "7sZOdRlRIrJd" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UmvbnJ9JIrJd" - }, - "outputs": [], + "execution_count": 22, "source": [ "import numpy as np\n", "\n", @@ -862,24 +959,24 @@ " result[\"gen_len\"] = np.mean(prediction_lens)\n", " result = {k: round(v, 4) for k, v in result.items()}\n", " return result" - ] + ], + "outputs": [], + "metadata": { + "id": "UmvbnJ9JIrJd" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rXuFTAzDIrJe" - }, "source": [ "最后将所有的参数/数据/模型传给`Seq2SeqTrainer`即可" - ] + ], + "metadata": { + "id": "rXuFTAzDIrJe" + } }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "imY1oC3SIrJf" - }, - "outputs": [], + "execution_count": 23, "source": [ "trainer = Seq2SeqTrainer(\n", " model,\n", @@ -890,46 +987,115 @@ " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "imY1oC3SIrJf" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CdzABDVcIrJg" - }, "source": [ "调用`train`方法进行微调训练。" - ] + ], + "metadata": { + "id": "CdzABDVcIrJg" + } }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, + "source": [ + "trainer.train()" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 1%|▏ | 500/38145 [1:05:10<91:20:54, 8.74s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'loss': 0.8588, 'learning_rate': 1.973784244330843e-05, 'epoch': 0.01}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 3%|▎ | 1000/38145 [2:09:32<73:56:07, 7.17s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'loss': 0.8343, 'learning_rate': 1.947568488661686e-05, 'epoch': 0.03}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 4%|▍ | 1500/38145 [3:03:53<57:17:10, 5.63s/it]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'loss': 0.8246, 'learning_rate': 1.9213527329925285e-05, 'epoch': 0.04}\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + " 5%|▌ | 1980/38145 [3:52:54<67:46:36, 6.75s/it]" + ] + }, + { + "output_type": "error", + "ename": "KeyboardInterrupt", + "evalue": "", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/2k/x3py0v857kgcwqvvl00xxhxw0000gn/T/ipykernel_15169/4032920361.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, resume_from_checkpoint, trial, **kwargs)\u001b[0m\n\u001b[1;32m 1079\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1080\u001b[0m \u001b[0;31m# Revert to normal clipping otherwise, handling Apex or full precision\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1081\u001b[0;31m torch.nn.utils.clip_grad_norm_(\n\u001b[0m\u001b[1;32m 1082\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaster_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_apex\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1083\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py\u001b[0m in \u001b[0;36mclip_grad_norm_\u001b[0;34m(parameters, max_norm, norm_type, error_if_nonfinite)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mtotal_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnorms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0mtotal_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0merror_if_nonfinite\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/nn/utils/clip_grad.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mtotal_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnorms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnorms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m \u001b[0mtotal_norm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mtotal_norm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0merror_if_nonfinite\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/zhihu/learn-nlp-with-transformers/venv/lib/python3.8/site-packages/torch/functional.py\u001b[0m in \u001b[0;36mnorm\u001b[0;34m(input, p, dim, keepdim, out, dtype)\u001b[0m\n\u001b[1;32m 1310\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1311\u001b[0m \u001b[0m_dim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# noqa: C416 TODO: rewrite as list(range(m))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1312\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_VF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_dim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mkeepdim\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1313\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;31m# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "metadata": { "id": "uNx5pyRlIrJh", "scrolled": false - }, - "outputs": [], - "source": [ - "trainer.train()" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "JXOyGJtqPjhZ" - }, "source": [ "最后别忘了,查看如何上传模型 ,上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样,直接用模型名字就能使用您的模型啦。\n" - ] + ], + "metadata": { + "id": "JXOyGJtqPjhZ" + } }, { "cell_type": "code", "execution_count": null, + "source": [], + "outputs": [], "metadata": { "id": "Jeq1Cq2yPjhZ" - }, - "outputs": [], - "source": [] + } } ], "metadata": { @@ -941,14 +1107,22 @@ "hash": "3bfce0b4c492a35815b5705a19fe374a7eea0baaa08b34d90450caf1fe9ce20b" }, "kernelspec": { - "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)", - "name": "python3" + "name": "python3", + "display_name": "Python 3.8.10 64-bit ('venv': virtualenv)" }, "language_info": { "name": "python", - "version": "" + "version": "3.8.10", + "mimetype": "text/x-python", + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "pygments_lexer": "ipython3", + "nbconvert_exporter": "python", + "file_extension": ".py" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 2 } \ No newline at end of file