diff --git a/docs/篇章4-使用Transformers解决NLP任务/4.3-问答任务-抽取式问答.ipynb b/docs/篇章4-使用Transformers解决NLP任务/4.3-问答任务-抽取式问答.ipynb index 1dde23e..2bb9071 100644 --- a/docs/篇章4-使用Transformers解决NLP任务/4.3-问答任务-抽取式问答.ipynb +++ b/docs/篇章4-使用Transformers解决NLP任务/4.3-问答任务-抽取式问答.ipynb @@ -2,148 +2,171 @@ "cells": [ { "cell_type": "markdown", + "source": [ + "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\r\n", + "\r\n", + "建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。\r\n", + "如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。" + ], "metadata": { "id": "X4cRE8IbIrIV" - }, - "source": [ - "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n", - "\n", - "建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。\n", - "如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。" - ] + } }, { "cell_type": "code", "execution_count": 1, - "metadata": { - "id": "MOsHUjgdIrIW" - }, - "outputs": [], "source": [ "# !pip install datasets transformers" - ] + ], + "outputs": [], + "metadata": { + "id": "MOsHUjgdIrIW" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rEJBSTyZIrIb" - }, "source": [ "# 在机器问答任务上微调transformer模型" - ] + ], + "metadata": { + "id": "rEJBSTyZIrIb" + } }, { "cell_type": "markdown", - "metadata": { - "id": "nm8eowt6YBU6" - }, "source": [ "在这个notebook中,我们将学习到如何微调[🤗 Transformers](https://github.com/huggingface/transformers)的transformer模型来解决机器问答任务。本文主要解决的是抽取式问答任务:给定一个问题和一段文本,从这段文本中找出能回答该问题的文本片段(span)。通过使用`Trainer` API和dataset包,我们将轻松加载数据集,然后微调transformers。下图给出了一个简单的例子\n", "![Widget inference representing the QA task](images/question_answering.png)\n", "\n", "**Note:** 注意:本文的问答任务是从文本中抽取答案,并不是直接生成答案!" - ] + ], + "metadata": { + "id": "nm8eowt6YBU6" + } }, { "cell_type": "markdown", - "metadata": { - "id": "4RRkXuteIrIh" - }, "source": [ "本notebook设计的例子可以用来解决任何和SQUAD 1和SQUAD 2类似的抽取式问答任务,并且可以使用[模型库Model Hub](https://huggingface.co/models)的任何模型checkpoint,只要这些模型包含了一个token classification head 和 一个fast tokenizer。关于模型和fast tokenizer的对应关系见:[这个表格](https://huggingface.co/transformers/index.html#bigtable)。\n", "\n", "\n", "如果您的数据集和本notebook有所不同,英国只需要微调的调整就可以直接使用本notebook。当然,根据您的硬件设备(电脑内存、显卡大小),您需要合理的调整batch size大小,避免out-of-memory的错误。\n", "Set those three parameters, then the rest of the notebook should run smoothly:" - ] + ], + "metadata": { + "id": "4RRkXuteIrIh" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "# squad_v2等于True或者False分别代表使用SQUAD v1 或者 SQUAD v2。\r\n", + "# 如果您使用的是其他数据集,那么True代表的是:模型可以回答“不可回答”问题,也就是部分问题不给出答案,而False则代表所有问题必须回答。\r\n", + "squad_v2 = False\r\n", + "model_checkpoint = \"distilbert-base-uncased\"\r\n", + "batch_size = 16" + ], + "outputs": [], "metadata": { "id": "zVvslsfMIrIh" - }, - "outputs": [], - "source": [ - "# squad_v2等于True或者False分别代表使用SQUAD v1 或者 SQUAD v2。\n", - "# 如果您使用的是其他数据集,那么True代表的是:模型可以回答“不可回答”问题,也就是部分问题不给出答案,而False则代表所有问题必须回答。\n", - "squad_v2 = False\n", - "model_checkpoint = \"distilbert-base-uncased\"\n", - "batch_size = 16" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "whPRbBNbIrIl" - }, "source": [ "## 加载数据集" - ] + ], + "metadata": { + "id": "whPRbBNbIrIl" + } }, { "cell_type": "markdown", + "source": [ + "我们将会使用[🤗 Datasets](https://github.com/huggingface/datasets) 库来下载数据,并且得到我们需要的评测指标(和benchmark基准进行比较)。\r\n", + "\r\n", + "使用函数`load_dataset`和`load_metric`即可简单完成这两项任务。" + ], "metadata": { "id": "W7QYTpxXIrIl" - }, - "source": [ - "我们将会使用[🤗 Datasets](https://github.com/huggingface/datasets) 库来下载数据,并且得到我们需要的评测指标(和benchmark基准进行比较)。\n", - "\n", - "使用函数`load_dataset`和`load_metric`即可简单完成这两项任务。" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "IreSlFmlIrIm" - }, - "outputs": [], "source": [ "from datasets import load_dataset, load_metric" - ] + ], + "outputs": [], + "metadata": { + "id": "IreSlFmlIrIm" + } }, { "cell_type": "markdown", + "source": [ + "举个例子,我们将会在这个notebook中使用[SQUAD 数据集](https://rajpurkar.github.io/SQuAD-explorer/)。同样,本notebook也适配所有dataset库中提供的所有问答数据集。\r\n", + "\r\n", + "如果您使用的是自己的数据集(json或者csv格式),请查看[Datasets 文档](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)学习如何加载自定义的数据集。可能需要调整每列使用的名字。\r\n" + ], "metadata": { "id": "CKx2zKs5IrIq" - }, - "source": [ - "举个例子,我们将会在这个notebook中使用[SQUAD 数据集](https://rajpurkar.github.io/SQuAD-explorer/)。同样,本notebook也适配所有dataset库中提供的所有问答数据集。\n", - "\n", - "如果您使用的是自己的数据集(json或者csv格式),请查看[Datasets 文档](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files)学习如何加载自定义的数据集。可能需要调整每列使用的名字。\n" - ] + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "# 下载数据(确保有网络)\r\n", + "datasets = load_dataset(\"squad_v2\" if squad_v2 else \"squad\")" + ], + "outputs": [], "metadata": { "id": "s_AY1ATSIrIq" - }, - "outputs": [], - "source": [ - "# 下载数据(确保有网络)\n", - "datasets = load_dataset(\"squad_v2\" if squad_v2 else \"squad\")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "RzfPtOMoIrIu" - }, "source": [ - "这个`datasets`对象是[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)结构,训练、验证、测试分别对应这dict的一个key。" - ] + "除此之外,你也可以从我们提供的[链接](https://gas.graviti.cn/dataset/datawhale/SQuAD)下载数据并解压,将解压后的2个json文件复制到到`docs/篇章4-使用Transformers解决NLP任务/datasets/squad`目录下,然后用下面的代码进行加载。" + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, + "source": [ + "import os\r\n", + "\r\n", + "data_path = './dataset/squad/'\r\n", + "path = os.path.join(data_path, 'squad.py')\r\n", + "cache_dir = os.path.join(data_path, 'cache')\r\n", + "data_files = {\"train\": os.path.join(data_path, \"train-v1.1.json\"), \"validation\": os.path.join(data_path, \"dev-v1.1.json\")}\r\n", + "datasets = load_dataset(path, data_files=data_files, cache_dir=cache_dir)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "这个`datasets`对象是[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)结构,训练、验证、测试分别对应这dict的一个key。" + ], "metadata": { - "id": "GWiVUF0jIrIv", - "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" - }, + "id": "RzfPtOMoIrIu" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# 查看以下datasets及其属性\r\n", + "datasets" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "DatasetDict({\n", @@ -158,38 +181,40 @@ "})" ] }, - "execution_count": 9, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 9 } ], - "source": [ - "# 查看以下datasets及其属性\n", - "datasets" - ] + "metadata": { + "id": "GWiVUF0jIrIv", + "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" + } }, { "cell_type": "markdown", - "metadata": { - "id": "5_Tr9XWDYBVA" - }, "source": [ "无论是训练集、验证集还是测试集,对于每一个问答数据样本都会有“context\", \"question\"和“answers”三个key。\n", "\n", "我们可以使用一个下标来选择一个样本。" - ] + ], + "metadata": { + "id": "5_Tr9XWDYBVA" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "X6HrpprwIrIz", - "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" - }, + "source": [ + "datasets[\"train\"][0]\r\n", + "# answers代表答案\r\n", + "# context代表文本片段\r\n", + "# question代表问题" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},\n", @@ -199,73 +224,72 @@ " 'title': 'University_of_Notre_Dame'}" ] }, - "execution_count": 12, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 12 } ], - "source": [ - "datasets[\"train\"][0]\n", - "# answers代表答案\n", - "# context代表文本片段\n", - "# question代表问题" - ] + "metadata": { + "id": "X6HrpprwIrIz", + "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" + } }, { "cell_type": "markdown", - "metadata": { - "id": "s5GXOp9PYBVA" - }, "source": [ "注意answers的标注。answers除了给出了文本片段里的答案文本之外,还给出了该answer所在位置(以character开始计算,上面的例子是第515位)。\n", "\n", "为了能够进一步理解数据长什么样子,下面的函数将从数据集里随机选择几个例子进行展示。" - ] + ], + "metadata": { + "id": "s5GXOp9PYBVA" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "from datasets import ClassLabel, Sequence\r\n", + "import random\r\n", + "import pandas as pd\r\n", + "from IPython.display import display, HTML\r\n", + "\r\n", + "def show_random_elements(dataset, num_examples=10):\r\n", + " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\r\n", + " picks = []\r\n", + " for _ in range(num_examples):\r\n", + " pick = random.randint(0, len(dataset)-1)\r\n", + " while pick in picks:\r\n", + " pick = random.randint(0, len(dataset)-1)\r\n", + " picks.append(pick)\r\n", + " \r\n", + " df = pd.DataFrame(dataset[picks])\r\n", + " for column, typ in dataset.features.items():\r\n", + " if isinstance(typ, ClassLabel):\r\n", + " df[column] = df[column].transform(lambda i: typ.names[i])\r\n", + " elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\r\n", + " df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\r\n", + " display(HTML(df.to_html()))" + ], + "outputs": [], "metadata": { "id": "i3j8APAoIrI3" - }, - "outputs": [], - "source": [ - "from datasets import ClassLabel, Sequence\n", - "import random\n", - "import pandas as pd\n", - "from IPython.display import display, HTML\n", - "\n", - "def show_random_elements(dataset, num_examples=10):\n", - " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", - " picks = []\n", - " for _ in range(num_examples):\n", - " pick = random.randint(0, len(dataset)-1)\n", - " while pick in picks:\n", - " pick = random.randint(0, len(dataset)-1)\n", - " picks.append(pick)\n", - " \n", - " df = pd.DataFrame(dataset[picks])\n", - " for column, typ in dataset.features.items():\n", - " if isinstance(typ, ClassLabel):\n", - " df[column] = df[column].transform(lambda i: typ.names[i])\n", - " elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n", - " df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])\n", - " display(HTML(df.to_html()))" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "SZy5tRB_IrI7", - "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13", - "scrolled": true - }, + "source": [ + "show_random_elements(datasets[\"train\"], num_examples=2)" + ], "outputs": [ { + "output_type": "display_data", "data": { + "text/plain": [ + "" + ], "text/html": [ "\n", " \n", @@ -297,35 +321,30 @@ " \n", " \n", "
" - ], - "text/plain": [ - "" ] }, "metadata": { "tags": [] - }, - "output_type": "display_data" + } } ], - "source": [ - "show_random_elements(datasets[\"train\"], num_examples=2)" - ] + "metadata": { + "id": "SZy5tRB_IrI7", + "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13", + "scrolled": true + } }, { "cell_type": "markdown", - "metadata": { - "id": "n9qywopnIrJH" - }, "source": [ "## Preprocessing the training data" - ] + ], + "metadata": { + "id": "n9qywopnIrJH" + } }, { "cell_type": "markdown", - "metadata": { - "id": "YVx71GdAIrJH" - }, "source": [ "在将数据喂入模型之前,我们需要对数据进行预处理。预处理的工具叫`Tokenizer`。`Tokenizer`首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。\n", "\n", @@ -335,326 +354,330 @@ "- 使用指定的模型checkpoint对应的tokenizer的时候,我们也下载了模型需要的词表库vocabulary,准确来说是tokens vocabulary。\n", "\n", "这个被下载的tokens vocabulary会被缓存起来,从而再次使用的时候不会重新下载。" - ] + ], + "metadata": { + "id": "YVx71GdAIrJH" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "from transformers import AutoTokenizer\r\n", + " \r\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" + ], + "outputs": [], "metadata": { "id": "eXNLu_-nIrJI" - }, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer\n", - " \n", - "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "Vl6IidfdIrJK" - }, "source": [ "以下代码要求tokenizer必须是transformers.PreTrainedTokenizerFast类型,因为我们在预处理的时候需要用到fast tokenizer的一些特殊特性(比如多线程快速tokenizer)。\n", "\n", "几乎所有模型对应的tokenizer都有对应的fast tokenizer。我们可以在[模型tokenizer对应表](https://huggingface.co/transformers/index.html#bigtable)里查看所有预训练模型对应的tokenizer所拥有的特点。\n" - ] + ], + "metadata": { + "id": "Vl6IidfdIrJK" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import transformers\r\n", + "assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)" + ], + "outputs": [], "metadata": { "id": "w0UUlLKzYBVD" - }, - "outputs": [], - "source": [ - "import transformers\n", - "assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "UdHO2nryYBVE", - "outputId": "af9ba744-7aa9-44a8-b0c8-2b60b21bf786" - }, + "source": [ + "# 如果我们想要看到tokenizer预处理之后的文本格式,我们仅使用tokenizer的tokenize方法,add special tokens意思是增加预训练模型所要求的特俗token。\r\n", + "print(\"单个文本tokenize: {}\".format(tokenizer.tokenize(\"What is your name?\"), add_special_tokens=True))\r\n", + "print(\"2个文本tokenize: {}\".format(tokenizer.tokenize(\"My name is Sylvain.\", add_special_tokens=True)))\r\n", + "# 预训练模型输入格式要求的输入为token IDs,还需要attetnion mask。可以使用下面的方法得到预训练模型格式所要求的输入。" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "单个文本tokenize: ['what', 'is', 'your', 'name', '?']\n", "2个文本tokenize: ['[CLS]', 'my', 'name', 'is', 'sy', '##lva', '##in', '.', '[SEP]']\n" ] } ], - "source": [ - "# 如果我们想要看到tokenizer预处理之后的文本格式,我们仅使用tokenizer的tokenize方法,add special tokens意思是增加预训练模型所要求的特俗token。\n", - "print(\"单个文本tokenize: {}\".format(tokenizer.tokenize(\"What is your name?\"), add_special_tokens=True))\n", - "print(\"2个文本tokenize: {}\".format(tokenizer.tokenize(\"My name is Sylvain.\", add_special_tokens=True)))\n", - "# 预训练模型输入格式要求的输入为token IDs,还需要attetnion mask。可以使用下面的方法得到预训练模型格式所要求的输入。" - ] + "metadata": { + "id": "UdHO2nryYBVE", + "outputId": "af9ba744-7aa9-44a8-b0c8-2b60b21bf786" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rowT4iCLIrJK" - }, "source": [ "tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理,tokenizer预处理后得到的数据满足预训练模型输入格式" - ] + ], + "metadata": { + "id": "rowT4iCLIrJK" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "_E0gsIhoYBVF", - "outputId": "e810df73-25ff-4fd9-edff-417bbe9679b4" - }, + "source": [ + "# 对单个文本进行预处理\r\n", + "tokenizer(\"What is your name?\")" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}" ] }, - "execution_count": 32, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 32 } ], - "source": [ - "# 对单个文本进行预处理\n", - "tokenizer(\"What is your name?\")" - ] + "metadata": { + "id": "_E0gsIhoYBVF", + "outputId": "e810df73-25ff-4fd9-edff-417bbe9679b4" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "a5hBlsrHIrJL", - "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90" - }, + "source": [ + "# 对2个文本进行预处理,可以看到tokenizer在开始添加了101 token ID,中间用102token ID区分两段文本,末尾用102结尾。这些规则都是预训练模型是所设计的。\r\n", + "tokenizer(\"What is your name?\", \"My name is Sylvain.\")" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}" ] }, - "execution_count": 33, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 33 } ], - "source": [ - "# 对2个文本进行预处理,可以看到tokenizer在开始添加了101 token ID,中间用102token ID区分两段文本,末尾用102结尾。这些规则都是预训练模型是所设计的。\n", - "tokenizer(\"What is your name?\", \"My name is Sylvain.\")" - ] + "metadata": { + "id": "a5hBlsrHIrJL", + "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90" + } }, { "cell_type": "markdown", - "metadata": { - "id": "Q0tivQQwYBVG" - }, "source": [ "上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致,那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考[这个教程](https://huggingface.co/transformers/preprocessing.html)\n", "\n", "现在我们还需要思考预训练机器问答模型们是如何处理非常长的文本的。一般来说预训练模型输入有最大长度要求,所以我们通常将超长的输入进行截断。但是,如果我们将问答数据三元组中的超长context截断,那么我们可能丢掉答案(因为我们是从context中抽取出一个小片段作为答案)。为了解决这个问题,下面的代码找到一个超过长度的例子,然后向您演示如何进行处理。我们把超长的输入切片为多个较短的输入,每个输入都要满足模型最大长度输入要求。由于答案可能存在与切片的地方,因此我们需要允许相邻切片之间有交集,代码中通过`doc_stride`参数控制。\n", "\n", "机器问答预训练模型通常将question和context拼接之后作为输入,然后让模型从context里寻找答案。" - ] + ], + "metadata": { + "id": "Q0tivQQwYBVG" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "max_length = 384 # 输入feature的最大长度,question和context拼接之后\r\n", + "doc_stride = 128 # 2个切片之间的重合token数量。" + ], + "outputs": [], "metadata": { "id": "qDHb6I5aYBVH" - }, - "outputs": [], - "source": [ - "max_length = 384 # 输入feature的最大长度,question和context拼接之后\n", - "doc_stride = 128 # 2个切片之间的重合token数量。" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "OVKI8CGNYBVH" - }, "source": [ "for循环遍历数据集,寻找一个超长样本,本notebook例子模型所要求的最大输入是384(经常使用的还有512)" - ] + ], + "metadata": { + "id": "OVKI8CGNYBVH" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "for i, example in enumerate(datasets[\"train\"]):\r\n", + " if len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"]) > 384:\r\n", + " break\r\n", + "example = datasets[\"train\"][i]" + ], + "outputs": [], "metadata": { "id": "7Ma0M9v_YBVH" - }, - "outputs": [], - "source": [ - "for i, example in enumerate(datasets[\"train\"]):\n", - " if len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"]) > 384:\n", - " break\n", - "example = datasets[\"train\"][i]" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "A84Fzhp9YBVI" - }, "source": [ "如果不截断的化,那么输入的长度是396" - ] + ], + "metadata": { + "id": "A84Fzhp9YBVI" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "WySoZrRwYBVI", - "outputId": "53943502-8f19-4a8f-bf98-18fe0cf096a4" - }, + "source": [ + "len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"])" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "396" ] }, - "execution_count": 14, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 14 } ], - "source": [ - "len(tokenizer(example[\"question\"], example[\"context\"])[\"input_ids\"])" - ] + "metadata": { + "id": "WySoZrRwYBVI", + "outputId": "53943502-8f19-4a8f-bf98-18fe0cf096a4" + } }, { "cell_type": "markdown", - "metadata": { - "id": "RaygYboUYBVI" - }, "source": [ "现在如果我们截断成最大长度384,将会丢失超长部分的信息" - ] + ], + "metadata": { + "id": "RaygYboUYBVI" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "MXb6ChXWYBVJ", - "outputId": "fd940c60-69d9-44fc-d8ea-35c58bf54422" - }, + "source": [ + "len(tokenizer(example[\"question\"], example[\"context\"], max_length=max_length, truncation=\"only_second\")[\"input_ids\"])" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "384" ] }, - "execution_count": 37, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 37 } ], - "source": [ - "len(tokenizer(example[\"question\"], example[\"context\"], max_length=max_length, truncation=\"only_second\")[\"input_ids\"])" - ] + "metadata": { + "id": "MXb6ChXWYBVJ", + "outputId": "fd940c60-69d9-44fc-d8ea-35c58bf54422" + } }, { "cell_type": "markdown", - "metadata": { - "id": "738aDe4aYBVJ" - }, "source": [ "注意,一般来说,我们只对context进行切片,不会对问题进行切片,由于context是拼接在question后面的,对应着第2个文本,所以使用`only_second`控制.tokenizer使用`doc_stride`控制切片之间的重合长度。" - ] + ], + "metadata": { + "id": "738aDe4aYBVJ" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "tokenized_example = tokenizer(\r\n", + " example[\"question\"],\r\n", + " example[\"context\"],\r\n", + " max_length=max_length,\r\n", + " truncation=\"only_second\",\r\n", + " return_overflowing_tokens=True,\r\n", + " stride=doc_stride\r\n", + ")" + ], + "outputs": [], "metadata": { "id": "bLvoOEUIYBVJ" - }, - "outputs": [], - "source": [ - "tokenized_example = tokenizer(\n", - " example[\"question\"],\n", - " example[\"context\"],\n", - " max_length=max_length,\n", - " truncation=\"only_second\",\n", - " return_overflowing_tokens=True,\n", - " stride=doc_stride\n", - ")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "nOA0Y8dGYBVK" - }, "source": [ "由于对超长输入进行了切片,我们得到了多个输入,这些输入input_ids对应的长度是" - ] + ], + "metadata": { + "id": "nOA0Y8dGYBVK" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "PEZwY2IzYBVK", - "outputId": "d4eb8bd4-81fa-4aa5-aae2-bf90bd7b2bc1" - }, + "source": [ + "[len(x) for x in tokenized_example[\"input_ids\"]]" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "[384, 157]" ] }, - "execution_count": 40, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 40 } ], - "source": [ - "[len(x) for x in tokenized_example[\"input_ids\"]]" - ] + "metadata": { + "id": "PEZwY2IzYBVK", + "outputId": "d4eb8bd4-81fa-4aa5-aae2-bf90bd7b2bc1" + } }, { "cell_type": "markdown", - "metadata": { - "id": "XuVGMM2BYBVL" - }, "source": [ "我们可以将预处理后的token IDs,input_ids还原为文本格式:" - ] + ], + "metadata": { + "id": "XuVGMM2BYBVL" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "cBs66sJXYBVL", - "outputId": "dca34307-56e2-45b3-bc21-e6a656c2f8d7" - }, + "source": [ + "for i, x in enumerate(tokenized_example[\"input_ids\"][:2]):\n", + " print(\"切片: {}\".format(i))\n", + " print(tokenizer.decode(x))" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "切片: 0\n", "[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fifteenth at notre dame, has achieved a 332 - 165 record. in 2009 they were invited to the nit, where they advanced to the semifinals but were beaten by penn state who went on and beat baylor in the championship. the 2010 – 11 team concluded its regular season ranked number seven in the country, with a record of 25 – 5, brey's fifth straight 20 - win season, and a second - place finish in the big east. during the 2014 - 15 season, the team went 32 - 6 and won the acc conference tournament, later advancing to the elite 8, where the fighting irish lost on a missed buzzer - beater against then undefeated kentucky. led by nba draft picks jerian grant and pat connaughton, the fighting irish beat the eventual national champion duke blue devils twice during the season. the 32 wins were [SEP]\n", @@ -663,38 +686,23 @@ ] } ], - "source": [ - "for i, x in enumerate(tokenized_example[\"input_ids\"][:2]):\n", - " print(\"切片: {}\".format(i))\n", - " print(tokenizer.decode(x))" - ] + "metadata": { + "id": "cBs66sJXYBVL", + "outputId": "dca34307-56e2-45b3-bc21-e6a656c2f8d7" + } }, { "cell_type": "markdown", - "metadata": { - "id": "F9IYlTYgYBVL" - }, "source": [ "由于我们对超长文本进行了切片,我们需要重新寻找答案所在位置(相对于每一片context开头的相对位置)。机器问答模型将使用答案的位置(答案的起始位置和结束位置,start和end)作为训练标签(而不是答案的token IDS)。所以切片需要和原始输入有一个对应关系,每个token在切片后context的位置和原始超长context里位置的对应关系。在tokenizer里可以使用`return_offsets_mapping`参数得到这个对应关系的map:" - ] + ], + "metadata": { + "id": "F9IYlTYgYBVL" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "W8PK9eS5YBVM", - "outputId": "d9bcf780-2654-43dc-afd3-4b023a9e0229" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373), (374, 379), (379, 380), (381, 384), (385, 389), (390, 393), (394, 406), (407, 408), (409, 415), (416, 418)]\n", - "[0, 0]\n" - ] - } - ], "source": [ "tokenized_example = tokenizer(\n", " example[\"question\"],\n", @@ -707,94 +715,95 @@ ")\n", "# 打印切片前后位置下标的对应关系\n", "print(tokenized_example[\"offset_mapping\"][0][:100])" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373), (374, 379), (379, 380), (381, 384), (385, 389), (390, 393), (394, 406), (407, 408), (409, 415), (416, 418)]\n", + "[0, 0]\n" + ] + } + ], + "metadata": { + "id": "W8PK9eS5YBVM", + "outputId": "d9bcf780-2654-43dc-afd3-4b023a9e0229" + } }, { "cell_type": "markdown", - "metadata": { - "id": "wGeKX0oxYBVM" - }, "source": [ "上面打印的是tokenized_example第0片的前100个tokens在原始context片里的位置。注意第一个token是`[CLS]`设定为(0, 0)是因为这个token不属于qeustion或者answer的一部分。第2个token对应的起始和结束位置是0和3。我们可以根据切片后的token id转化对应的token;然后使用`offset_mapping`参数映射回切片前的token位置,找到原始位置的tokens。由于question拼接在context前面,所以直接从question里根据下标找就行了。" - ] + ], + "metadata": { + "id": "wGeKX0oxYBVM" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "gv9vD3r6YBVN", - "outputId": "204f24b3-5feb-4106-f9b4-eb52b5e62fee" - }, + "source": [ + "first_token_id = tokenized_example[\"input_ids\"][0][1]\n", + "offsets = tokenized_example[\"offset_mapping\"][0][1]\n", + "print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example[\"question\"][offsets[0]:offsets[1]])" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "how How\n" ] } ], - "source": [ - "first_token_id = tokenized_example[\"input_ids\"][0][1]\n", - "offsets = tokenized_example[\"offset_mapping\"][0][1]\n", - "print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example[\"question\"][offsets[0]:offsets[1]])" - ] + "metadata": { + "id": "gv9vD3r6YBVN", + "outputId": "204f24b3-5feb-4106-f9b4-eb52b5e62fee" + } }, { "cell_type": "markdown", - "metadata": { - "id": "4SpfrDNzYBVN" - }, "source": [ "因此,我们得到了切片前后的位置对应关系。我们还需要使用`sequence_ids`参数来区分question和context。" - ] + ], + "metadata": { + "id": "4SpfrDNzYBVN" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "gRQPZUQPYBVN", - "outputId": "94e9a2d8-fc4e-472f-9ddb-61be9641e38d" - }, + "source": [ + "sequence_ids = tokenized_example.sequence_ids()\n", + "print(sequence_ids)" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]\n" ] } ], - "source": [ - "sequence_ids = tokenized_example.sequence_ids()\n", - "print(sequence_ids)" - ] + "metadata": { + "id": "gRQPZUQPYBVN", + "outputId": "94e9a2d8-fc4e-472f-9ddb-61be9641e38d" + } }, { "cell_type": "markdown", - "metadata": { - "id": "jCtYEPCsYBVO" - }, "source": [ " `None`对应了special tokens,然后0或者1分表代表第1个文本和第2个文本,由于我们qeustin第1个传入,context第2个传入,所以分别对应question和context。最终我们可以找到标注的答案在预处理之后的features里的位置:" - ] + ], + "metadata": { + "id": "jCtYEPCsYBVO" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "ZK2Modc7YBVO", - "outputId": "ad67bfcb-ad7f-4df1-9696-eda61d49e499" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "start_position: 23, end_position: 26\n" - ] - } - ], "source": [ "answers = example[\"answers\"]\n", "start_char = answers[\"answer_start\"][0]\n", @@ -824,75 +833,84 @@ " print(\"start_position: {}, end_position: {}\".format(start_position, end_position))\n", "else:\n", " print(\"The answer is not in this feature.\")" - ] + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "start_position: 23, end_position: 26\n" + ] + } + ], + "metadata": { + "id": "ZK2Modc7YBVO", + "outputId": "ad67bfcb-ad7f-4df1-9696-eda61d49e499" + } }, { "cell_type": "markdown", - "metadata": { - "id": "d5H2Rv7PYBVP" - }, "source": [ "我们需要对答案的位置进行验证,验证方式是:使用答案所在位置下标,取到对应的token ID,然后转化为文本,然后和原始答案进行但对比。" - ] + ], + "metadata": { + "id": "d5H2Rv7PYBVP" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "mdmegUoSYBVP", - "outputId": "b5979cb1-9fbb-4863-cccc-ad0c5d5ce760" - }, + "source": [ + "print(tokenizer.decode(tokenized_example[\"input_ids\"][0][start_position: end_position+1]))\n", + "print(answers[\"text\"][0])" + ], "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "over 1, 600\n", "over 1,600\n" ] } ], - "source": [ - "print(tokenizer.decode(tokenized_example[\"input_ids\"][0][start_position: end_position+1]))\n", - "print(answers[\"text\"][0])" - ] + "metadata": { + "id": "mdmegUoSYBVP", + "outputId": "b5979cb1-9fbb-4863-cccc-ad0c5d5ce760" + } }, { "cell_type": "markdown", - "metadata": { - "id": "9HKReju5YBVQ" - }, "source": [ "有时候question拼接context,而有时候是context拼接question,不同的模型有不同的要求,因此我们需要使用`padding_side`参数来指定。" - ] + ], + "metadata": { + "id": "9HKReju5YBVQ" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "DryiuiaMYBVQ" - }, - "outputs": [], "source": [ "pad_on_right = tokenizer.padding_side == \"right\" #context在右边" - ] + ], + "outputs": [], + "metadata": { + "id": "DryiuiaMYBVQ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "1BDIIhJmYBVR" - }, "source": [ "现在,把所有步骤合并到一起。对于context中无答案的情况,我们直接将标注的答案起始位置和结束位置放置在CLS的下标处。如果`allow_impossible_answers`这个参数是`False`的化,那这些无答案的样本都会被扔掉。为了简洁起见,我们先扔掉把。\n" - ] + ], + "metadata": { + "id": "1BDIIhJmYBVR" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "mObBCXnrYBVR" - }, - "outputs": [], "source": [ "def prepare_train_features(examples):\n", " # 既要对examples进行truncation(截断)和padding(补全)还要还要保留所有信息,所以要用的切片的方法。\n", @@ -965,87 +983,92 @@ " tokenized_examples[\"end_positions\"].append(token_end_index + 1)\n", "\n", " return tokenized_examples" - ] + ], + "outputs": [], + "metadata": { + "id": "mObBCXnrYBVR" + } }, { "cell_type": "markdown", - "metadata": { - "id": "0lm8ozrJIrJR" - }, "source": [ "以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。" - ] + ], + "metadata": { + "id": "0lm8ozrJIrJR" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "-b70jh26IrJS" - }, - "outputs": [], "source": [ "features = prepare_train_features(datasets['train'][:5])\n", "# 处理5个样本" - ] + ], + "outputs": [], + "metadata": { + "id": "-b70jh26IrJS" + } }, { "cell_type": "markdown", - "metadata": { - "id": "zS-6iXTkIrJT" - }, "source": [ "接下来对数据集datasets里面的所有样本进行预处理,处理的方式是使用`map`函数,将预处理函数`prepare_train_features`应用到(map)所有样本上。" - ] + ], + "metadata": { + "id": "zS-6iXTkIrJT" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "DDtsaJeVIrJT" - }, - "outputs": [], "source": [ "tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets[\"train\"].column_names)" - ] + ], + "outputs": [], + "metadata": { + "id": "DDtsaJeVIrJT" + } }, { "cell_type": "markdown", - "metadata": { - "id": "voWiw8C7IrJV" - }, "source": [ "更好的是,返回的结果会自动被缓存,避免下次处理的时候重新计算(但是也要注意,如果输入有改动,可能会被缓存影响!)。datasets库函数会对输入的参数进行检测,判断是否有变化,如果没有变化就使用缓存数据,如果有变化就重新处理。但如果输入参数不变,想改变输入的时候,最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外,上面使用到的`batched=True`这个参数是tokenizer的特点,以为这会使用多线程同时并行对输入进行处理。" - ] + ], + "metadata": { + "id": "voWiw8C7IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "545PP3o8IrJV" - }, "source": [ "## Fine-tuning微调模型" - ] + ], + "metadata": { + "id": "545PP3o8IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "FBiW8UpKIrJW" - }, "source": [ "目前,我们已经预处理好了训练/微调需要的数据,现在我们下载预训练的模型。由于我们要做的是机器问答任务,于是我们使用这个类`AutoModelForQuestionAnswering`。和tokenizer相似,model也是使用`from_pretrained`方法进行加载。\n" - ] + ], + "metadata": { + "id": "FBiW8UpKIrJW" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "TlqNaB8jIrJW", - "outputId": "84916cf3-6e6c-47f3-d081-032ec30a4132" - }, + "source": [ + "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\r\n", + "\r\n", + "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)" + ], "outputs": [ { - "name": "stderr", "output_type": "stream", + "name": "stderr", "text": [ "Downloading: 100%|██████████| 268M/268M [00:46<00:00, 5.79MB/s]\n", "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n", @@ -1056,241 +1079,240 @@ ] } ], - "source": [ - "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n", - "\n", - "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)" - ] + "metadata": { + "id": "TlqNaB8jIrJW", + "outputId": "84916cf3-6e6c-47f3-d081-032ec30a4132" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CczA5lJlIrJX" - }, "source": [ "由于我们微调的任务是机器问答任务,而我们加载的是预训练的语言模型,那么上面会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数(预训练语言模型的神经网络head被扔掉了,同时随机初始化了机器问答的神经网络head)。\n", "\n", "正因为有这些随机初始化的参数,所以我们要在新的数据集上重新fine-tune我们的模型。" - ] + ], + "metadata": { + "id": "CczA5lJlIrJX" + } }, { "cell_type": "markdown", - "metadata": { - "id": "_N8urzhyIrJY" - }, "source": [ "为了能够得到一个`Trainer`训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数[`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性。同时它需要一个文件夹的名字。这个文件夹会被用来保存模型和其他模型配置。" - ] + ], + "metadata": { + "id": "_N8urzhyIrJY" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "args = TrainingArguments(\r\n", + " f\"test-squad\",\r\n", + " evaluation_strategy = \"epoch\",\r\n", + " learning_rate=2e-5, #学习率\r\n", + " per_device_train_batch_size=batch_size,\r\n", + " per_device_eval_batch_size=batch_size,\r\n", + " num_train_epochs=3, # 训练的论次\r\n", + " weight_decay=0.01,\r\n", + ")" + ], + "outputs": [], "metadata": { "id": "Bliy8zgjIrJY" - }, - "outputs": [], - "source": [ - "args = TrainingArguments(\n", - " f\"test-squad\",\n", - " evaluation_strategy = \"epoch\",\n", - " learning_rate=2e-5, #学习率\n", - " per_device_train_batch_size=batch_size,\n", - " per_device_eval_batch_size=batch_size,\n", - " num_train_epochs=3, # 训练的论次\n", - " weight_decay=0.01,\n", - ")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "km3pGVdTIrJc" - }, "source": [ "上面`evaluation_strategy = \"epoch\"`参数告诉训练代码:我们每个epcoh会做一次验证评估。\n", "\n", "上面`batch_size`在这个notebook之前定义好了。" - ] + ], + "metadata": { + "id": "km3pGVdTIrJc" + } }, { "cell_type": "markdown", - "metadata": { - "id": "9G0oFbpTYBVX" - }, "source": [ "我们使用一个default_data_collator将预处理好的数据喂给模型。" - ] + ], + "metadata": { + "id": "9G0oFbpTYBVX" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "from transformers import default_data_collator\r\n", + "\r\n", + "data_collator = default_data_collator" + ], + "outputs": [], "metadata": { "id": "4nTsgJKRYBVX" - }, - "outputs": [], - "source": [ - "from transformers import default_data_collator\n", - "\n", - "data_collator = default_data_collator" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "rXuFTAzDIrJe" - }, "source": [ "训练的时候,我们将只会计算loss。根据评测指标评估模型将会放在下一节。\n", "\n", "只需要把模型,训练参数,数据,之前使用的tokenizer,和数据投递工具default_data_collator传入Tranier即可。" - ] + ], + "metadata": { + "id": "rXuFTAzDIrJe" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "trainer = Trainer(\r\n", + " model,\r\n", + " args,\r\n", + " train_dataset=tokenized_datasets[\"train\"],\r\n", + " eval_dataset=tokenized_datasets[\"validation\"],\r\n", + " data_collator=data_collator,\r\n", + " tokenizer=tokenizer,\r\n", + ")" + ], + "outputs": [], "metadata": { "id": "imY1oC3SIrJf" - }, - "outputs": [], - "source": [ - "trainer = Trainer(\n", - " model,\n", - " args,\n", - " train_dataset=tokenized_datasets[\"train\"],\n", - " eval_dataset=tokenized_datasets[\"validation\"],\n", - " data_collator=data_collator,\n", - " tokenizer=tokenizer,\n", - ")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "CdzABDVcIrJg" - }, "source": [ "调用`train`方法开始训练" - ] + ], + "metadata": { + "id": "CdzABDVcIrJg" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "fWZRjmU6r-RP" - }, - "outputs": [], "source": [ "trainer.train()" - ] + ], + "outputs": [], + "metadata": { + "id": "fWZRjmU6r-RP" + } }, { "cell_type": "markdown", - "metadata": { - "id": "d_G79pAyYBVY" - }, "source": [ "由于训练时间很长,如果是在本地mac训练,每个epcoh大约需要2消失,所以每次训练完保存以下模型。" - ] + ], + "metadata": { + "id": "d_G79pAyYBVY" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "qYP7NNo3YBVZ" - }, - "outputs": [], "source": [ "trainer.save_model(\"test-squad-trained\")" - ] + ], + "outputs": [], + "metadata": { + "id": "qYP7NNo3YBVZ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "RwFvugZTYBVZ" - }, "source": [ "## Evaluation评估" - ] + ], + "metadata": { + "id": "RwFvugZTYBVZ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "3fcMB_uGYBVZ" - }, "source": [ "模型评估会稍微优点复杂,我们需要将模型的输出后处理成我们需要的文本格式。模型本身预测的是answer所在start/end位置的logits。如果我们评估时喂入模型的是一个batch,那么输出如下:" - ] + ], + "metadata": { + "id": "3fcMB_uGYBVZ" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import torch\r\n", + "\r\n", + "for batch in trainer.get_eval_dataloader():\r\n", + " break\r\n", + "batch = {k: v.to(trainer.args.device) for k, v in batch.items()}\r\n", + "with torch.no_grad():\r\n", + " output = trainer.model(**batch)\r\n", + "output.keys()" + ], + "outputs": [], "metadata": { "id": "XzkOFb3uYBVZ" - }, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "for batch in trainer.get_eval_dataloader():\n", - " break\n", - "batch = {k: v.to(trainer.args.device) for k, v in batch.items()}\n", - "with torch.no_grad():\n", - " output = trainer.model(**batch)\n", - "output.keys()" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "G4McnKTSYBVa" - }, "source": [ "模型的输出是一个像dict的数据结构,包含了loss(因为提供了label,所有有loss),answer start和end的logits。我们在输出预测结果的时候不需要看loss,直接看logits就好了。" - ] + ], + "metadata": { + "id": "G4McnKTSYBVa" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "yNYbBFcvYBVa", - "outputId": "7f752049-0965-4484-96b2-d2326c6f5dcf" - }, + "source": [ + "output.start_logits.shape, output.end_logits.shape" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([16, 384]), torch.Size([16, 384]))" ] }, - "execution_count": 35, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 35 } ], - "source": [ - "output.start_logits.shape, output.end_logits.shape" - ] + "metadata": { + "id": "yNYbBFcvYBVa", + "outputId": "7f752049-0965-4484-96b2-d2326c6f5dcf" + } }, { "cell_type": "markdown", - "metadata": { - "id": "2qQ4y1cwYBVa" - }, "source": [ "每个feature里的每个token都会有一个logit。预测answer最简单的方法就是选择start的logits里最大的下标最为answer其实位置,end的logits里最大下标作为answer的结束位置。" - ] + ], + "metadata": { + "id": "2qQ4y1cwYBVa" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "sTZn5YtxYBVb", - "outputId": "8818a3c5-1eda-4926-9d37-af2a0124b59f" - }, + "source": [ + "output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1)" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "(tensor([ 46, 57, 78, 43, 118, 15, 72, 35, 15, 34, 73, 41, 80, 91,\n", @@ -1299,73 +1321,70 @@ " 158, 35], device='cuda:0'))" ] }, - "execution_count": 36, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 36 } ], - "source": [ - "output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1)" - ] + "metadata": { + "id": "sTZn5YtxYBVb", + "outputId": "8818a3c5-1eda-4926-9d37-af2a0124b59f" + } }, { "cell_type": "markdown", - "metadata": { - "id": "po1DsUQdYBVb" - }, "source": [ "以上策略大部分情况下都是不错的。但是,如果我们的输入告诉我们找不到答案:比如start的位置比end的位置下标大,或者start和end的位置指向了question。\n", "\n", "这个时候,简单的方法是我们继续需要选择第2好的预测作为我们的答案了,实在不行看第3好的预测,以此类推。\n", "\n", "由于上面的方法不太容易找到可行的答案,我们需要思考更合理的方法。我们将start和end的logits相加得到新的打分,然后去看最好的`n_best_size`个start和end对。从`n_best_size`个start和end对里推出相应的答案,然后检查答案是否有效,最后将他们按照打分进行怕苦,选择得分最高的作为答案。" - ] + ], + "metadata": { + "id": "po1DsUQdYBVb" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "ZBWVTvBTYBVb" - }, - "outputs": [], "source": [ "n_best_size = 20" - ] + ], + "outputs": [], + "metadata": { + "id": "ZBWVTvBTYBVb" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import numpy as np\r\n", + "\r\n", + "start_logits = output.start_logits[0].cpu().numpy()\r\n", + "end_logits = output.end_logits[0].cpu().numpy()\r\n", + "# 收集最佳的start和end logits的位置:\r\n", + "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + "valid_answers = []\r\n", + "for start_index in start_indexes:\r\n", + " for end_index in end_indexes:\r\n", + " if start_index <= end_index: # 如果start小雨end,那么合理的\r\n", + " valid_answers.append(\r\n", + " {\r\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\r\n", + " \"text\": \"\" # 后续需要根据token的下标将答案找出来\r\n", + " }\r\n", + " )" + ], + "outputs": [], "metadata": { "id": "h9BOWPEoYBVb" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "start_logits = output.start_logits[0].cpu().numpy()\n", - "end_logits = output.end_logits[0].cpu().numpy()\n", - "# 收集最佳的start和end logits的位置:\n", - "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - "valid_answers = []\n", - "for start_index in start_indexes:\n", - " for end_index in end_indexes:\n", - " if start_index <= end_index: # 如果start小雨end,那么合理的\n", - " valid_answers.append(\n", - " {\n", - " \"score\": start_logits[start_index] + end_logits[end_index],\n", - " \"text\": \"\" # 后续需要根据token的下标将答案找出来\n", - " }\n", - " )" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "Be7maAWsYBVc" - }, "source": [ "随后我们对根据`score`对`valid_answers`进行排序,找到最好的那一个。最后还剩一步是:检查start和end位置对应的文本是否在context里面而不是在question里面。\n", "\n", @@ -1374,70 +1393,105 @@ "- offset mapping: 将每个切片的tokens的位置映射会原始文本基于character的下标位置。\n", "\n", "所以我们又重新处理了以下validation验证集。和处理训练的时候的`prepare_train_features`稍有不同。\n" - ] + ], + "metadata": { + "id": "Be7maAWsYBVc" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "\r\n", + "def prepare_validation_features(examples):\r\n", + " # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results\r\n", + " # in one example possible giving several features when a context is long, each of those features having a\r\n", + " # context that overlaps a bit the context of the previous feature.\r\n", + " tokenized_examples = tokenizer(\r\n", + " examples[\"question\" if pad_on_right else \"context\"],\r\n", + " examples[\"context\" if pad_on_right else \"question\"],\r\n", + " truncation=\"only_second\" if pad_on_right else \"only_first\",\r\n", + " max_length=max_length,\r\n", + " stride=doc_stride,\r\n", + " return_overflowing_tokens=True,\r\n", + " return_offsets_mapping=True,\r\n", + " padding=\"max_length\",\r\n", + " )\r\n", + "\r\n", + " # Since one example might give us several features if it has a long context, we need a map from a feature to\r\n", + " # its corresponding example. This key gives us just that.\r\n", + " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\r\n", + "\r\n", + " # We keep the example_id that gave us this feature and we will store the offset mappings.\r\n", + " tokenized_examples[\"example_id\"] = []\r\n", + "\r\n", + " for i in range(len(tokenized_examples[\"input_ids\"])):\r\n", + " # Grab the sequence corresponding to that example (to know what is the context and what is the question).\r\n", + " sequence_ids = tokenized_examples.sequence_ids(i)\r\n", + " context_index = 1 if pad_on_right else 0\r\n", + "\r\n", + " # One example can give several spans, this is the index of the example containing this span of text.\r\n", + " sample_index = sample_mapping[i]\r\n", + " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\r\n", + "\r\n", + " # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token\r\n", + " # position is part of the context or not.\r\n", + " tokenized_examples[\"offset_mapping\"][i] = [\r\n", + " (o if sequence_ids[k] == context_index else None)\r\n", + " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\r\n", + " ]\r\n", + "\r\n", + " return tokenized_examples" + ], + "outputs": [], "metadata": { "id": "AJpleD_oYBVc" - }, - "outputs": [], - "source": [ - "\n", - "def prepare_validation_features(examples):\n", - " # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results\n", - " # in one example possible giving several features when a context is long, each of those features having a\n", - " # context that overlaps a bit the context of the previous feature.\n", - " tokenized_examples = tokenizer(\n", - " examples[\"question\" if pad_on_right else \"context\"],\n", - " examples[\"context\" if pad_on_right else \"question\"],\n", - " truncation=\"only_second\" if pad_on_right else \"only_first\",\n", - " max_length=max_length,\n", - " stride=doc_stride,\n", - " return_overflowing_tokens=True,\n", - " return_offsets_mapping=True,\n", - " padding=\"max_length\",\n", - " )\n", - "\n", - " # Since one example might give us several features if it has a long context, we need a map from a feature to\n", - " # its corresponding example. This key gives us just that.\n", - " sample_mapping = tokenized_examples.pop(\"overflow_to_sample_mapping\")\n", - "\n", - " # We keep the example_id that gave us this feature and we will store the offset mappings.\n", - " tokenized_examples[\"example_id\"] = []\n", - "\n", - " for i in range(len(tokenized_examples[\"input_ids\"])):\n", - " # Grab the sequence corresponding to that example (to know what is the context and what is the question).\n", - " sequence_ids = tokenized_examples.sequence_ids(i)\n", - " context_index = 1 if pad_on_right else 0\n", - "\n", - " # One example can give several spans, this is the index of the example containing this span of text.\n", - " sample_index = sample_mapping[i]\n", - " tokenized_examples[\"example_id\"].append(examples[\"id\"][sample_index])\n", - "\n", - " # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token\n", - " # position is part of the context or not.\n", - " tokenized_examples[\"offset_mapping\"][i] = [\n", - " (o if sequence_ids[k] == context_index else None)\n", - " for k, o in enumerate(tokenized_examples[\"offset_mapping\"][i])\n", - " ]\n", - "\n", - " return tokenized_examples" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "rmaTICcJYBVd" - }, "source": [ "和之前一样将`prepare_validation_features`函数应用到每个验证集合的样本上。" - ] + ], + "metadata": { + "id": "rmaTICcJYBVd" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "validation_features = datasets[\"validation\"].map(\r\n", + " prepare_validation_features,\r\n", + " batched=True,\r\n", + " remove_columns=datasets[\"validation\"].column_names\r\n", + ")" + ], + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))" + ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "32ba04d6240149f49eb48c8d8b7f9aae", + "version_major": 2, + "version_minor": 0 + } + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + } + ], "metadata": { "colab": { "referenced_widgets": [ @@ -1446,109 +1500,113 @@ }, "id": "xDfua4clYBVd", "outputId": "4789e3b2-52f0-4ca0-9d01-8c2b8e10a167" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "32ba04d6240149f49eb48c8d8b7f9aae", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "validation_features = datasets[\"validation\"].map(\n", - " prepare_validation_features,\n", - " batched=True,\n", - " remove_columns=datasets[\"validation\"].column_names\n", - ")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "dNECLWJLYBVe" - }, "source": [ "使用`Trainer.predict`方法获得所有预测结果" - ] + ], + "metadata": { + "id": "dNECLWJLYBVe" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "HTbFJ3FhYBVe" - }, - "outputs": [], "source": [ "raw_predictions = trainer.predict(validation_features)" - ] + ], + "outputs": [], + "metadata": { + "id": "HTbFJ3FhYBVe" + } }, { "cell_type": "markdown", - "metadata": { - "id": "QlTxqJN3YBVe" - }, "source": [ "这个 `Trainer` *隐藏了* 一些模型训练时候没有使用的属性(这里是 `example_id`和`offset_mapping`,后处理的时候会用到),所以我们需要把这些设置回来:" - ] + ], + "metadata": { + "id": "QlTxqJN3YBVe" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "nyQsRyqGYBVe" - }, - "outputs": [], "source": [ "validation_features.set_format(type=validation_features.format[\"type\"], columns=list(validation_features.features.keys()))" - ] + ], + "outputs": [], + "metadata": { + "id": "nyQsRyqGYBVe" + } }, { "cell_type": "markdown", - "metadata": { - "id": "Ac0BahcRYBVf" - }, "source": [ "当一个token位置对应着question部分时候,`prepare_validation_features`函数将offset mappings设定为`None`,所以我们根据offset mapping很容易可以鉴定token是否在context里面啦。我们同样也根绝扔掉了特别长的答案。" - ] + ], + "metadata": { + "id": "Ac0BahcRYBVf" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "dcqAk9GgYBVf" - }, - "outputs": [], "source": [ "max_answer_length = 30" - ] + ], + "outputs": [], + "metadata": { + "id": "dcqAk9GgYBVf" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "5wBRLoJoYBVf", - "outputId": "6e889849-7d40-4003-8d8d-546b8c6eb6a1" - }, + "source": [ + "start_logits = output.start_logits[0].cpu().numpy()\r\n", + "end_logits = output.end_logits[0].cpu().numpy()\r\n", + "offset_mapping = validation_features[0][\"offset_mapping\"]\r\n", + "# The first feature comes from the first example. For the more general case, we will need to be match the example_id to\r\n", + "# an example index\r\n", + "context = datasets[\"validation\"][0][\"context\"]\r\n", + "\r\n", + "# Gather the indices the best start/end logits:\r\n", + "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + "valid_answers = []\r\n", + "for start_index in start_indexes:\r\n", + " for end_index in end_indexes:\r\n", + " # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond\r\n", + " # to part of the input_ids that are not in the context.\r\n", + " if (\r\n", + " start_index >= len(offset_mapping)\r\n", + " or end_index >= len(offset_mapping)\r\n", + " or offset_mapping[start_index] is None\r\n", + " or offset_mapping[end_index] is None\r\n", + " ):\r\n", + " continue\r\n", + " # Don't consider answers with a length that is either < 0 or > max_answer_length.\r\n", + " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\r\n", + " continue\r\n", + " if start_index <= end_index: # We need to refine that test to check the answer is inside the context\r\n", + " start_char = offset_mapping[start_index][0]\r\n", + " end_char = offset_mapping[end_index][1]\r\n", + " valid_answers.append(\r\n", + " {\r\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\r\n", + " \"text\": context[start_char: end_char]\r\n", + " }\r\n", + " )\r\n", + "\r\n", + "valid_answers = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[:n_best_size]\r\n", + "valid_answers" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "[{'score': 16.706663, 'text': 'Denver Broncos'},\n", @@ -1587,246 +1645,236 @@ " 'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24'}]" ] }, - "execution_count": 44, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 44 } ], - "source": [ - "start_logits = output.start_logits[0].cpu().numpy()\n", - "end_logits = output.end_logits[0].cpu().numpy()\n", - "offset_mapping = validation_features[0][\"offset_mapping\"]\n", - "# The first feature comes from the first example. For the more general case, we will need to be match the example_id to\n", - "# an example index\n", - "context = datasets[\"validation\"][0][\"context\"]\n", - "\n", - "# Gather the indices the best start/end logits:\n", - "start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - "end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - "valid_answers = []\n", - "for start_index in start_indexes:\n", - " for end_index in end_indexes:\n", - " # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond\n", - " # to part of the input_ids that are not in the context.\n", - " if (\n", - " start_index >= len(offset_mapping)\n", - " or end_index >= len(offset_mapping)\n", - " or offset_mapping[start_index] is None\n", - " or offset_mapping[end_index] is None\n", - " ):\n", - " continue\n", - " # Don't consider answers with a length that is either < 0 or > max_answer_length.\n", - " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n", - " continue\n", - " if start_index <= end_index: # We need to refine that test to check the answer is inside the context\n", - " start_char = offset_mapping[start_index][0]\n", - " end_char = offset_mapping[end_index][1]\n", - " valid_answers.append(\n", - " {\n", - " \"score\": start_logits[start_index] + end_logits[end_index],\n", - " \"text\": context[start_char: end_char]\n", - " }\n", - " )\n", - "\n", - "valid_answers = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[:n_best_size]\n", - "valid_answers" - ] + "metadata": { + "id": "5wBRLoJoYBVf", + "outputId": "6e889849-7d40-4003-8d8d-546b8c6eb6a1" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "nLHf5tsBYBVg" - }, - "outputs": [], "source": [ "将预测答案和真实答案进行比较即可:" - ] + ], + "outputs": [], + "metadata": { + "id": "nLHf5tsBYBVg" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "MPqG0pyiYBVg", - "outputId": "fa1a7b51-09fb-4ce4-e858-456205dbdb31" - }, + "source": [ + "datasets[\"validation\"][0][\"answers\"]" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'answer_start': [177, 177, 177],\n", " 'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']}" ] }, - "execution_count": 45, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 45 } ], - "source": [ - "datasets[\"validation\"][0][\"answers\"]" - ] + "metadata": { + "id": "MPqG0pyiYBVg", + "outputId": "fa1a7b51-09fb-4ce4-e858-456205dbdb31" + } }, { "cell_type": "markdown", - "metadata": { - "id": "-Vl2BpcVYBVh" - }, "source": [ "可以看到模型做对了!\n", "\n", "如同上面的例子所言,由于第1个feature一定是来自于第1个example,所以相对容易。对于其他的fearures来说,我们需要一个features和examples的一个映射map。同样,由于一个example可能被切片成多个features,所以我们也需要将所有features里的答案全部手机起来。以下的代码就将exmaple的下标和features的下标进行map映射。" - ] + ], + "metadata": { + "id": "-Vl2BpcVYBVh" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import collections\r\n", + "\r\n", + "examples = datasets[\"validation\"]\r\n", + "features = validation_features\r\n", + "\r\n", + "example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\r\n", + "features_per_example = collections.defaultdict(list)\r\n", + "for i, feature in enumerate(features):\r\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)" + ], + "outputs": [], "metadata": { "id": "VgOFWhElYBVh" - }, - "outputs": [], - "source": [ - "import collections\n", - "\n", - "examples = datasets[\"validation\"]\n", - "features = validation_features\n", - "\n", - "example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", - "features_per_example = collections.defaultdict(list)\n", - "for i, feature in enumerate(features):\n", - " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "nXJ-THjzYBVh" - }, "source": [ "对于后处理过程基本上已经全部完成了。最后一点事情是如何解决无答案的情况(squad_v2=True的时候)。以上的代码都只考虑了context里面的asnwers,所以我们同样需要将无答案的预测得分进行搜集(无答案的预测对应的CLSt oken的start和end)。如果一个example样本又多个features,那么我们还需要在多个features里预测是不是都无答案。所以无答案的最终得分是所有features的无答案得分最小的那个。\n", "\n", "只要无答案的最终得分高于其他所有答案的得分,那么该问题就是无答案。\n", "\n", "把所有事情都合并起来:" - ] + ], + "metadata": { + "id": "nXJ-THjzYBVh" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "from tqdm.auto import tqdm\r\n", + "\r\n", + "def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):\r\n", + " all_start_logits, all_end_logits = raw_predictions\r\n", + " # Build a map example to its corresponding features.\r\n", + " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\r\n", + " features_per_example = collections.defaultdict(list)\r\n", + " for i, feature in enumerate(features):\r\n", + " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\r\n", + "\r\n", + " # The dictionaries we have to fill.\r\n", + " predictions = collections.OrderedDict()\r\n", + "\r\n", + " # Logging.\r\n", + " print(f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\")\r\n", + "\r\n", + " # Let's loop over all the examples!\r\n", + " for example_index, example in enumerate(tqdm(examples)):\r\n", + " # Those are the indices of the features associated to the current example.\r\n", + " feature_indices = features_per_example[example_index]\r\n", + "\r\n", + " min_null_score = None # Only used if squad_v2 is True.\r\n", + " valid_answers = []\r\n", + " \r\n", + " context = example[\"context\"]\r\n", + " # Looping through all the features associated to the current example.\r\n", + " for feature_index in feature_indices:\r\n", + " # We grab the predictions of the model for this feature.\r\n", + " start_logits = all_start_logits[feature_index]\r\n", + " end_logits = all_end_logits[feature_index]\r\n", + " # This is what will allow us to map some the positions in our logits to span of texts in the original\r\n", + " # context.\r\n", + " offset_mapping = features[feature_index][\"offset_mapping\"]\r\n", + "\r\n", + " # Update minimum null prediction.\r\n", + " cls_index = features[feature_index][\"input_ids\"].index(tokenizer.cls_token_id)\r\n", + " feature_null_score = start_logits[cls_index] + end_logits[cls_index]\r\n", + " if min_null_score is None or min_null_score < feature_null_score:\r\n", + " min_null_score = feature_null_score\r\n", + "\r\n", + " # Go through all possibilities for the `n_best_size` greater start and end logits.\r\n", + " start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\r\n", + " for start_index in start_indexes:\r\n", + " for end_index in end_indexes:\r\n", + " # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond\r\n", + " # to part of the input_ids that are not in the context.\r\n", + " if (\r\n", + " start_index >= len(offset_mapping)\r\n", + " or end_index >= len(offset_mapping)\r\n", + " or offset_mapping[start_index] is None\r\n", + " or offset_mapping[end_index] is None\r\n", + " ):\r\n", + " continue\r\n", + " # Don't consider answers with a length that is either < 0 or > max_answer_length.\r\n", + " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\r\n", + " continue\r\n", + "\r\n", + " start_char = offset_mapping[start_index][0]\r\n", + " end_char = offset_mapping[end_index][1]\r\n", + " valid_answers.append(\r\n", + " {\r\n", + " \"score\": start_logits[start_index] + end_logits[end_index],\r\n", + " \"text\": context[start_char: end_char]\r\n", + " }\r\n", + " )\r\n", + " \r\n", + " if len(valid_answers) > 0:\r\n", + " best_answer = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[0]\r\n", + " else:\r\n", + " # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid\r\n", + " # failure.\r\n", + " best_answer = {\"text\": \"\", \"score\": 0.0}\r\n", + " \r\n", + " # Let's pick our final answer: the best one or the null answer (only for squad_v2)\r\n", + " if not squad_v2:\r\n", + " predictions[example[\"id\"]] = best_answer[\"text\"]\r\n", + " else:\r\n", + " answer = best_answer[\"text\"] if best_answer[\"score\"] > min_null_score else \"\"\r\n", + " predictions[example[\"id\"]] = answer\r\n", + "\r\n", + " return predictions" + ], + "outputs": [], "metadata": { "id": "00SHF2PzYBVh" - }, - "outputs": [], - "source": [ - "from tqdm.auto import tqdm\n", - "\n", - "def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):\n", - " all_start_logits, all_end_logits = raw_predictions\n", - " # Build a map example to its corresponding features.\n", - " example_id_to_index = {k: i for i, k in enumerate(examples[\"id\"])}\n", - " features_per_example = collections.defaultdict(list)\n", - " for i, feature in enumerate(features):\n", - " features_per_example[example_id_to_index[feature[\"example_id\"]]].append(i)\n", - "\n", - " # The dictionaries we have to fill.\n", - " predictions = collections.OrderedDict()\n", - "\n", - " # Logging.\n", - " print(f\"Post-processing {len(examples)} example predictions split into {len(features)} features.\")\n", - "\n", - " # Let's loop over all the examples!\n", - " for example_index, example in enumerate(tqdm(examples)):\n", - " # Those are the indices of the features associated to the current example.\n", - " feature_indices = features_per_example[example_index]\n", - "\n", - " min_null_score = None # Only used if squad_v2 is True.\n", - " valid_answers = []\n", - " \n", - " context = example[\"context\"]\n", - " # Looping through all the features associated to the current example.\n", - " for feature_index in feature_indices:\n", - " # We grab the predictions of the model for this feature.\n", - " start_logits = all_start_logits[feature_index]\n", - " end_logits = all_end_logits[feature_index]\n", - " # This is what will allow us to map some the positions in our logits to span of texts in the original\n", - " # context.\n", - " offset_mapping = features[feature_index][\"offset_mapping\"]\n", - "\n", - " # Update minimum null prediction.\n", - " cls_index = features[feature_index][\"input_ids\"].index(tokenizer.cls_token_id)\n", - " feature_null_score = start_logits[cls_index] + end_logits[cls_index]\n", - " if min_null_score is None or min_null_score < feature_null_score:\n", - " min_null_score = feature_null_score\n", - "\n", - " # Go through all possibilities for the `n_best_size` greater start and end logits.\n", - " start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - " end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()\n", - " for start_index in start_indexes:\n", - " for end_index in end_indexes:\n", - " # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond\n", - " # to part of the input_ids that are not in the context.\n", - " if (\n", - " start_index >= len(offset_mapping)\n", - " or end_index >= len(offset_mapping)\n", - " or offset_mapping[start_index] is None\n", - " or offset_mapping[end_index] is None\n", - " ):\n", - " continue\n", - " # Don't consider answers with a length that is either < 0 or > max_answer_length.\n", - " if end_index < start_index or end_index - start_index + 1 > max_answer_length:\n", - " continue\n", - "\n", - " start_char = offset_mapping[start_index][0]\n", - " end_char = offset_mapping[end_index][1]\n", - " valid_answers.append(\n", - " {\n", - " \"score\": start_logits[start_index] + end_logits[end_index],\n", - " \"text\": context[start_char: end_char]\n", - " }\n", - " )\n", - " \n", - " if len(valid_answers) > 0:\n", - " best_answer = sorted(valid_answers, key=lambda x: x[\"score\"], reverse=True)[0]\n", - " else:\n", - " # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid\n", - " # failure.\n", - " best_answer = {\"text\": \"\", \"score\": 0.0}\n", - " \n", - " # Let's pick our final answer: the best one or the null answer (only for squad_v2)\n", - " if not squad_v2:\n", - " predictions[example[\"id\"]] = best_answer[\"text\"]\n", - " else:\n", - " answer = best_answer[\"text\"] if best_answer[\"score\"] > min_null_score else \"\"\n", - " predictions[example[\"id\"]] = answer\n", - "\n", - " return predictions" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "sfpxu9j3YBV-" - }, "source": [ "将后处理函数应用到原始预测上:" - ] - }, - { - "cell_type": "code", - "execution_count": null, + ], "metadata": { - "id": "LcQe1dCnYBV-" - }, - "outputs": [], - "source": [] + "id": "sfpxu9j3YBV-" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "final_predictions = postprocess_qa_predictions(datasets[\"validation\"], validation_features, raw_predictions.predictions)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Post-processing 10570 example predictions split into 10784 features.\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))" + ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "347ebed36d3541388e4e821372e91aa4", + "version_major": 2, + "version_minor": 0 + } + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n" + ] + } + ], "metadata": { "colab": { "referenced_widgets": [ @@ -1835,119 +1883,101 @@ }, "id": "Df4vY9d1YBV_", "outputId": "026516bc-d0e4-439a-b77a-daf738f61aa1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Post-processing 10570 example predictions split into 10784 features.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "347ebed36d3541388e4e821372e91aa4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=10570.0), HTML(value='')))" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "final_predictions = postprocess_qa_predictions(datasets[\"validation\"], validation_features, raw_predictions.predictions)" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "AoRhXyiPYBV_" - }, "source": [ "然后我们加载评测指标:" - ] + ], + "metadata": { + "id": "AoRhXyiPYBV_" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "LYNJLOAbYBV_" - }, - "outputs": [], "source": [ "metric = load_metric(\"squad_v2\" if squad_v2 else \"squad\")" - ] + ], + "outputs": [], + "metadata": { + "id": "LYNJLOAbYBV_" + } }, { "cell_type": "markdown", - "metadata": { - "id": "HckepCtJYBWA" - }, "source": [ - "然后我们基于预测和标注对评测指标进行计算。为了合理的比较,我们需要将预测和标注的格式。对于squad2来说,评测指标还需要`no_answer_probability`参数(由于已经无答案直接设置成了空字符串,所以这里直接将这个参数设置为0.0)" - ] + "同理,也可以使用我们提供的本地脚本来加载:" + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, + "source": [ + "metric_path = './dataset/metrics/squad.py'\r\n", + "metric = load_metric(metric_path)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "然后我们基于预测和标注对评测指标进行计算。为了合理的比较,我们需要将预测和标注的格式。对于squad2来说,评测指标还需要`no_answer_probability`参数(由于已经无答案直接设置成了空字符串,所以这里直接将这个参数设置为0.0)" + ], "metadata": { - "id": "k4y-LM_cYBWA", - "outputId": "a122acf7-203c-4eb3-d26b-99b05b78a2df" - }, + "id": "HckepCtJYBWA" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "if squad_v2:\r\n", + " formatted_predictions = [{\"id\": k, \"prediction_text\": v, \"no_answer_probability\": 0.0} for k, v in predictions.items()]\r\n", + "else:\r\n", + " formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in final_predictions.items()]\r\n", + "references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\r\n", + "metric.compute(predictions=formatted_predictions, references=references)" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'exact_match': 76.74550614947965, 'f1': 85.13412652023338}" ] }, - "execution_count": 50, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 50 } ], - "source": [ - "if squad_v2:\n", - " formatted_predictions = [{\"id\": k, \"prediction_text\": v, \"no_answer_probability\": 0.0} for k, v in predictions.items()]\n", - "else:\n", - " formatted_predictions = [{\"id\": k, \"prediction_text\": v} for k, v in final_predictions.items()]\n", - "references = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in datasets[\"validation\"]]\n", - "metric.compute(predictions=formatted_predictions, references=references)" - ] + "metadata": { + "id": "k4y-LM_cYBWA", + "outputId": "a122acf7-203c-4eb3-d26b-99b05b78a2df" + } }, { "cell_type": "markdown", - "metadata": { - "id": "exnxfrEKYBWA" - }, "source": [ "最后别忘了,[查看如何上传模型](https://huggingface.co/transformers/model_sharing.html) ,上传模型到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样,直接用名字就能使用您的模型啦。" - ] + ], + "metadata": { + "id": "exnxfrEKYBWA" + } }, { "cell_type": "code", "execution_count": null, + "source": [], + "outputs": [], "metadata": { "id": "uAYJ1DnfYBWA" - }, - "outputs": [], - "source": [] + } } ], "metadata": { @@ -1976,5 +2006,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 2 } \ No newline at end of file diff --git a/docs/篇章4-使用Transformers解决NLP任务/4.7-生成任务-摘要生成.ipynb b/docs/篇章4-使用Transformers解决NLP任务/4.7-生成任务-摘要生成.ipynb index 8127a0c..0e3b97b 100644 --- a/docs/篇章4-使用Transformers解决NLP任务/4.7-生成任务-摘要生成.ipynb +++ b/docs/篇章4-使用Transformers解决NLP任务/4.7-生成任务-摘要生成.ipynb @@ -2,128 +2,153 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "X4cRE8IbIrIV" - }, "source": [ "本文涉及的jupter notebook在[篇章4代码库中](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1)。\n", "\n", "建议直接使用google colab notebook打开本教程,可以快速下载相关数据集和模型。\n", "如果您正在google的colab中打开这个notebook,您可能需要安装Transformers和🤗Datasets库。将以下命令取消注释即可安装。" - ] + ], + "metadata": { + "id": "X4cRE8IbIrIV" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "MOsHUjgdIrIW" - }, - "outputs": [], "source": [ "! pip install datasets transformers rouge-score nltk" - ] + ], + "outputs": [], + "metadata": { + "id": "MOsHUjgdIrIW" + } }, { "cell_type": "markdown", - "metadata": { - "id": "HFASsisvIrIb" - }, "source": [ "分布式训练请查看 [这里](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)." - ] + ], + "metadata": { + "id": "HFASsisvIrIb" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rEJBSTyZIrIb" - }, "source": [ "微调transformer模型解决摘要生成任务" - ] + ], + "metadata": { + "id": "rEJBSTyZIrIb" + } }, { "cell_type": "markdown", - "metadata": { - "id": "kTCFado4IrIc" - }, "source": [ "在本notebook中,我们将展示如何微调 [🤗 Transformers](https://github.com/huggingface/transformers)中的预训练模型来解决摘要生成任务。我们使用[XSum dataset](https://arxiv.org/pdf/1808.08745.pdf)数据集。这个数据集包含了BBC的文章和一句对应的摘要。下面是一个例子:\n", "\n", "![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/master/examples/images/summarization.png?raw=1)\n", "\n", "对于摘要生成任务,我们将展示如何使用简单的加载数据集,同时针对相应的仍无使用transformer中的Trainer接口对模型进行微调。" - ] + ], + "metadata": { + "id": "kTCFado4IrIc" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "iO8WOo1hdJT0" - }, - "outputs": [], "source": [ "model_checkpoint = \"t5-small\"" - ] + ], + "outputs": [], + "metadata": { + "id": "iO8WOo1hdJT0" + } }, { "cell_type": "markdown", - "metadata": { - "id": "4RRkXuteIrIh" - }, "source": [ "只要预训练的transformer模型包含seq2seq结构的head层,那么本notebook理论上可以使用各种各样的transformer模型,解决任何摘要生成任务。这里,我们使用[`t5-small`](https://huggingface.co/t5-small)模型checkpoint。\n" - ] + ], + "metadata": { + "id": "4RRkXuteIrIh" + } }, { "cell_type": "markdown", - "metadata": { - "id": "whPRbBNbIrIl" - }, "source": [ "## 加载数据" - ] + ], + "metadata": { + "id": "whPRbBNbIrIl" + } }, { "cell_type": "markdown", - "metadata": { - "id": "W7QYTpxXIrIl" - }, "source": [ "我们将会使用[🤗 Datasets](https://github.com/huggingface/datasets)库来加载数据和对应的评测方式。数据加载和评测方式加载只需要简单使用load_dataset和load_metric即可。\n" - ] + ], + "metadata": { + "id": "W7QYTpxXIrIl" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "from datasets import load_dataset, load_metric\r\n", + "\r\n", + "raw_datasets = load_dataset(\"xsum\")\r\n", + "metric = load_metric(\"rouge\")" + ], + "outputs": [], "metadata": { "id": "IreSlFmlIrIm" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset, load_metric\n", - "\n", - "raw_datasets = load_dataset(\"xsum\")\n", - "metric = load_metric(\"rouge\")" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "RzfPtOMoIrIu" - }, "source": [ - "这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。" - ] + "除此之外,你也可以从我们提供的[链接](https://gas.graviti.cn/dataset/datawhale/Xsum)下载数据并解压,将解压后的json文件和`bbc-summary-data`文件夹复制到`docs/篇章4-使用Transformers解决NLP任务/datasets/xsum`目录下,然后用下面的代码进行加载。同理,也可以使用我们提供的脚本加载评测方式`rouge`。" + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, + "source": [ + "import os\r\n", + "\r\n", + "data_path = './dataset/xsum/'\r\n", + "path = os.path.join(data_path, 'xsum.py')\r\n", + "cache_dir = os.path.join(data_path, 'cache')\r\n", + "data_files = {'data': data_path}\r\n", + "dataset = load_dataset(path, data_files=data_files, cache_dir=cache_dir)\r\n", + "\r\n", + "metric_path = './dataset/metrics/rouge.py'\r\n", + "metric = load_metric(metric_path)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "这个datasets对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集,只需要使用对应的key(train,validation,test)即可得到相应的数据。" + ], "metadata": { - "id": "GWiVUF0jIrIv", - "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" - }, + "id": "RzfPtOMoIrIu" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "raw_datasets" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "DatasetDict({\n", @@ -142,35 +167,35 @@ "})" ] }, - "execution_count": 4, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 4 } ], - "source": [ - "raw_datasets" - ] + "metadata": { + "id": "GWiVUF0jIrIv", + "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" + } }, { "cell_type": "markdown", - "metadata": { - "id": "u3EtYfeHIrIz" - }, "source": [ "给定一个数据切分的key(train、validation或者test)和下标即可查看数据:" - ] + ], + "metadata": { + "id": "u3EtYfeHIrIz" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "X6HrpprwIrIz", - "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" - }, + "source": [ + "raw_datasets[\"train\"][0]" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'document': 'Recent reports have linked some France-based players with returns to Wales.\\n\"I\\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\\'d rather spend that money on keeping players in Wales,\" said Davies.\\nThe WRU provides £2m to the fund and £1.3m comes from the regions.\\nFormer Wales and British and Irish Lions fly-half Davies became WRU chairman on Tuesday 21 October, succeeding deposed David Pickering following governing body elections.\\nHe is now serving a notice period to leave his role as Newport Gwent Dragons chief executive after being voted on to the WRU board in September.\\nDavies was among the leading figures among Dragons, Ospreys, Scarlets and Cardiff Blues officials who were embroiled in a protracted dispute with the WRU that ended in a £60m deal in August this year.\\nIn the wake of that deal being done, Davies said the £3.3m should be spent on ensuring current Wales-based stars remain there.\\nIn recent weeks, Racing Metro flanker Dan Lydiate was linked with returning to Wales.\\nLikewise the Paris club\\'s scrum-half Mike Phillips and centre Jamie Roberts were also touted for possible returns.\\nWales coach Warren Gatland has said: \"We haven\\'t instigated contact with the players.\\n\"But we are aware that one or two of them are keen to return to Wales sooner rather than later.\"\\nSpeaking to Scrum V on BBC Radio Wales, Davies re-iterated his stance, saying keeping players such as Scarlets full-back Liam Williams and Ospreys flanker Justin Tipuric in Wales should take precedence.\\n\"It\\'s obviously a limited amount of money [available]. The union are contributing 60% of that contract and the regions are putting £1.3m in.\\n\"So it\\'s a total pot of just over £3m and if you look at the sorts of salaries that the... guys... have been tempted to go overseas for [are] significant amounts of money.\\n\"So if we were to bring the players back, we\\'d probably get five or six players.\\n\"And I\\'ve always felt - and this is with my rugby hat on now; this is not region or WRU - I\\'d rather spend that money on keeping players in Wales.\\n\"There are players coming out of contract, perhaps in the next year or so… you\\'re looking at your Liam Williams\\' of the world; Justin Tipuric for example - we need to keep these guys in Wales.\\n\"We actually want them there. They are the ones who are going to impress the young kids, for example.\\n\"They are the sort of heroes that our young kids want to emulate.\\n\"So I would start off [by saying] with the limited pot of money, we have to retain players in Wales.\\n\"Now, if that can be done and there\\'s some spare monies available at the end, yes, let\\'s look to bring players back.\\n\"But it\\'s a cruel world, isn\\'t it?\\n\"It\\'s fine to take the buck and go, but great if you can get them back as well, provided there\\'s enough money.\"\\nBritish and Irish Lions centre Roberts has insisted he will see out his Racing Metro contract.\\nHe and Phillips also earlier dismissed the idea of leaving Paris.\\nRoberts also admitted being hurt by comments in French Newspaper L\\'Equipe attributed to Racing Coach Laurent Labit questioning their effectiveness.\\nCentre Roberts and flanker Lydiate joined Racing ahead of the 2013-14 season while scrum-half Phillips moved there in December 2013 after being dismissed for disciplinary reasons by former club Bayonne.',\n", @@ -178,33 +203,29 @@ " 'summary': 'New Welsh Rugby Union chairman Gareth Davies believes a joint £3.3m WRU-regions fund should be used to retain home-based talent such as Liam Williams, not bring back exiled stars.'}" ] }, - "execution_count": 5, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 5 } ], - "source": [ - "raw_datasets[\"train\"][0]" - ] + "metadata": { + "id": "X6HrpprwIrIz", + "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" + } }, { "cell_type": "markdown", - "metadata": { - "id": "WHUmphG3IrI3" - }, "source": [ "为了能够进一步理解数据长什么样子,下面的函数将从数据集里随机选择几个例子进行展示。" - ] + ], + "metadata": { + "id": "WHUmphG3IrI3" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "i3j8APAoIrI3" - }, - "outputs": [], "source": [ "import datasets\n", "import random\n", @@ -225,22 +246,25 @@ " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" - ] + ], + "outputs": [], + "metadata": { + "id": "i3j8APAoIrI3" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 791 - }, - "id": "SZy5tRB_IrI7", - "outputId": "17194536-108f-4dc2-f5c2-dccf93c0164e" - }, + "source": [ + "show_random_elements(raw_datasets[\"train\"])" + ], "outputs": [ { + "output_type": "display_data", "data": { + "text/plain": [ + "" + ], "text/html": [ "\n", " \n", @@ -266,42 +290,40 @@ " \n", " \n", "
" - ], - "text/plain": [ - "" ] }, "metadata": { "tags": [] - }, - "output_type": "display_data" + } } ], - "source": [ - "show_random_elements(raw_datasets[\"train\"])" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 791 + }, + "id": "SZy5tRB_IrI7", + "outputId": "17194536-108f-4dc2-f5c2-dccf93c0164e" + } }, { "cell_type": "markdown", - "metadata": { - "id": "lnjDIuQ3IrI-" - }, "source": [ "metric是[`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)类的一个实例,查看metric和使用的例子:" - ] + ], + "metadata": { + "id": "lnjDIuQ3IrI-" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5o4rUteaIrI_", - "outputId": "63ba0150-f748-4cc0-fda1-dada85bcdf79" - }, + "source": [ + "metric" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "Metric(name: \"rouge\", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: \"\"\"\n", @@ -340,38 +362,40 @@ "\"\"\", stored examples: 0)" ] }, - "execution_count": 8, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 8 } ], - "source": [ - "metric" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jAWdqcUBIrJC" - }, - "source": [ - "我们使用`compute`方法来对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:" - ] - }, - { - "cell_type": "code", - "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "6XN1Rq0aIrJC", - "outputId": "56fc8845-3d36-4490-b281-54096bf82ad4" - }, + "id": "5o4rUteaIrI_", + "outputId": "63ba0150-f748-4cc0-fda1-dada85bcdf79" + } + }, + { + "cell_type": "markdown", + "source": [ + "我们使用`compute`方法来对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:" + ], + "metadata": { + "id": "jAWdqcUBIrJC" + } + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "fake_preds = [\"hello there\", \"general kenobi\"]\n", + "fake_labels = [\"hello there\", \"general kenobi\"]\n", + "metric.compute(predictions=fake_preds, references=fake_labels)" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),\n", @@ -380,33 +404,31 @@ " 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}" ] }, - "execution_count": 9, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 9 } ], - "source": [ - "fake_preds = [\"hello there\", \"general kenobi\"]\n", - "fake_labels = [\"hello there\", \"general kenobi\"]\n", - "metric.compute(predictions=fake_preds, references=fake_labels)" - ] + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6XN1Rq0aIrJC", + "outputId": "56fc8845-3d36-4490-b281-54096bf82ad4" + } }, { "cell_type": "markdown", - "metadata": { - "id": "n9qywopnIrJH" - }, "source": [ "## 数据预处理" - ] + ], + "metadata": { + "id": "n9qywopnIrJH" + } }, { "cell_type": "markdown", - "metadata": { - "id": "YVx71GdAIrJH" - }, "source": [ "在将数据喂入模型之前,我们需要对数据进行预处理。预处理的工具叫Tokenizer。Tokenizer首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。\n", "\n", @@ -417,175 +439,174 @@ "\n", "\n", "这个被下载的tokens vocabulary会被缓存起来,从而再次使用的时候不会重新下载。" - ] + ], + "metadata": { + "id": "YVx71GdAIrJH" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "eXNLu_-nIrJI" - }, - "outputs": [], "source": [ "from transformers import AutoTokenizer\n", " \n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" - ] + ], + "outputs": [], + "metadata": { + "id": "eXNLu_-nIrJI" + } }, { "cell_type": "markdown", - "metadata": { - "id": "Vl6IidfdIrJK" - }, "source": [ "By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library." - ] + ], + "metadata": { + "id": "Vl6IidfdIrJK" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rowT4iCLIrJK" - }, "source": [ "tokenizer既可以对单个文本进行预处理,也可以对一对文本进行预处理,tokenizer预处理后得到的数据满足预训练模型输入格式" - ] + ], + "metadata": { + "id": "rowT4iCLIrJK" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "a5hBlsrHIrJL", - "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90" - }, + "source": [ + "tokenizer(\"Hello, this one sentence!\")" + ], "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}" ] }, - "execution_count": 11, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 11 } ], - "source": [ - "tokenizer(\"Hello, this one sentence!\")" - ] + "metadata": { + "id": "a5hBlsrHIrJL", + "outputId": "acdaa98a-a8cd-4a20-89b8-cc26437bbe90" + } }, { "cell_type": "markdown", - "metadata": { - "id": "qo_0B1M2IrJM" - }, "source": [ "上面看到的token IDs也就是input_ids一般来说随着预训练模型名字的不同而有所不同。原因是不同的预训练模型在预训练的时候设定了不同的规则。但只要tokenizer和model的名字一致,那么tokenizer预处理的输入格式就会满足model需求的。关于预处理更多内容参考[这个教程](https://huggingface.co/transformers/preprocessing.html)\n", "\n", "除了可以tokenize一句话,我们也可以tokenize一个list的句子。" - ] + ], + "metadata": { + "id": "qo_0B1M2IrJM" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 11 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0YFUpXROdJT6", "outputId": "3659a781-f0d5-48db-9331-34df0452a5a8" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}" - ] - }, - "execution_count": 11, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"])" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "skFBG0iBdJT7" - }, "source": [ "注意:为了给模型准备好翻译的targets,我们使用`as_target_tokenizer`来控制targets所对应的特殊token:" - ] + ], + "metadata": { + "id": "skFBG0iBdJT7" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "with tokenizer.as_target_tokenizer():\n", + " print(tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"]))" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n" + ] + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "E2VZkR_XdJT7", "outputId": "12732226-b243-400e-fdfc-7b137c884438" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}\n" - ] - } - ], - "source": [ - "with tokenizer.as_target_tokenizer():\n", - " print(tokenizer([\"Hello, this one sentence!\", \"This is another sentence.\"]))" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "2C0hcmp9IrJQ" - }, "source": [ "如果您使用的是T5预训练模型的checkpoints,需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务,具体前缀例子如下:\n" - ] + ], + "metadata": { + "id": "2C0hcmp9IrJQ" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "LFntjGh8dJT7" - }, - "outputs": [], "source": [ "if model_checkpoint in [\"t5-small\", \"t5-base\", \"t5-larg\", \"t5-3b\", \"t5-11b\"]:\n", " prefix = \"summarize: \"\n", "else:\n", " prefix = \"\"" - ] + ], + "outputs": [], + "metadata": { + "id": "LFntjGh8dJT7" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rZMGxnr5dJT8" - }, "source": [ "现在我们可以把所有内容放在一起组成我们的预处理函数了。我们对样本进行预处理的时候,我们还会`truncation=True`这个参数来确保我们超长的句子被截断。默认情况下,对与比较短的句子我们会自动padding。" - ] + ], + "metadata": { + "id": "rZMGxnr5dJT8" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "vc0BSBLIIrJQ" - }, - "outputs": [], "source": [ "max_input_length = 1024\n", "max_target_length = 128\n", @@ -600,131 +621,131 @@ "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" - ] + ], + "outputs": [], + "metadata": { + "id": "vc0BSBLIIrJQ" + } }, { "cell_type": "markdown", - "metadata": { - "id": "0lm8ozrJIrJR" - }, "source": [ "以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。" - ] + ], + "metadata": { + "id": "0lm8ozrJIrJR" + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "preprocess_function(raw_datasets['train'][:2])" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8, 7178, 13, 24, 1154, 271, 612, 6, 28571, 243, 8, 3996, 19660, 51, 225, 36, 1869, 30, 3, 5833, 750, 10256, 18, 390, 4811, 2367, 132, 5, 86, 1100, 1274, 6, 16046, 10730, 24397, 49, 2744, 31914, 17, 15, 47, 5229, 28, 7646, 12, 10256, 5, 3, 21322, 8, 1919, 1886, 31, 7, 14667, 440, 18, 17114, 4794, 16202, 7, 11, 2050, 17845, 2715, 7, 130, 92, 2633, 26, 21, 487, 5146, 5, 10256, 3763, 16700, 2776, 17, 40, 232, 65, 243, 10, 96, 1326, 43, 29, 31, 17, 16, 7, 2880, 920, 574, 28, 8, 1508, 5, 96, 11836, 62, 33, 2718, 24, 80, 42, 192, 13, 135, 33, 9805, 12, 1205, 12, 10256, 14159, 1066, 145, 865, 535, 14734, 12, 4712, 2781, 584, 30, 9938, 5061, 10256, 6, 28571, 3, 60, 18, 155, 15, 4094, 112, 3, 8389, 6, 2145, 2627, 1508, 224, 38, 14586, 7677, 423, 18, 1549, 1414, 265, 6060, 11, 411, 7, 2026, 63, 7, 24397, 49, 12446, 2262, 3791, 447, 16, 10256, 225, 240, 20799, 1433, 5, 96, 196, 17, 31, 7, 6865, 3, 9, 1643, 866, 13, 540, 784, 28843, 4275, 37, 7021, 33, 12932, 15436, 13, 24, 1696, 11, 8, 6266, 33, 3, 3131, 3996, 13606, 51, 16, 5, 96, 5231, 34, 31, 7, 3, 9, 792, 815, 13, 131, 147, 23395, 51, 11, 3, 99, 25, 320, 44, 8, 10549, 13, 21331, 24, 8, 233, 3413, 233, 43, 118, 3, 22765, 12, 281, 10055, 21, 784, 355, 908, 1516, 6201, 13, 540, 5, 96, 5231, 3, 99, 62, 130, 12, 830, 8, 1508, 223, 6, 62, 31, 26, 1077, 129, 874, 42, 1296, 1508, 5, 96, 7175, 27, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 5, 96, 7238, 33, 1508, 1107, 91, 13, 1696, 6, 2361, 16, 8, 416, 215, 42, 78, 233, 25, 31, 60, 479, 44, 39, 1414, 265, 6060, 31, 13, 8, 296, 117, 12446, 2262, 3791, 447, 21, 677, 3, 18, 62, 174, 12, 453, 175, 3413, 16, 10256, 5, 96, 1326, 700, 241, 135, 132, 5, 328, 33, 8, 2102, 113, 33, 352, 12, 18514, 8, 1021, 1082, 6, 21, 677, 5, 96, 10273, 33, 8, 1843, 13, 17736, 24, 69, 1021, 1082, 241, 12, 29953, 5, 96, 5231, 27, 133, 456, 326, 784, 969, 2145, 908, 28, 8, 1643, 815, 13, 540, 6, 62, 43, 12, 7365, 1508, 16, 10256, 5, 96, 17527, 6, 3, 99, 24, 54, 36, 612, 11, 132, 31, 7, 128, 8179, 3, 26413, 347, 44, 8, 414, 6, 4273, 6, 752, 31, 7, 320, 12, 830, 1508, 223, 5, 96, 11836, 34, 31, 7, 3, 9, 23958, 296, 6, 19, 29, 31, 17, 34, 58, 96, 196, 17, 31, 7, 1399, 12, 240, 8, 3, 13863, 11, 281, 6, 68, 248, 3, 99, 25, 54, 129, 135, 223, 38, 168, 6, 937, 132, 31, 7, 631, 540, 535, 2390, 11, 7262, 10371, 7, 2050, 2715, 7, 65, 16, 15777, 3, 88, 56, 217, 91, 112, 16046, 10730, 1696, 5, 216, 11, 16202, 7, 92, 2283, 19664, 8, 800, 13, 3140, 1919, 5, 2715, 7, 92, 10246, 271, 4781, 57, 2622, 16, 2379, 29494, 301, 31, 427, 23067, 15, 3, 20923, 12, 16046, 9493, 9906, 17, 325, 2360, 822, 53, 70, 9570, 5, 2969, 2715, 7, 11, 24397, 49, 31914, 17, 15, 3311, 16046, 2177, 13, 8, 2038, 11590, 774, 298, 14667, 440, 18, 17114, 16202, 7, 2301, 132, 16, 1882, 2038, 227, 271, 19664, 21, 3, 15471, 2081, 57, 1798, 1886, 2474, 5993, 5, 1], [21603, 10, 6788, 19758, 7, 2273, 130, 718, 91, 12, 1154, 28, 3, 9, 6220, 2642, 44, 8, 6036, 30, 8, 368, 3540, 986, 7, 2409, 30, 1701, 706, 5, 2409, 7, 130, 16645, 326, 11, 2117, 12355, 1054, 38, 3, 9, 6478, 16813, 47, 4006, 91, 5, 37, 12787, 6, 261, 57, 1932, 27874, 5220, 31195, 3230, 6, 43, 118, 7774, 3, 9, 381, 13, 648, 5, 1377, 1310, 6, 17602, 6417, 6032, 130, 4006, 91, 30, 8, 6036, 30, 12096, 8348, 16, 1186, 11, 932, 5, 37, 6032, 1553, 826, 3, 9, 27874, 896, 2063, 2902, 16, 1882, 1673, 18395, 53, 8, 7070, 13, 8, 7021, 5692, 44, 8, 896, 2501, 5, 1193, 1778, 29, 53, 8, 1251, 3534, 9, 226, 6, 11529, 283, 4569, 4409, 5225, 8692, 243, 10, 96, 196, 17, 19, 3, 9, 2261, 5415, 21, 8, 415, 616, 6, 34, 4110, 2261, 17879, 6, 34, 10762, 151, 31, 7, 1342, 44, 1020, 6, 34, 54, 1709, 3583, 364, 7232, 8, 616, 5, 96, 19494, 62, 174, 151, 28, 251, 12, 698, 24, 28, 8, 2095, 16, 455, 21, 135, 12, 103, 70, 613, 11, 830, 175, 151, 12, 4831, 535, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[368, 22982, 24895, 3545, 13404, 3121, 15, 189, 28571, 7228, 3, 9, 4494, 3996, 19660, 51, 549, 8503, 18, 18145, 7, 3069, 225, 36, 261, 12, 7365, 234, 18, 390, 3683, 224, 38, 1414, 265, 6060, 6, 59, 830, 223, 1215, 699, 26, 4811, 5, 1], [71, 21641, 2642, 646, 1067, 46, 11529, 3450, 828, 16, 5727, 27874, 65, 118, 10126, 3, 9, 3534, 9, 226, 5, 1]]}" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 15 + } + ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-b70jh26IrJS", "outputId": "f5a50c2b-0106-41bb-ffb2-3a7c445308f6" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'input_ids': [[21603, 10, 17716, 2279, 43, 5229, 128, 1410, 18, 390, 1508, 28, 5146, 12, 10256, 5, 96, 196, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 976, 243, 28571, 5, 37, 549, 8503, 795, 17586, 51, 12, 8, 3069, 11, 3996, 13606, 51, 639, 45, 8, 6266, 5, 18263, 10256, 11, 2390, 11, 7262, 10371, 7, 3971, 18, 17114, 28571, 1632, 549, 8503, 13404, 30, 2818, 1401, 1797, 6, 7229, 53, 20, 12151, 1955, 8356, 49, 53, 826, 3, 19585, 643, 9768, 5, 216, 19, 230, 3122, 3, 9, 2103, 1059, 12, 1175, 112, 1075, 38, 24260, 350, 16103, 10282, 7, 5752, 4297, 227, 271, 3, 11060, 30, 12, 8, 549, 8503, 1476, 16, 1600, 5, 28571, 47, 859, 8, 1374, 5638, 859, 10282, 7, 6, 411, 7, 2026, 63, 7, 6, 14586, 7677, 11, 26911, 2419, 7, 4298, 113, 130, 10960, 52, 26786, 16, 3, 9, 813, 11674, 11044, 28, 8, 549, 8503, 24, 3492, 16, 3, 9, 3996, 3328, 51, 1154, 16, 1660, 48, 215, 5, 86, 8, 7178, 13, 24, 1154, 271, 612, 6, 28571, 243, 8, 3996, 19660, 51, 225, 36, 1869, 30, 3, 5833, 750, 10256, 18, 390, 4811, 2367, 132, 5, 86, 1100, 1274, 6, 16046, 10730, 24397, 49, 2744, 31914, 17, 15, 47, 5229, 28, 7646, 12, 10256, 5, 3, 21322, 8, 1919, 1886, 31, 7, 14667, 440, 18, 17114, 4794, 16202, 7, 11, 2050, 17845, 2715, 7, 130, 92, 2633, 26, 21, 487, 5146, 5, 10256, 3763, 16700, 2776, 17, 40, 232, 65, 243, 10, 96, 1326, 43, 29, 31, 17, 16, 7, 2880, 920, 574, 28, 8, 1508, 5, 96, 11836, 62, 33, 2718, 24, 80, 42, 192, 13, 135, 33, 9805, 12, 1205, 12, 10256, 14159, 1066, 145, 865, 535, 14734, 12, 4712, 2781, 584, 30, 9938, 5061, 10256, 6, 28571, 3, 60, 18, 155, 15, 4094, 112, 3, 8389, 6, 2145, 2627, 1508, 224, 38, 14586, 7677, 423, 18, 1549, 1414, 265, 6060, 11, 411, 7, 2026, 63, 7, 24397, 49, 12446, 2262, 3791, 447, 16, 10256, 225, 240, 20799, 1433, 5, 96, 196, 17, 31, 7, 6865, 3, 9, 1643, 866, 13, 540, 784, 28843, 4275, 37, 7021, 33, 12932, 15436, 13, 24, 1696, 11, 8, 6266, 33, 3, 3131, 3996, 13606, 51, 16, 5, 96, 5231, 34, 31, 7, 3, 9, 792, 815, 13, 131, 147, 23395, 51, 11, 3, 99, 25, 320, 44, 8, 10549, 13, 21331, 24, 8, 233, 3413, 233, 43, 118, 3, 22765, 12, 281, 10055, 21, 784, 355, 908, 1516, 6201, 13, 540, 5, 96, 5231, 3, 99, 62, 130, 12, 830, 8, 1508, 223, 6, 62, 31, 26, 1077, 129, 874, 42, 1296, 1508, 5, 96, 7175, 27, 31, 162, 373, 1800, 3, 18, 11, 48, 19, 28, 82, 22209, 3, 547, 30, 230, 117, 48, 19, 59, 1719, 42, 549, 8503, 3, 18, 27, 31, 26, 1066, 1492, 24, 540, 30, 2627, 1508, 16, 10256, 5, 96, 7238, 33, 1508, 1107, 91, 13, 1696, 6, 2361, 16, 8, 416, 215, 42, 78, 233, 25, 31, 60, 479, 44, 39, 1414, 265, 6060, 31, 13, 8, 296, 117, 12446, 2262, 3791, 447, 21, 677, 3, 18, 62, 174, 12, 453, 175, 3413, 16, 10256, 5, 96, 1326, 700, 241, 135, 132, 5, 328, 33, 8, 2102, 113, 33, 352, 12, 18514, 8, 1021, 1082, 6, 21, 677, 5, 96, 10273, 33, 8, 1843, 13, 17736, 24, 69, 1021, 1082, 241, 12, 29953, 5, 96, 5231, 27, 133, 456, 326, 784, 969, 2145, 908, 28, 8, 1643, 815, 13, 540, 6, 62, 43, 12, 7365, 1508, 16, 10256, 5, 96, 17527, 6, 3, 99, 24, 54, 36, 612, 11, 132, 31, 7, 128, 8179, 3, 26413, 347, 44, 8, 414, 6, 4273, 6, 752, 31, 7, 320, 12, 830, 1508, 223, 5, 96, 11836, 34, 31, 7, 3, 9, 23958, 296, 6, 19, 29, 31, 17, 34, 58, 96, 196, 17, 31, 7, 1399, 12, 240, 8, 3, 13863, 11, 281, 6, 68, 248, 3, 99, 25, 54, 129, 135, 223, 38, 168, 6, 937, 132, 31, 7, 631, 540, 535, 2390, 11, 7262, 10371, 7, 2050, 2715, 7, 65, 16, 15777, 3, 88, 56, 217, 91, 112, 16046, 10730, 1696, 5, 216, 11, 16202, 7, 92, 2283, 19664, 8, 800, 13, 3140, 1919, 5, 2715, 7, 92, 10246, 271, 4781, 57, 2622, 16, 2379, 29494, 301, 31, 427, 23067, 15, 3, 20923, 12, 16046, 9493, 9906, 17, 325, 2360, 822, 53, 70, 9570, 5, 2969, 2715, 7, 11, 24397, 49, 31914, 17, 15, 3311, 16046, 2177, 13, 8, 2038, 11590, 774, 298, 14667, 440, 18, 17114, 16202, 7, 2301, 132, 16, 1882, 2038, 227, 271, 19664, 21, 3, 15471, 2081, 57, 1798, 1886, 2474, 5993, 5, 1], [21603, 10, 6788, 19758, 7, 2273, 130, 718, 91, 12, 1154, 28, 3, 9, 6220, 2642, 44, 8, 6036, 30, 8, 368, 3540, 986, 7, 2409, 30, 1701, 706, 5, 2409, 7, 130, 16645, 326, 11, 2117, 12355, 1054, 38, 3, 9, 6478, 16813, 47, 4006, 91, 5, 37, 12787, 6, 261, 57, 1932, 27874, 5220, 31195, 3230, 6, 43, 118, 7774, 3, 9, 381, 13, 648, 5, 1377, 1310, 6, 17602, 6417, 6032, 130, 4006, 91, 30, 8, 6036, 30, 12096, 8348, 16, 1186, 11, 932, 5, 37, 6032, 1553, 826, 3, 9, 27874, 896, 2063, 2902, 16, 1882, 1673, 18395, 53, 8, 7070, 13, 8, 7021, 5692, 44, 8, 896, 2501, 5, 1193, 1778, 29, 53, 8, 1251, 3534, 9, 226, 6, 11529, 283, 4569, 4409, 5225, 8692, 243, 10, 96, 196, 17, 19, 3, 9, 2261, 5415, 21, 8, 415, 616, 6, 34, 4110, 2261, 17879, 6, 34, 10762, 151, 31, 7, 1342, 44, 1020, 6, 34, 54, 1709, 3583, 364, 7232, 8, 616, 5, 96, 19494, 62, 174, 151, 28, 251, 12, 698, 24, 28, 8, 2095, 16, 455, 21, 135, 12, 103, 70, 613, 11, 830, 175, 151, 12, 4831, 535, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[368, 22982, 24895, 3545, 13404, 3121, 15, 189, 28571, 7228, 3, 9, 4494, 3996, 19660, 51, 549, 8503, 18, 18145, 7, 3069, 225, 36, 261, 12, 7365, 234, 18, 390, 3683, 224, 38, 1414, 265, 6060, 6, 59, 830, 223, 1215, 699, 26, 4811, 5, 1], [71, 21641, 2642, 646, 1067, 46, 11529, 3450, 828, 16, 5727, 27874, 65, 118, 10126, 3, 9, 3534, 9, 226, 5, 1]]}" - ] - }, - "execution_count": 15, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "preprocess_function(raw_datasets['train'][:2])" - ] + } }, { "cell_type": "markdown", - "metadata": { - "id": "zS-6iXTkIrJT" - }, "source": [ "接下来对数据集datasets里面的所有样本进行预处理,处理的方式是使用map函数,将预处理函数prepare_train_features应用到(map)所有样本上。" - ] + ], + "metadata": { + "id": "zS-6iXTkIrJT" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "DDtsaJeVIrJT" - }, - "outputs": [], "source": [ "tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)" - ] + ], + "outputs": [], + "metadata": { + "id": "DDtsaJeVIrJT" + } }, { "cell_type": "markdown", - "metadata": { - "id": "voWiw8C7IrJV" - }, "source": [ "更好的是,返回的结果会自动被缓存,避免下次处理的时候重新计算(但是也要注意,如果输入有改动,可能会被缓存影响!)。datasets库函数会对输入的参数进行检测,判断是否有变化,如果没有变化就使用缓存数据,如果有变化就重新处理。但如果输入参数不变,想改变输入的时候,最好清理调这个缓存。清理的方式是使用`load_from_cache_file=False`参数。另外,上面使用到的`batched=True`这个参数是tokenizer的特点,以为这会使用多线程同时并行对输入进行处理。" - ] + ], + "metadata": { + "id": "voWiw8C7IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "545PP3o8IrJV" - }, "source": [ "## 微调模型" - ] + ], + "metadata": { + "id": "545PP3o8IrJV" + } }, { "cell_type": "markdown", - "metadata": { - "id": "FBiW8UpKIrJW" - }, "source": [ "既然数据已经准备好了,现在我们需要下载并加载我们的预训练模型,然后微调预训练模型。既然我们是做seq2seq任务,那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSeq2SeqLM`这个类。和tokenizer相似,`from_pretrained`方法同样可以帮助我们下载并加载模型,同时也会对模型进行缓存,就不会重复下载模型啦。" - ] + ], + "metadata": { + "id": "FBiW8UpKIrJW" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "TlqNaB8jIrJW" - }, - "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)" - ] + ], + "outputs": [], + "metadata": { + "id": "TlqNaB8jIrJW" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CczA5lJlIrJX" - }, "source": [ "由于我们微调的任务是seq2seq任务,而我们加载的是预训练的seq2seq模型,所以不会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数(比如:预训练语言模型的神经网络head被扔掉了,同时随机初始化了机器翻译的神经网络head)。" - ] + ], + "metadata": { + "id": "CczA5lJlIrJX" + } }, { "cell_type": "markdown", - "metadata": { - "id": "_N8urzhyIrJY" - }, "source": [ "\n", "为了能够得到一个`Seq2SeqTrainer`训练工具,我们还需要3个要素,其中最重要的是训练的设定/参数[`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性" - ] + ], + "metadata": { + "id": "_N8urzhyIrJY" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Bliy8zgjIrJY" - }, - "outputs": [], "source": [ "batch_size = 16\n", "args = Seq2SeqTrainingArguments(\n", @@ -739,13 +760,14 @@ " predict_with_generate=True,\n", " fp16=True,\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "Bliy8zgjIrJY" + } }, { "cell_type": "markdown", - "metadata": { - "id": "km3pGVdTIrJc" - }, "source": [ "上面evaluation_strategy = \"epoch\"参数告诉训练代码:我们每个epcoh会做一次验证评估。\n", "\n", @@ -754,35 +776,34 @@ "由于我们的数据集比较大,同时`Seq2SeqTrainer`会不断保存模型,所以我们需要告诉它至多保存`save_total_limit=3`个模型。\n", "\n", "最后我们需要一个数据收集器data collator,将我们处理好的输入喂给模型。" - ] + ], + "metadata": { + "id": "km3pGVdTIrJc" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "osCjNg41dJT_" - }, - "outputs": [], "source": [ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)" - ] + ], + "outputs": [], + "metadata": { + "id": "osCjNg41dJT_" + } }, { "cell_type": "markdown", - "metadata": { - "id": "7sZOdRlRIrJd" - }, "source": [ "设置好`Seq2SeqTrainer`还剩最后一件事情,那就是我们需要定义好评估方法。我们使用`metric`来完成评估。将模型预测送入评估之前,我们也会做一些数据后处理:" - ] + ], + "metadata": { + "id": "7sZOdRlRIrJd" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "UmvbnJ9JIrJd" - }, - "outputs": [], "source": [ "import nltk\n", "import numpy as np\n", @@ -807,24 +828,24 @@ " result[\"gen_len\"] = np.mean(prediction_lens)\n", " \n", " return {k: round(v, 4) for k, v in result.items()}" - ] + ], + "outputs": [], + "metadata": { + "id": "UmvbnJ9JIrJd" + } }, { "cell_type": "markdown", - "metadata": { - "id": "rXuFTAzDIrJe" - }, "source": [ "最后将所有的参数/数据/模型传给`Seq2SeqTrainer`即可" - ] + ], + "metadata": { + "id": "rXuFTAzDIrJe" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "imY1oC3SIrJf" - }, - "outputs": [], "source": [ "trainer = Seq2SeqTrainer(\n", " model,\n", @@ -835,28 +856,34 @@ " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", ")" - ] + ], + "outputs": [], + "metadata": { + "id": "imY1oC3SIrJf" + } }, { "cell_type": "markdown", - "metadata": { - "id": "CdzABDVcIrJg" - }, "source": [ "调用`train`方法进行微调训练。" - ] + ], + "metadata": { + "id": "CdzABDVcIrJg" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "uNx5pyRlIrJh", - "outputId": "077e661e-d36c-469b-89b8-7ff7f73541ec", - "scrolled": false - }, + "source": [ + "trainer.train()" + ], "outputs": [ { + "output_type": "display_data", "data": { + "text/plain": [ + "" + ], "text/html": [ "\n", "
\n", @@ -903,50 +930,48 @@ " \n", " \n", "

