versae's picture
Adding sampling to mc4
3f09f56
raw
history blame
12.4 kB
"""mC4 dataset based on Common Crawl."""
import gzip
import json
import datasets
import kenlm
import numpy as np
logger = datasets.logging.get_logger(__name__)
_DESCRIPTION = """\
A colossal, cleaned version of Common Crawl's web crawl corpus.
Based on Common Crawl dataset: "https://commoncrawl.org".
This is the processed version of Google's mC4 dataset by AllenAI.
"""
_CITATION = """
@article{2019t5,
author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
journal = {arXiv e-prints},
year = {2019},
archivePrefix = {arXiv},
eprint = {1910.10683},
}
"""
_URL = "https://github.com/allenai/allennlp/discussions/5056"
_DATA_URL = "https://huggingface.co/datasets/allenai/c4/resolve/1ddc917116b730e1859edef32896ec5c16be51d0/multilingual/c4-{language}{split_suffix}.tfrecord-{index:05d}-of-{n_shards:05d}.json.gz"
_LANGUAGES = [
"af",
"am",
"ar",
"az",
"be",
"bg",
"bg-Latn",
"bn",
"ca",
"ceb",
"co",
"cs",
"cy",
"da",
"de",
"el",
"el-Latn",
"en",
"eo",
"es",
"et",
"eu",
"fa",
"fi",
"fil",
"fr",
"fy",
"ga",
"gd",
"gl",
"gu",
"ha",
"haw",
"hi",
"hi-Latn",
"hmn",
"ht",
"hu",
"hy",
"id",
"ig",
"is",
"it",
"iw",
"ja",
"ja-Latn",
"jv",
"ka",
"kk",
"km",
"kn",
"ko",
"ku",
"ky",
"la",
"lb",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"no",
"ny",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"ru-Latn",
"sd",
"si",
"sk",
"sl",
"sm",
"sn",
"so",
"sq",
"sr",
"st",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tr",
"uk",
"und",
"ur",
"uz",
"vi",
"xh",
"yi",
"yo",
"zh",
"zh-Latn",
"zu",
]
_N_SHARDS_PER_SPLIT = {
"af": {"train": 64, "validation": 1},
"am": {"train": 16, "validation": 1},
"ar": {"train": 1024, "validation": 4},
"az": {"train": 256, "validation": 1},
"be": {"train": 128, "validation": 1},
"bg": {"train": 1024, "validation": 1},
"bg-Latn": {"train": 4, "validation": 1},
"bn": {"train": 512, "validation": 1},
"ca": {"train": 512, "validation": 1},
"ceb": {"train": 8, "validation": 1},
"co": {"train": 8, "validation": 1},
"cs": {"train": 1024, "validation": 2},
"cy": {"train": 256, "validation": 1},
"da": {"train": 1024, "validation": 1},
"de": {"train": 2048, "validation": 16},
"el": {"train": 1024, "validation": 2},
"el-Latn": {"train": 16, "validation": 1},
"en": {"train": 11264, "validation": 128},
"eo": {"train": 32, "validation": 1},
"es": {"train": 2048, "validation": 16},
"et": {"train": 256, "validation": 1},
"eu": {"train": 64, "validation": 1},
"fa": {"train": 1024, "validation": 2},
"fi": {"train": 1024, "validation": 1},
"fil": {"train": 64, "validation": 1},
"fr": {"train": 2048, "validation": 16},
"fy": {"train": 16, "validation": 1},
"ga": {"train": 16, "validation": 1},
"gd": {"train": 16, "validation": 1},
"gl": {"train": 128, "validation": 1},
"gu": {"train": 64, "validation": 1},
"ha": {"train": 8, "validation": 1},
"haw": {"train": 2, "validation": 1},
"hi": {"train": 1024, "validation": 2},
"hi-Latn": {"train": 16, "validation": 1},
"hmn": {"train": 8, "validation": 1},
"ht": {"train": 8, "validation": 1},
"hu": {"train": 1024, "validation": 2},
"hy": {"train": 128, "validation": 1},
"id": {"train": 1024, "validation": 4},
"ig": {"train": 4, "validation": 1},
"is": {"train": 128, "validation": 1},
"it": {"train": 1024, "validation": 8},
"iw": {"train": 1024, "validation": 1},
"ja": {"train": 1024, "validation": 8},
"ja-Latn": {"train": 8, "validation": 1},
"jv": {"train": 8, "validation": 1},
"ka": {"train": 256, "validation": 1},
"kk": {"train": 256, "validation": 1},
"km": {"train": 64, "validation": 1},
"kn": {"train": 64, "validation": 1},
"ko": {"train": 1024, "validation": 1},
"ku": {"train": 16, "validation": 1},
"ky": {"train": 64, "validation": 1},
"la": {"train": 64, "validation": 1},
"lb": {"train": 32, "validation": 1},
"lo": {"train": 8, "validation": 1},
"lt": {"train": 512, "validation": 1},
"lv": {"train": 256, "validation": 1},
"mg": {"train": 8, "validation": 1},
"mi": {"train": 4, "validation": 1},
"mk": {"train": 128, "validation": 1},
"ml": {"train": 128, "validation": 1},
"mn": {"train": 128, "validation": 1},
"mr": {"train": 1024, "validation": 1},
"ms": {"train": 512, "validation": 1},
"mt": {"train": 128, "validation": 1},
"my": {"train": 64, "validation": 1},
"ne": {"train": 256, "validation": 1},
"nl": {"train": 1024, "validation": 4},
"no": {"train": 1024, "validation": 1},
"ny": {"train": 4, "validation": 1},
"pa": {"train": 32, "validation": 1},
"pl": {"train": 1024, "validation": 4},
"ps": {"train": 16, "validation": 1},
"pt": {"train": 1024, "validation": 4},
"ro": {"train": 1024, "validation": 2},
"ru": {"train": 4096, "validation": 32},
"ru-Latn": {"train": 32, "validation": 1},
"sd": {"train": 64, "validation": 1},
"si": {"train": 64, "validation": 1},
"sk": {"train": 512, "validation": 1},
"sl": {"train": 256, "validation": 1},
"sm": {"train": 4, "validation": 1},
"sn": {"train": 8, "validation": 1},
"so": {"train": 64, "validation": 1},
"sq": {"train": 128, "validation": 1},
"sr": {"train": 256, "validation": 1},
"st": {"train": 2, "validation": 1},
"su": {"train": 4, "validation": 1},
"sv": {"train": 1024, "validation": 2},
"sw": {"train": 32, "validation": 1},
"ta": {"train": 256, "validation": 1},
"te": {"train": 128, "validation": 1},
"tg": {"train": 64, "validation": 1},
"th": {"train": 1024, "validation": 1},
"tr": {"train": 1024, "validation": 4},
"uk": {"train": 1024, "validation": 2},
"und": {"train": 3072, "validation": 32},
"ur": {"train": 128, "validation": 1},
"uz": {"train": 32, "validation": 1},
"vi": {"train": 1024, "validation": 4},
"xh": {"train": 2, "validation": 1},
"yi": {"train": 16, "validation": 1},
"yo": {"train": 2, "validation": 1},
"zh": {"train": 1024, "validation": 2},
"zh-Latn": {"train": 8, "validation": 1},
"zu": {"train": 8, "validation": 1},
}
class Mc4Config(datasets.BuilderConfig):
"""BuilderConfig for mC4."""
def __init__(self, *args, languages, **kwargs):
"""BuilderConfig for mC4.
Args:
languages (:obj:`List[str]`): list of languages to load
**kwargs: keyword arguments forwarded to super.
"""
super().__init__(
*args,
name="+".join(languages),
**kwargs,
)
self.languages = languages
class Mc4(datasets.GeneratorBasedBuilder):
"""mC4, a colossal, cleaned version of Common Crawl's web crawl corpus."""
BUILDER_CONFIGS = [Mc4Config(languages=[lang]) for lang in _LANGUAGES]
BUILDER_CONFIG_CLASS = Mc4Config
def __init__(self, *args, writer_batch_size=None, **kwargs):
self.sampling_method = kwargs.pop("sampling_method")
if self.sampling_method:
self.perplexity_model = kwargs.pop("perplexity_model")
self.sampling_factor = kwargs.pop("sampling_factor")
self.boundaries = kwargs.pop("boundaries")
# Loading 5-gram model
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
logger.info("loading model = %s", self.perplexity_model)
self.pp_model = kenlm.Model(self.perplexity_model)
if self.sampling_method == "gaussian":
self.should_keep_doc = self._should_keep_doc_gaussian
else:
self.should_keep_doc = self._should_keep_doc_gaussian
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
def get_perplexity(self, doc):
doc_log_score, doc_length = 0, 0
for line in doc.split("\n"):
log_score = self.pp_model.score(line)
length = len(line.split()) + 1
doc_log_score += log_score
doc_length += length
return 10.0 ** (-doc_log_score / doc_length)
def _should_keep_doc_step(self, doc, factor=1, boundaries=None):
perplexity = self.get_perplexity(doc)
if boundaries is None:
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
if perplexity <= boundaries[0]:
quartile_range = boundaries[0]
elif boundaries[0] < perplexity < boundaries[1]:
quartile_range = boundaries[1] - boundaries[0]
elif boundaries[1] < perplexity < boundaries[2]:
quartile_range = boundaries[2] - boundaries[1]
elif perplexity >= boundaries[2]:
quartile_range = 100 * boundaries[2]
probability = factor / quartile_range
return np.random() < probability
def _should_keep_doc_gaussian(self, doc, factor=0.4, boundaries=None):
perplexity = self.get_perplexity(doc)
if boundaries is not None:
m = boundaries[1]
else:
m = 662247.50212365
weighted_perplexity = factor * np.exp(-9/2*((perplexity-m)/m)**2)
return np.random.uniform() < weighted_perplexity
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"text": datasets.Value("string"),
"timestamp": datasets.Value("string"),
"url": datasets.Value("string"),
}
),
supervised_keys=None,
homepage=_URL,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_urls = {}
for split in ["train", "validation"]:
data_urls[split] = [
_DATA_URL.format(
language=self.config.name,
split_suffix="-validation" if split == "validation" else "",
index=index,
n_shards=_N_SHARDS_PER_SPLIT[lang][split],
)
for lang in self.config.languages
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
]
train_downloaded_files = dl_manager.download(data_urls["train"])
validation_downloaded_files = dl_manager.download(data_urls["validation"])
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepaths": validation_downloaded_files}
),
]
def _generate_examples(self, filepaths):
"""This function returns the examples in the raw (text) form by iterating on all the files."""
id_ = 0
for filepath in filepaths:
logger.info("generating examples from = %s", filepath)
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
if self.sampling_method:
logger.info("sampling method = %s", self.sampling_method)
for line in f:
if line:
example = json.loads(line)
if self.should_keep_doc(
example["text"],
factor=self.sampling_factor,
boundaries=self.boundaries):
yield id_, example
id_ += 1
else:
for line in f:
if line:
example = json.loads(line)
yield id_, example
id_ += 1