add squad, xsum
This commit is contained in:
parent
f4ffb9339d
commit
530ff994ea
|
@ -0,0 +1,92 @@
|
||||||
|
""" Official evaluation script for v1.1 of the SQuAD dataset. """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
import sys
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s):
|
||||||
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||||
|
|
||||||
|
def remove_articles(text):
|
||||||
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
return " ".join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return "".join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(prediction, ground_truth):
|
||||||
|
prediction_tokens = normalize_answer(prediction).split()
|
||||||
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if num_same == 0:
|
||||||
|
return 0
|
||||||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||||||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match_score(prediction, ground_truth):
|
||||||
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||||
|
|
||||||
|
|
||||||
|
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||||
|
scores_for_ground_truths = []
|
||||||
|
for ground_truth in ground_truths:
|
||||||
|
score = metric_fn(prediction, ground_truth)
|
||||||
|
scores_for_ground_truths.append(score)
|
||||||
|
return max(scores_for_ground_truths)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(dataset, predictions):
|
||||||
|
f1 = exact_match = total = 0
|
||||||
|
for article in dataset:
|
||||||
|
for paragraph in article["paragraphs"]:
|
||||||
|
for qa in paragraph["qas"]:
|
||||||
|
total += 1
|
||||||
|
if qa["id"] not in predictions:
|
||||||
|
message = "Unanswered question " + qa["id"] + " will receive score 0."
|
||||||
|
print(message, file=sys.stderr)
|
||||||
|
continue
|
||||||
|
ground_truths = list(map(lambda x: x["text"], qa["answers"]))
|
||||||
|
prediction = predictions[qa["id"]]
|
||||||
|
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||||
|
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||||
|
|
||||||
|
exact_match = 100.0 * exact_match / total
|
||||||
|
f1 = 100.0 * f1 / total
|
||||||
|
|
||||||
|
return {"exact_match": exact_match, "f1": f1}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
expected_version = "1.1"
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version)
|
||||||
|
parser.add_argument("dataset_file", help="Dataset file")
|
||||||
|
parser.add_argument("prediction_file", help="Prediction File")
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.dataset_file) as dataset_file:
|
||||||
|
dataset_json = json.load(dataset_file)
|
||||||
|
if dataset_json["version"] != expected_version:
|
||||||
|
print(
|
||||||
|
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"],
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
dataset = dataset_json["data"]
|
||||||
|
with open(args.prediction_file) as prediction_file:
|
||||||
|
predictions = json.load(prediction_file)
|
||||||
|
print(json.dumps(evaluate(dataset, predictions)))
|
|
@ -0,0 +1,131 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" ROUGE metric from Google Research github repo. """
|
||||||
|
|
||||||
|
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt
|
||||||
|
import absl # Here to have a nice missing dependency error message early on
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
|
import numpy # Here to have a nice missing dependency error message early on
|
||||||
|
import six # Here to have a nice missing dependency error message early on
|
||||||
|
from rouge_score import rouge_scorer, scoring
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
_CITATION = """\
|
||||||
|
@inproceedings{lin-2004-rouge,
|
||||||
|
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries",
|
||||||
|
author = "Lin, Chin-Yew",
|
||||||
|
booktitle = "Text Summarization Branches Out",
|
||||||
|
month = jul,
|
||||||
|
year = "2004",
|
||||||
|
address = "Barcelona, Spain",
|
||||||
|
publisher = "Association for Computational Linguistics",
|
||||||
|
url = "https://www.aclweb.org/anthology/W04-1013",
|
||||||
|
pages = "74--81",
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESCRIPTION = """\
|
||||||
|
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for
|
||||||
|
evaluating automatic summarization and machine translation software in natural language processing.
|
||||||
|
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation.
|
||||||
|
|
||||||
|
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters.
|
||||||
|
|
||||||
|
This metrics is a wrapper around Google Research reimplementation of ROUGE:
|
||||||
|
https://github.com/google-research/google-research/tree/master/rouge
|
||||||
|
"""
|
||||||
|
|
||||||
|
_KWARGS_DESCRIPTION = """
|
||||||
|
Calculates average rouge scores for a list of hypotheses and references
|
||||||
|
Args:
|
||||||
|
predictions: list of predictions to score. Each predictions
|
||||||
|
should be a string with tokens separated by spaces.
|
||||||
|
references: list of reference for each prediction. Each
|
||||||
|
reference should be a string with tokens separated by spaces.
|
||||||
|
rouge_types: A list of rouge types to calculate.
|
||||||
|
Valid names:
|
||||||
|
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
|
||||||
|
`"rougeL"`: Longest common subsequence based scoring.
|
||||||
|
`"rougeLSum"`: rougeLsum splits text using `"\n"`.
|
||||||
|
See details in https://github.com/huggingface/datasets/issues/617
|
||||||
|
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
|
||||||
|
use_agregator: Return aggregates if this is set to True
|
||||||
|
Returns:
|
||||||
|
rouge1: rouge_1 (precision, recall, f1),
|
||||||
|
rouge2: rouge_2 (precision, recall, f1),
|
||||||
|
rougeL: rouge_l (precision, recall, f1),
|
||||||
|
rougeLsum: rouge_lsum (precision, recall, f1)
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> rouge = datasets.load_metric('rouge')
|
||||||
|
>>> predictions = ["hello there", "general kenobi"]
|
||||||
|
>>> references = ["hello there", "general kenobi"]
|
||||||
|
>>> results = rouge.compute(predictions=predictions, references=references)
|
||||||
|
>>> print(list(results.keys()))
|
||||||
|
['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
|
||||||
|
>>> print(results["rouge1"])
|
||||||
|
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))
|
||||||
|
>>> print(results["rouge1"].mid.fmeasure)
|
||||||
|
1.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||||
|
class Rouge(datasets.Metric):
|
||||||
|
def _info(self):
|
||||||
|
return datasets.MetricInfo(
|
||||||
|
description=_DESCRIPTION,
|
||||||
|
citation=_CITATION,
|
||||||
|
inputs_description=_KWARGS_DESCRIPTION,
|
||||||
|
features=datasets.Features(
|
||||||
|
{
|
||||||
|
"predictions": datasets.Value("string", id="sequence"),
|
||||||
|
"references": datasets.Value("string", id="sequence"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"],
|
||||||
|
reference_urls=[
|
||||||
|
"https://en.wikipedia.org/wiki/ROUGE_(metric)",
|
||||||
|
"https://github.com/google-research/google-research/tree/master/rouge",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute(self, predictions, references, rouge_types=None, use_agregator=True, use_stemmer=False):
|
||||||
|
if rouge_types is None:
|
||||||
|
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||||
|
|
||||||
|
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer)
|
||||||
|
if use_agregator:
|
||||||
|
aggregator = scoring.BootstrapAggregator()
|
||||||
|
else:
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for ref, pred in zip(references, predictions):
|
||||||
|
score = scorer.score(ref, pred)
|
||||||
|
if use_agregator:
|
||||||
|
aggregator.add_scores(score)
|
||||||
|
else:
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
if use_agregator:
|
||||||
|
result = aggregator.aggregate()
|
||||||
|
else:
|
||||||
|
result = {}
|
||||||
|
for key in scores[0]:
|
||||||
|
result[key] = list(score[key] for score in scores)
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,110 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Datasets Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" SQuAD metric. """
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
from .evaluate import evaluate
|
||||||
|
|
||||||
|
|
||||||
|
_CITATION = """\
|
||||||
|
@inproceedings{Rajpurkar2016SQuAD10,
|
||||||
|
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
||||||
|
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
||||||
|
booktitle={EMNLP},
|
||||||
|
year={2016}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESCRIPTION = """
|
||||||
|
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
|
||||||
|
|
||||||
|
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
||||||
|
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
||||||
|
from the corresponding reading passage, or the question might be unanswerable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_KWARGS_DESCRIPTION = """
|
||||||
|
Computes SQuAD scores (F1 and EM).
|
||||||
|
Args:
|
||||||
|
predictions: List of question-answers dictionaries with the following key-values:
|
||||||
|
- 'id': id of the question-answer pair as given in the references (see below)
|
||||||
|
- 'prediction_text': the text of the answer
|
||||||
|
references: List of question-answers dictionaries with the following key-values:
|
||||||
|
- 'id': id of the question-answer pair (see above),
|
||||||
|
- 'answers': a Dict in the SQuAD dataset format
|
||||||
|
{
|
||||||
|
'text': list of possible texts for the answer, as a list of strings
|
||||||
|
'answer_start': list of start positions for the answer, as a list of ints
|
||||||
|
}
|
||||||
|
Note that answer_start values are not taken into account to compute the metric.
|
||||||
|
Returns:
|
||||||
|
'exact_match': Exact match (the normalized answer exactly match the gold answer)
|
||||||
|
'f1': The F-score of predicted tokens versus the gold answer
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
|
||||||
|
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
||||||
|
>>> squad_metric = datasets.load_metric("squad")
|
||||||
|
>>> results = squad_metric.compute(predictions=predictions, references=references)
|
||||||
|
>>> print(results)
|
||||||
|
{'exact_match': 100.0, 'f1': 100.0}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
||||||
|
class Squad(datasets.Metric):
|
||||||
|
def _info(self):
|
||||||
|
return datasets.MetricInfo(
|
||||||
|
description=_DESCRIPTION,
|
||||||
|
citation=_CITATION,
|
||||||
|
inputs_description=_KWARGS_DESCRIPTION,
|
||||||
|
features=datasets.Features(
|
||||||
|
{
|
||||||
|
"predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
|
||||||
|
"references": {
|
||||||
|
"id": datasets.Value("string"),
|
||||||
|
"answers": datasets.features.Sequence(
|
||||||
|
{
|
||||||
|
"text": datasets.Value("string"),
|
||||||
|
"answer_start": datasets.Value("int32"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||||
|
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute(self, predictions, references):
|
||||||
|
pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
|
||||||
|
dataset = [
|
||||||
|
{
|
||||||
|
"paragraphs": [
|
||||||
|
{
|
||||||
|
"qas": [
|
||||||
|
{
|
||||||
|
"answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
|
||||||
|
"id": ref["id"],
|
||||||
|
}
|
||||||
|
for ref in references
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
score = evaluate(dataset=dataset, predictions=pred_dict)
|
||||||
|
return score
|
|
@ -0,0 +1,139 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""SQUAD: The Stanford Question Answering Dataset."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
logger = datasets.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_CITATION = """\
|
||||||
|
@article{2016arXiv160605250R,
|
||||||
|
author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},
|
||||||
|
Konstantin and {Liang}, Percy},
|
||||||
|
title = "{SQuAD: 100,000+ Questions for Machine Comprehension of Text}",
|
||||||
|
journal = {arXiv e-prints},
|
||||||
|
year = 2016,
|
||||||
|
eid = {arXiv:1606.05250},
|
||||||
|
pages = {arXiv:1606.05250},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
eprint = {1606.05250},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESCRIPTION = """\
|
||||||
|
Stanford Question Answering Dataset (SQuAD) is a reading comprehension \
|
||||||
|
dataset, consisting of questions posed by crowdworkers on a set of Wikipedia \
|
||||||
|
articles, where the answer to every question is a segment of text, or span, \
|
||||||
|
from the corresponding reading passage, or the question might be unanswerable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_URL = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
|
||||||
|
_URLS = {
|
||||||
|
"train": _URL + "train-v1.1.json",
|
||||||
|
"dev": _URL + "dev-v1.1.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SquadConfig(datasets.BuilderConfig):
|
||||||
|
"""BuilderConfig for SQUAD."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""BuilderConfig for SQUAD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: keyword arguments forwarded to super.
|
||||||
|
"""
|
||||||
|
super(SquadConfig, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Squad(datasets.GeneratorBasedBuilder):
|
||||||
|
"""SQUAD: The Stanford Question Answering Dataset. Version 1.1."""
|
||||||
|
|
||||||
|
BUILDER_CONFIGS = [
|
||||||
|
SquadConfig(
|
||||||
|
name="plain_text",
|
||||||
|
version=datasets.Version("1.0.0", ""),
|
||||||
|
description="Plain text",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _info(self):
|
||||||
|
return datasets.DatasetInfo(
|
||||||
|
description=_DESCRIPTION,
|
||||||
|
features=datasets.Features(
|
||||||
|
{
|
||||||
|
"id": datasets.Value("string"),
|
||||||
|
"title": datasets.Value("string"),
|
||||||
|
"context": datasets.Value("string"),
|
||||||
|
"question": datasets.Value("string"),
|
||||||
|
"answers": datasets.features.Sequence(
|
||||||
|
{
|
||||||
|
"text": datasets.Value("string"),
|
||||||
|
"answer_start": datasets.Value("int32"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
# No default supervised_keys (as we have to pass both question
|
||||||
|
# and context as input).
|
||||||
|
supervised_keys=None,
|
||||||
|
homepage="https://rajpurkar.github.io/SQuAD-explorer/",
|
||||||
|
citation=_CITATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_generators(self, dl_manager):
|
||||||
|
downloaded_files = self.config.data_files
|
||||||
|
|
||||||
|
return [
|
||||||
|
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
|
||||||
|
datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["validation"]}),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _generate_examples(self, filepath):
|
||||||
|
"""This function returns the examples in the raw (text) form."""
|
||||||
|
logger.info("generating examples from = %s", filepath)
|
||||||
|
with open(filepath, encoding="utf-8") as f:
|
||||||
|
squad = json.load(f)
|
||||||
|
for article in squad["data"]:
|
||||||
|
title = article.get("title", "").strip()
|
||||||
|
for paragraph in article["paragraphs"]:
|
||||||
|
context = paragraph["context"].strip()
|
||||||
|
for qa in paragraph["qas"]:
|
||||||
|
question = qa["question"].strip()
|
||||||
|
id_ = qa["id"]
|
||||||
|
|
||||||
|
answer_starts = [answer["answer_start"] for answer in qa["answers"]]
|
||||||
|
answers = [answer["text"].strip() for answer in qa["answers"]]
|
||||||
|
|
||||||
|
# Features currently used are "context", "question", and "answers".
|
||||||
|
# Others are extracted here for the ease of future expansions.
|
||||||
|
yield id_, {
|
||||||
|
"title": title,
|
||||||
|
"context": context,
|
||||||
|
"question": question,
|
||||||
|
"id": id_,
|
||||||
|
"answers": {
|
||||||
|
"answer_start": answer_starts,
|
||||||
|
"text": answers,
|
||||||
|
},
|
||||||
|
}
|
|
@ -0,0 +1,154 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""XSum dataset."""
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
|
||||||
|
_CITATION = """
|
||||||
|
@article{Narayan2018DontGM,
|
||||||
|
title={Don't Give Me the Details, Just the Summary! Topic-Aware Convolutional Neural Networks for Extreme Summarization},
|
||||||
|
author={Shashi Narayan and Shay B. Cohen and Mirella Lapata},
|
||||||
|
journal={ArXiv},
|
||||||
|
year={2018},
|
||||||
|
volume={abs/1808.08745}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESCRIPTION = """
|
||||||
|
Extreme Summarization (XSum) Dataset.
|
||||||
|
|
||||||
|
There are three features:
|
||||||
|
- document: Input news article.
|
||||||
|
- summary: One sentence summary of the article.
|
||||||
|
- id: BBC ID of the article.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# From https://github.com/EdinburghNLP/XSum/issues/12
|
||||||
|
_URL_DATA = "http://bollin.inf.ed.ac.uk/public/direct/XSUM-EMNLP18-Summary-Data-Original.tar.gz"
|
||||||
|
_URL_SPLITS = (
|
||||||
|
"https://raw.githubusercontent.com/EdinburghNLP/XSum/master/XSum-Dataset/XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
_DOCUMENT = "document"
|
||||||
|
_SUMMARY = "summary"
|
||||||
|
_ID = "id"
|
||||||
|
|
||||||
|
_REMOVE_LINES = set(
|
||||||
|
[
|
||||||
|
"Share this with\n",
|
||||||
|
"Email\n",
|
||||||
|
"Facebook\n",
|
||||||
|
"Messenger\n",
|
||||||
|
"Twitter\n",
|
||||||
|
"Pinterest\n",
|
||||||
|
"WhatsApp\n",
|
||||||
|
"Linkedin\n",
|
||||||
|
"LinkedIn\n",
|
||||||
|
"Copy this link\n",
|
||||||
|
"These are external links and will open in a new window\n",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Xsum(datasets.GeneratorBasedBuilder):
|
||||||
|
"""Extreme Summarization (XSum) Dataset."""
|
||||||
|
|
||||||
|
# Version 1.2.0 expands coverage, includes ids, and removes web contents.
|
||||||
|
VERSION = datasets.Version("1.2.0")
|
||||||
|
|
||||||
|
def _info(self):
|
||||||
|
return datasets.DatasetInfo(
|
||||||
|
description=_DESCRIPTION,
|
||||||
|
features=datasets.Features(
|
||||||
|
{
|
||||||
|
_DOCUMENT: datasets.Value("string"),
|
||||||
|
_SUMMARY: datasets.Value("string"),
|
||||||
|
_ID: datasets.Value("string"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
supervised_keys=(_DOCUMENT, _SUMMARY),
|
||||||
|
homepage="https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset",
|
||||||
|
citation=_CITATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_generators(self, dl_manager):
|
||||||
|
"""Returns SplitGenerators."""
|
||||||
|
data_files = self.config.data_files
|
||||||
|
split_path = os.path.join(data_files["data"], "XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json")
|
||||||
|
|
||||||
|
return [
|
||||||
|
datasets.SplitGenerator(
|
||||||
|
name=datasets.Split.TRAIN,
|
||||||
|
gen_kwargs={
|
||||||
|
"split_path": split_path,
|
||||||
|
"split_name": "train",
|
||||||
|
"data_dir": os.path.join(data_files["data"], "bbc-summary-data"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
datasets.SplitGenerator(
|
||||||
|
name=datasets.Split.VALIDATION,
|
||||||
|
gen_kwargs={
|
||||||
|
"split_path": split_path,
|
||||||
|
"split_name": "validation",
|
||||||
|
"data_dir": os.path.join(data_files["data"], "bbc-summary-data"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
datasets.SplitGenerator(
|
||||||
|
name=datasets.Split.TEST,
|
||||||
|
gen_kwargs={
|
||||||
|
"split_path": split_path,
|
||||||
|
"split_name": "test",
|
||||||
|
"data_dir": os.path.join(data_files["data"], "bbc-summary-data"),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _generate_examples(self, split_path, split_name, data_dir):
|
||||||
|
"""Yields examples."""
|
||||||
|
|
||||||
|
with open(split_path, "r", encoding="utf-8") as f:
|
||||||
|
split_ids = json.load(f)
|
||||||
|
|
||||||
|
for i in split_ids[split_name]:
|
||||||
|
with open(os.path.join(data_dir, i + ".summary"), "r", encoding="utf-8") as f:
|
||||||
|
text = "".join([line for line in f.readlines() if line not in _REMOVE_LINES and line.strip()])
|
||||||
|
# Each file follows below format:
|
||||||
|
# [SN]URL[SN]
|
||||||
|
# http://somelink
|
||||||
|
#
|
||||||
|
# [SN]TITLE[SN]
|
||||||
|
# some intro
|
||||||
|
#
|
||||||
|
# [SN]FIRST-SENTENCE[SN]
|
||||||
|
# some intro
|
||||||
|
#
|
||||||
|
# [SN]RESTBODY[SN]
|
||||||
|
# text line.
|
||||||
|
# another text line.
|
||||||
|
# "another text line."
|
||||||
|
|
||||||
|
# According to the following issue, FIRST-SENTENCE
|
||||||
|
# is the reference summary and TITLE is unused:
|
||||||
|
# https://github.com/EdinburghNLP/XSum/issues/22
|
||||||
|
segs = text.split("[SN]")
|
||||||
|
yield i, {_DOCUMENT: segs[8].strip(), _SUMMARY: segs[6].strip(), _ID: i}
|
Loading…
Reference in New Issue