" - ], - "text/plain": [ - "" ] }, "metadata": { "tags": [] - }, - "output_type": "display_data" + } }, { + "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=12753, training_loss=2.7692033505520146, metrics={'train_runtime': 4909.3835, 'train_samples_per_second': 2.598, 'total_flos': 7.774481450954342e+16, 'epoch': 1.0, 'init_mem_cpu_alloc_delta': 335248, 'init_mem_gpu_alloc_delta': 242026496, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 2637782, 'train_mem_gpu_alloc_delta': 728138240, 'train_mem_cpu_peaked_delta': 138226182, 'train_mem_gpu_peaked_delta': 14677017088})" ] }, - "execution_count": 23, "metadata": { "tags": [] }, - "output_type": "execute_result" + "execution_count": 23 } ], - "source": [ - "trainer.train()" - ] + "metadata": { + "id": "uNx5pyRlIrJh", + "outputId": "077e661e-d36c-469b-89b8-7ff7f73541ec", + "scrolled": false + } }, { "cell_type": "markdown", - "metadata": { - "id": "ha0-YirPdJUB" - }, "source": [ "最后别忘了,查看如何上传模型 ,上传模型到](https://huggingface.co/transformers/model_sharing.html) 到[🤗 Model Hub](https://huggingface.co/models)。随后您就可以像这个notebook一开始一样,直接用模型名字就能使用您的模型啦。\n" - ] + ], + "metadata": { + "id": "ha0-YirPdJUB" + } }, { "cell_type": "code", "execution_count": null, + "source": [], + "outputs": [], "metadata": { "id": "NC7aCpkBdJUB" - }, - "outputs": [], - "source": [] + } } ], "metadata": { @@ -973,5 +998,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 2 } \ No newline at end of file