diff --git a/fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py b/fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py new file mode 100644 index 0000000000000000000000000000000000000000..9a643b9289501b639a3e8722e2d877744fffae84 --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py @@ -0,0 +1,33 @@ +import argparse +import json +from collections import defaultdict +import os +import soundfile as sf +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--exp', type=str) + parser.add_argument('--dump', type=str) + args = parser.parse_args() + + langs = [d for d in os.listdir(args.dump) if os.path.isdir(os.path.join(args.dump, d))] + + data = {} + + for lang in langs: + ids = [int(x.strip()) for x in open(args.dump + "/" + lang + "/ids.txt", "r").readlines()] + word_hyps = [x.strip() for x in open(args.exp + "/" + lang + "/hypo.word.reord", "r").readlines()] + scores = [x.strip() for x in open(args.exp + "/" + lang + "/asr_score.reord", "r").readlines()] + assert len(ids) == len(word_hyps) + assert len(ids) == len(scores) + for id, word_hyp, s in zip(ids, word_hyps, scores): + if id in data: + print("Duplicate ID found") + import pdb;pdb.set_trace() + data[id] = (word_hyp, s) + + with open(args.exp + "/nbest_asr_hyp", "w") as f1, open(args.exp + "/asr_score", "w") as f2: + for i in range(len(data.keys())): + f1.write(data[i][0] + "\n") + f2.write(data[i][1] + "\n") \ No newline at end of file diff --git a/fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py b/fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py new file mode 100644 index 0000000000000000000000000000000000000000..455ee25ec087467434287fb5d8b4dad92b90f833 --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py @@ -0,0 +1,23 @@ +import soundfile as sf +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--src', type=str) + parser.add_argument('--dst', type=str) + args = parser.parse_args() + + wavs = [x.strip() for x in open(args.src, "r").readlines()] + + new_lines = ["/"] + for wav in wavs: + # Read the wav file + data, sample_rate = sf.read(wav) + + # Number of samples is the length of the data array + num_samples = len(data) + + new_lines.append(wav+"\t"+str(num_samples)) + + with open(args.dst, "w") as f: + f.writelines([x+"\n" for x in new_lines]) diff --git a/fairseq/examples/mms/lid_rerank/mms/split_by_lang.py b/fairseq/examples/mms/lid_rerank/mms/split_by_lang.py new file mode 100644 index 0000000000000000000000000000000000000000..b123e406f2db6d106711f1a3a5f8fa272876fa4b --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/mms/split_by_lang.py @@ -0,0 +1,90 @@ +import argparse +import json +from collections import defaultdict +import os +import soundfile as sf +from tqdm import tqdm + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--wavs_tsv', type=str) + parser.add_argument('--lid_preds', type=str) + parser.add_argument('--dst', type=str) + parser.add_argument('--refs', type=str, default=None) + parser.add_argument('--langs', type=str, default=None) + parser.add_argument('--confs', type=str, default=None) + args = parser.parse_args() + + # split wavs into dst/lang/wav.txt and dst/lang/ids.txt + # uses lid_preds to create topk asr; 1 wav has k different lid + + wavs_tsv = [x for x in open(args.wavs_tsv, "r").readlines()] + root = wavs_tsv[0] + wavs = wavs_tsv[1:] + lid_preds = [eval(x) for x in open(args.lid_preds, "r").readlines()] + if args.refs is not None: + refs = [x.strip() for x in open(args.refs, "r").readlines()] + assert len(wavs) == len(refs) + refs_filt = [] + if args.langs is not None: + langs = [x.strip() for x in open(args.langs, "r").readlines()] + assert len(wavs) == len(langs) + langs_filt = [] + if args.confs is not None: + confs = [x.strip() for x in open(args.confs, "r").readlines()] + assert len(wavs) == len(confs) + confs_filt = [] + + assert len(wavs) == len(lid_preds) + + topk_wavs = [] + topk_langs = [] + + for i, (w, p) in enumerate(zip(wavs, lid_preds)): + if p == "n/a": + continue + + assert len(p) == len(lid_preds[0]) + + for l, _ in p: + topk_wavs.append(w) + topk_langs.append(l) + + if args.refs is not None: + refs_filt.append(refs[i]) + if args.langs is not None: + langs_filt.append(langs[i]) + if args.confs is not None: + confs_filt.append(confs[i]) + + lang_split = defaultdict(list) + for id, (wav,lid) in enumerate(zip(topk_wavs, topk_langs)): + lang_split[lid].append((id, wav)) + + for lang in tqdm(lang_split.keys()): + if not os.path.exists(args.dst + "/" + lang): + os.makedirs(args.dst + "/" + lang) + + with open(args.dst + "/" + lang + "/test.tsv", "w") as f1, \ + open(args.dst + "/" + lang + "/ids.txt", "w") as f2: + f1.write(root) + f1.writelines([x[1] for x in lang_split[lang]]) + f2.writelines([str(x[0]) + "\n" for x in lang_split[lang]]) + + with open(args.dst + "/" + lang + "/test.ltr", "w") as fw: + fw.write("d u m m y | d u m m y |\n"*len(lang_split[lang])) + with open(args.dst + "/" + lang + "/test.wrd", "w") as fw: + fw.write("dummy dummy\n"*len(lang_split[lang])) + + with open(args.dst + "/lid.txt", "w") as f: + f.writelines([x+"\n" for x in topk_langs]) + + if args.refs is not None: + with open(args.dst + "/refs.txt", "w") as f: + f.writelines([x+"\n" for x in refs_filt]) + if args.langs is not None: + with open(args.dst + "/langs.txt", "w") as f: + f.writelines([x+"\n" for x in langs_filt]) + if args.confs is not None: + with open(args.dst + "/confs.txt", "w") as f: + f.writelines([x+"\n" for x in confs_filt]) \ No newline at end of file diff --git a/fairseq/examples/mms/lid_rerank/nllb/infer.py b/fairseq/examples/mms/lid_rerank/nllb/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4d69a9078fd17846a615c27952aa2602e0dafb --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/nllb/infer.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# -*- encoding: utf8 -*- +import fasttext +from tqdm import tqdm +import argparse +import os +import math + +parser = argparse.ArgumentParser() +parser.add_argument("--txt", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +parser.add_argument('--lid', type=str) +args = parser.parse_args() + +mapping = {"arb":"ara", "azj":"aze", "pes":"fas", "fuv":"ful", "lvs":"lav", "khk":"mon", "zsm":"zlm", "gaz":"orm", "pbt":"pus", "uzn":"uzb", "zho":"cmn"} + +def fix_code(x): + code = x.split("_")[-2] + if code in mapping: + code = mapping[code] + return code + +if __name__ == "__main__": + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + pretrained_lang_model = args.model + model = fasttext.load_model(pretrained_lang_model) + + txts = [x.strip() for x in open(args.txt, "r").readlines()] + lids = [x.strip() for x in open(args.lid, "r").readlines()] + assert len(txts) == len(lids) + + with open(args.dst + "/wlid_score", "w") as f: + for t,l in tqdm(zip(txts, lids)): + predictions = model.predict(t, k=218) # max 218 + predictions = [(fix_code(x), y) for x, y in zip(predictions[0], predictions[1])] + + try: + pred_langs = [x[0] for x in predictions] + idx = pred_langs.index(l) + score = math.log(predictions[idx][-1]) + except: + score = -1000 + f.write(str(score) + "\n") \ No newline at end of file diff --git a/fairseq/examples/mms/lid_rerank/rerank/rerank.py b/fairseq/examples/mms/lid_rerank/rerank/rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..beea3e6a77342662f578b3900f3dd3cdbd6e3821 --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/rerank/rerank.py @@ -0,0 +1,132 @@ +import argparse +import json +from collections import defaultdict +import os +from tqdm import tqdm +import sys +import subprocess +import re +import math +import numpy as np +import editdistance +from sklearn.preprocessing import StandardScaler +from multiprocessing import Pool +from functools import partial +import random + +cer_langs = [x.strip() for x in open("cer_langs.txt", "r").readlines()] + +def select(w, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=10, exclude=None): + assert len(w) == len(feats[0]) + scores = [] + for f in feats: + s = 0 + for i in range(len(w)): + s += w[i]*f[i] + scores.append(s) + + lid_correct = 0 + lid_total = 0 + asr_err = 0 + asr_total = 0 + text = [] + lang = [] + + for i in range(len(ref_lid)): + if exclude is not None: + if ref_lid[i] in exclude: + continue + + start_idx = i * n + end_idx = start_idx + n + cand_scores = scores[start_idx:end_idx] + max_idx, max_val = max(enumerate(cand_scores), key=lambda x: x[1]) + + cand_feats = feats[start_idx:end_idx] + + lang.append(nbest_lid[start_idx:end_idx][max_idx]) + if ref_lid[i] == nbest_lid[start_idx:end_idx][max_idx]: + lid_correct += 1 + lid_total += 1 + + hyp = nbest_asr[start_idx:end_idx][max_idx] + text.append(hyp) + ref = ref_asr[i] + hyp = hyp.lower() + ref = ref.lower() + hyp = hyp.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + ref = ref.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + if ref_lid[i] in cer_langs: + hyp = " ".join(hyp) + ref = " ".join(ref) + + hyp_words = hyp.split() + tgt_words = ref.split() + errs = editdistance.eval(hyp_words, tgt_words) + asr_err += errs + asr_total += len(tgt_words) + + results = {"lid_acc": lid_correct / lid_total, "asr_wer": asr_err / asr_total, "weights": w} + + return results, text, lang + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--slid', type=str) + parser.add_argument('--wlid', type=str) + parser.add_argument('--asr', type=str) + parser.add_argument('--lm', type=str) + parser.add_argument('--uasr', type=str) + parser.add_argument('--n', type=int, default=10) + parser.add_argument('--dst', type=str) + parser.add_argument('--ref_lid', type=str) + parser.add_argument('--nbest_lid', type=str) + parser.add_argument('--ref_asr', type=str) + parser.add_argument('--nbest_asr', type=str) + parser.add_argument('--w', type=str) + parser.add_argument('--tag', type=str, default = None) + parser.add_argument('--exclude', nargs="*", default=None) # exclude langs + args = parser.parse_args() + + slid = [float(x.strip()) for x in open(args.slid, "r").readlines()] + wlid = [float(x.strip()) for x in open(args.wlid, "r").readlines()] + asr = [float(x.strip()) for x in open(args.asr, "r").readlines()] + lm = [float(x.strip()) for x in open(args.lm, "r").readlines()] + uasr = [float(x.strip()) for x in open(args.uasr, "r").readlines()] + + assert len(slid) == len(wlid) + assert len(wlid) == len(asr) + assert len(asr) == len(lm) + assert len(lm) == len(uasr) + + ref_lid = [x.strip() for x in open(args.ref_lid, "r").readlines()] + nbest_lid= [x.strip() for x in open(args.nbest_lid, "r").readlines()] + ref_asr = [x.strip() for x in open(args.ref_asr, "r").readlines()] + nbest_asr = [x.strip() for x in open(args.nbest_asr, "r").readlines()] + + assert len(ref_lid) * args.n == len(nbest_lid) + assert len(ref_asr) * args.n == len(nbest_asr) + assert len(ref_lid) == len(ref_asr) + + lengths = [len(x) for x in nbest_asr] + + feats = [[s, w, a, l, u, le] for s,w,a,l,u,le in zip(slid, wlid, asr, lm, uasr, lengths)] + + weight = eval(open(args.w, "r").read())['weights'] + + results, text, lang = select(weight, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=args.n, exclude=args.exclude) + + if args.tag is not None: + tag_text = "." + args.tag + else: + tag_text = "" + + with open(args.dst + "/reranked_1best_asr_hyp" + tag_text, "w") as f_out: + f_out.writelines([x+"\n" for x in text]) + + with open(args.dst + "/reranked_1best_lid" + tag_text, "w") as f_out: + f_out.writelines([x+"\n" for x in lang]) + + with open(args.dst + "/text.result" + tag_text, "w") as f_out: + for k in results.keys(): + f_out.write(k + "\t" + str(results[k]) + "\n") diff --git a/fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py b/fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py new file mode 100644 index 0000000000000000000000000000000000000000..fc15f650a73c5bc28450acd8f2f83196fd7956a1 --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py @@ -0,0 +1,138 @@ +import argparse +import os +from tqdm import tqdm +import numpy as np +import editdistance +from multiprocessing import Pool +from functools import partial + +cer_langs = [x.strip() for x in open("cer_langs.txt", "r").readlines()] + +def compute(w, feats, ref_lid, nbest_lid, ref_asr, nbest_asr, n=10, exclude=None): + assert len(w) == len(feats[0]) + scores = [] + for f in feats: + s = 0 + for i in range(len(w)): + s += w[i]*f[i] + scores.append(s) + + lid_correct = 0 + lid_total = 0 + asr_err = 0 + asr_total = 0 + + for i in range(len(ref_lid)): + if exclude is not None: + if ref_lid[i] in exclude: + continue + + start_idx = i * n + end_idx = start_idx + n + cand_scores = scores[start_idx:end_idx] + max_idx, max_val = max(enumerate(cand_scores), key=lambda x: x[1]) + + if ref_lid[i] == nbest_lid[start_idx:end_idx][max_idx]: + lid_correct += 1 + lid_total += 1 + + hyp = nbest_asr[start_idx:end_idx][max_idx] + ref = ref_asr[i] + hyp = hyp.lower() + ref = ref.lower() + hyp = hyp.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + ref = ref.replace(".", "").replace(",", "").replace("?", "").replace("!", "").replace(":", "").replace(")", "").replace("(", "").replace("-", "") + if ref_lid[i] in cer_langs: + hyp = " ".join(hyp) + ref = " ".join(ref) + + hyp_words = hyp.split() + tgt_words = ref.split() + errs = editdistance.eval(hyp_words, tgt_words) + asr_err += errs + asr_total += len(tgt_words) + + return {"lid_acc": lid_correct / lid_total, "asr_wer": asr_err / asr_total, "weights": w} + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Example argument parser') + parser.add_argument('--slid', type=str) + parser.add_argument('--wlid', type=str) + parser.add_argument('--asr', type=str) + parser.add_argument('--lm', type=str) + parser.add_argument('--uasr', type=str) + parser.add_argument('--n', type=int, default=10) + parser.add_argument('--dst', type=str) + parser.add_argument('--ref_lid', type=str) + parser.add_argument('--nbest_lid', type=str) + parser.add_argument('--ref_asr', type=str) + parser.add_argument('--nbest_asr', type=str) + parser.add_argument('--iters', type=int, default=10000) + parser.add_argument('--slid_scale', type=int, default = 100) + parser.add_argument('--wlid_scale', type=int, default = 100) + parser.add_argument('--asr_scale', type=int, default = 10) + parser.add_argument('--lm_scale', type=int, default = 10) + parser.add_argument('--uasr_scale', type=int, default = 10) + parser.add_argument('--len_scale', type=int, default = 1) + parser.add_argument('--num_jobs', type=int, default = 64) + parser.add_argument('--exclude', nargs="*", default=None) # exclude langs + args = parser.parse_args() + + slid = [float(x.strip()) for x in open(args.slid, "r").readlines()] + wlid = [float(x.strip()) for x in open(args.wlid, "r").readlines()] + asr = [float(x.strip()) for x in open(args.asr, "r").readlines()] + lm = [float(x.strip()) for x in open(args.lm, "r").readlines()] + uasr = [float(x.strip()) for x in open(args.uasr, "r").readlines()] + + assert len(slid) == len(wlid) + assert len(wlid) == len(asr) + assert len(asr) == len(lm) + assert len(lm) == len(uasr) + + ref_lid = [x.strip() for x in open(args.ref_lid, "r").readlines()] + nbest_lid= [x.strip() for x in open(args.nbest_lid, "r").readlines()] + ref_asr = [x.strip() for x in open(args.ref_asr, "r").readlines()] + nbest_asr = [x.strip() for x in open(args.nbest_asr, "r").readlines()] + + assert len(ref_lid) * args.n == len(nbest_lid) + assert len(ref_asr) * args.n == len(nbest_asr) + assert len(ref_lid) == len(ref_asr) + + lengths = [len(x) for x in nbest_asr] + + feats = [[s, w, a, l, u, le] for s,w,a,l,u,le in zip(slid, wlid, asr, lm, uasr, lengths)] + + weights = [] + for i in range(args.iters): + s_w = np.random.rand() * args.slid_scale + w_w = np.random.rand() * args.wlid_scale + a_w = np.random.rand() * args.asr_scale + l_w = np.random.rand() * args.lm_scale + u_w = np.random.rand() * args.uasr_scale + le_w = (np.random.rand() -0.5) * args.len_scale + weights.append([s_w, w_w, a_w, l_w, u_w, le_w]) + + num_tries = len(weights) + print("Total number of search points", num_tries) + threads = args.num_jobs + pool = Pool(threads) + compute_fxn = partial(compute, feats=feats, ref_lid=ref_asr, nbest_lid=nbest_lid, ref_asr=ref_asr, nbest_asr=nbest_asr, n=args.n, exclude=args.exclude) + results = pool.map(compute_fxn, weights) + pool.close() + pool.join() + + assert len(results) == len(weights) + + wer_best = 100 + best = "" + if not os.path.exists(args.dst): + os.makedirs(args.dst) + with open(args.dst + "/results.all", "w") as f_out: + for result in results: + f_out.write(str(result)+"\n") + if result["asr_wer"] < wer_best: + wer_best = result["asr_wer"] + best = result + + with open(args.dst + "/best_coefficients", "w") as f_out: + f_out.write(str(best)+"\n") \ No newline at end of file diff --git a/fairseq/examples/mms/lid_rerank/whisper/infer_lid.py b/fairseq/examples/mms/lid_rerank/whisper/infer_lid.py new file mode 100644 index 0000000000000000000000000000000000000000..150e0bbcca378626dcbab47d4495ba8578949a9e --- /dev/null +++ b/fairseq/examples/mms/lid_rerank/whisper/infer_lid.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- encoding: utf8 -*- +import argparse +import itertools +import os +import re +import sys +from pathlib import Path +import math + +import whisper +from tqdm import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument("--wavs", type=str) +parser.add_argument("--dst", type=str) +parser.add_argument("--model", type=str) +parser.add_argument("--n", type=int, default=10) +parser.add_argument("--mapping", type=str, default="whisper/lid_mapping.txt") +args = parser.parse_args() + +if __name__ == "__main__": + model = whisper.load_model(args.model) + + print(args) + + wavs = [x.strip() for x in open(args.wavs, "r").readlines()] + if not os.path.exists(args.dst): + os.makedirs(args.dst) + + if args.mapping is not None: + #whisper_lid_code:mms_lid_code + mapping = {x[0]:x[1] for x in [l.strip().split(";", 1) for l in open(args.mapping, "r").readlines()]} + else: + mapping = None + + with open(args.dst + "/predictions", "w") as f: + for wav in tqdm(wavs): + # load audio and pad/trim it to fit 30 seconds + audio = whisper.load_audio(wav) + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(model.device) + + _, probs = model.detect_language(mel) + result = sorted(probs.items(), key=lambda x:x[1], reverse=True)[:args.n] + f.write(str(result) + "\n") + + lid_preds = [eval(x) for x in open(args.dst + "/predictions", "r").readlines()] + lids = [] + scores = [] + for p in lid_preds: + assert len(p) == len(lid_preds[0]) + for l, s in p: + if args.mapping is not None: + lids.append(mapping[l]) + else: + lids.append(l) + scores.append(math.log(s)) + with open(args.dst + "/nbest_lid", "w") as f: + f.writelines([x+"\n" for x in lids]) + with open(args.dst + "/slid_score", "w") as f: + f.writelines([str(x)+"\n" for x in scores]) \ No newline at end of file diff --git a/fairseq/examples/moe_lm/data_card.md b/fairseq/examples/moe_lm/data_card.md new file mode 100644 index 0000000000000000000000000000000000000000..54e694b62088f467483504f727a23565b894e0a2 --- /dev/null +++ b/fairseq/examples/moe_lm/data_card.md @@ -0,0 +1,221 @@ +# Data card for the paper "Efficient Large Scale Language Modeling with Mixtures of Experts" +## Version 1.0.0 + +We follow the recommendations of Gebru et al. (2018) and provide a datacard for the dataset used to train the 1.1T parameter model. + +## Motivation +* **For what purpose was the dataset created? Was there a specific task in mind? Was there a specific gap that needed to be filled? Please provide a description.** +The pre-training data for training the 1.1 T model was created by a union of six English language datasets, including five datasets used by RoBERTa (Liu et al 2019) and the English subset of CC 100. These purpose of creating this dataset was to pre-train the language model. + +* **Who created the dataset (e.g., which team, research group) and on behalf of which entity (e.g., company, institution, organization)?** +FAIR (Fundamental Artificial Intelligence Research) + +* **Who funded the creation of the dataset? If there is an associated grant, please provide the name of the grantor and the grant name and number.** +FAIR (Fundamental Artificial Intelligence Research) + +* **Any other comments?** +No. + +## Composition + +* **What do the instances that comprise the dataset represent (e.g., documents, photos, people, countries)? Are there multiple types of instances (e.g., movies, users, and ratings; people and interactions between them; nodes and edges)? Please provide a description.** +The instances are textual documents. The overall dataset is composed from a union of the following datasets - + * BookCorpus (Zhu et al., 2019) consists of more than 10K unpublished books (4GB); + * English Wikipedia, excluding lists, tables and headers (12GB); + * CC-News (Nagel,2016) contains 63 million English news articles crawled between September 2016 and February 2019 (76GB); + * OpenWebText (Gokaslan and Cohen, 2019), an open source recreation of the WebText dataset used to train GPT-2 (38GB); + * CC-Stories (Trinh and Le, 2018) contains a subset of CommonCrawl data filtered to match the story-like style of Winograd schemas (31GB); + * English CC100 (Wenzek et al., 2020), a dataset extracted from CommonCrawl snapshots between January 2018 and December 2018, filtered to match the style of Wikipedia (292GB). + +* **How many instances are there in total (of each type, if appropriate)?** +The training data contains 112B tokens corresponding to 453 GB of data. + +* **Does the dataset contain all possible instances or is it a sample (not necessarily random) of instances from a larger set? If the dataset is a sample, then what is the larger set? Is the sample representative of the larger set (e.g., geographic coverage)? If so, please describe how this representativeness was validated/verified. If it is not representative of the larger set, please describe why not (e.g., to cover a more diverse range of instances, because instances were withheld or unavailable).** +The English CC100 section of the dataset is a subset of CommonCrawl snapshots extracted between January 2018 to December 2018, filtered to match the style of Wikipedia. The CC-stories dataset contains a subset of CommonCrawl data filtered to match the story-like style of Winograd schemas. + +* **What data does each instance consist of? “Raw” data (e.g., unprocessed text or images) or features? In either case, please provide a description.** +Each instance consists of raw text data. + +* **Is there a label or target associated with each instance? If so, please provide a description.** +No. + +* **Is any information missing from individual instances? If so, please provide a description, explaining why this information is missing (e.g., because it was unavailable). This does not include intentionally removed information, but might include, e.g., redacted text.** +No. + +* **Are relationships between individual instances made explicit (e.g., users' movie ratings, social network links)? If so, please describe how these relationships are made explicit.** +There are no explicit relationships between individual instances. + +* **Are there recommended data splits (e.g., training, development/validation, testing)? If so, please provide a description of these splits, explaining the rationale behind them.** +We hold out a random validation set of approximately 150MB from the pretraining data, sampled proportionally to each dataset's size in the pretraining corpus. + +* **Are there any errors, sources of noise, or redundancies in the dataset? If so, please provide a description.** +N/A + +* **Is the dataset self-contained, or does it link to or otherwise rely on external resources (e.g., websites, tweets, other datasets)?** +It's self-contained. + +* **Does the dataset contain data that might be considered confidential (e.g., data that is protected by legal privilege or by doctor-patient confidentiality, data that includes the content of individuals' non-public communications)? If so, please provide a description.** +The datasets used are publicly available, and the information in them is not considered confidential. + +* **Does the dataset contain data that, if viewed directly, might be offensive, insulting, threatening, or might otherwise cause anxiety? If so, please describe why.** +Parts of the dataset are a subset of public Common Crawl data, which could contain sentences that, if viewed directly, might be offensive, insulting, threatening, or might otherwise cause anxiety. + +* **Does the dataset relate to people? If not, you may skip the remaining questions in this section.** +Some documents of this data relate to people, such as news articles, Wikipedia descriptions, etc. + +* **Does the dataset identify any subpopulations (e.g., by age, gender)? If so, please describe how these subpopulations are identified and provide a description of their respective distributions within the dataset.** +No. + +* **Is it possible to identify individuals (i.e., one or more natural persons), either directly or indirectly (i.e., in combination with other data) from the dataset? If so, please describe how** +In addition to individuals who have Wikipedia pages (celebrities, politicians, etc.), it may be possible to identify other individuals by their names, Twitter account names, etc. if that information is present in Common Crawl. + +* **Does the dataset contain data that might be considered sensitive in any way (e.g., data that reveals racial or ethnic origins, sexual orientations, religious beliefs, political opinions or union memberships, or locations; financial or health data; biometric or genetic data; forms of government identification, such as social security numbers; criminal history)? If so, please provide a description.** +The training dataset is partially derived from Common Crawl, which may contain some sensitive information. + +* **Any other comments?** +No + + +## Collection Process + +* **How was the data associated with each instance acquired? Was the data directly observable (e.g., raw text, movie ratings), reported by subjects (e.g., survey responses), or indirectly inferred/ derived from other data (e.g., part-of-speech tags, model-based guesses for age or language)? If data was reported by subjects or indirectly inferred/derived from other data, was the data validated/verified? If so, please describe how.** +N/A. The dataset is a union of six publicly available datasets. + +* **What mechanisms or procedures were used to collect the data (e.g., hardware apparatus or sensor, manual human curation, software program, software API)? How were these mechanisms or procedures validated?** +N/A + +* **If the dataset is a sample from a larger set, what was the sampling strategy (e.g., deterministic, probabilistic with specific sampling probabilities)?** +Please refer to the main document for details. + +* **Who was involved in the data collection process (e.g., students, crowdworkers, contractors) and how were they compensated (e.g., how much were crowdworkers paid)?** +This data is mined, filtered and sampled by machines. + +* **Over what timeframe was the data collected? Does this timeframe match the creation timeframe of the data associated with the instances (e.g., recent crawl of old news articles)? If not, please describe the timeframe in which the data associated with the instances was created.** +Different parts of the dataset were mined over different time periods. +1. The CC-News dataset contains English news articles crawled between September 2016 and February 2019. +2. The English CC-100 dataset was extracted from CommonCrawl snapshots between January 2018 and December 2018. + +* **Were any ethical review processes conducted (e.g., by an institutional review board)? If so, please provide a description of these review processes, including the outcomes, as well as a link or other access point to any supporting documentation.** +No. + +* **Does the dataset relate to people? If not, you may skip the remainder of the questions in this section.** +No. + +* **Did you collect the data from the individuals in question directly, or obtain it via third parties or other sources (e.g., websites)?** +N/A + +* **Were the individuals in question notified about the data collection? If so, please describe (or show with screenshots or other information) how notice was provided, and provide a link or other access point to, or otherwise reproduce, the exact language of the notification itself.** +N/A + +* **Did the individuals in question consent to the collection and use of their data? If so, please describe (or show with screenshots or other information) how consent was requested and provided, and provide a link or other access point to, or otherwise reproduce, the exact language to which the individuals consented.** +N/A + +* **If consent was obtained, were the consenting individuals provided with a mechanism to revoke their consent in the future or for certain uses? If so, please provide a description, as well as a link or other access point to the mechanism (if appropriate).** +N/A + +* **Has an analysis of the potential impact of the dataset and its use on data subjects (e.g., a data protection impact analysis) been conducted? If so, please provide a description of this analysis, including the outcomes, as well as a link or other access point to any supporting documentation.** +Some responsible AI related evaluations were performed. Please refer to the main document and the model card for the paper. + +* **Any other comments?** +No + + +## Preprocessing/cleaning/labeling + + +* **Was any preprocessing/cleaning/labeling of the data done (e.g., discretization or bucketing, tokenization, part-of-speech tagging, SIFT feature extraction, removal of instances, processing of missing values)? If so, please provide a description. If not, you may skip the remainder of the questions in this section.** +The component datasets went through standard cleaning and re-formatting practices, including removing repetitive/non informative text like "Chapter One", or "This ebook by Project Gutenberg". + +* **Was the “raw” data saved in addition to the preprocessed/cleaned/labeled data (e.g., to support unanticipated future uses)? If so, please provide a link or other access point to the “raw” data.** +The "raw" component datasets is publicly available in their respective locations (more details can be seen in the respective papers linked in references). + +* **Is the software used to preprocess/clean/label the instances available? If so, please provide a link or other access point.** +The software is proprietary to Meta Platforms and currently unavailable publicly. + +* **Any other comments?** +No + + +## Uses + +* **Has the dataset been used for any tasks already? If so, please provide a description.** +Yes, this dataset was used to pre-train the models described in the paper. + +* **Is there a repository that links to any or all papers or systems that use the dataset? If so, please provide a link or other access point.** +No. + +* **What (other) tasks could the dataset be used for?** +This data can be used to pretrain English language models, which are foundation to many current and future language tasks. + +* **Is there anything about the composition of the dataset or the way it was collected and preprocessed/cleaned/labeled that might impact future uses? For example, is there anything that a future user might need to know to avoid uses that could result in unfair treatment of individuals or groups (e.g., stereotyping, quality of service issues) or other undesirable harms (e.g., financial harms, legal risks) If so, please provide a description. Is there anything a future user could do to mitigate these undesirable harms?** +The pipeline for creating this dataset paves a way for building a scalable infrastructure for mining datasets to be be used for training large-scale models. + +* **Are there tasks for which the dataset should not be used? If so, please provide a description.** +No. + +* **Any other comments?** +No. + +## Distribution + + +* **Will the dataset be distributed to third parties outside of the entity (e.g., company, institution, organization) on behalf of which the dataset was created? If so, please provide a description.** +No. + +* **How will the dataset will be distributed (e.g., tarball on website, API, GitHub)? Does the dataset have a digital object identifier (DOI)?** +N/A + +* **When will the dataset be distributed?** +No. + +* **Will the dataset be distributed under a copyright or other intellectual property (IP) license, and/or under applicable terms of use (ToU)? If so, please describe this license and/or ToU, and provide a link or other access point to, or otherwise reproduce, any relevant licensing terms or ToU, as well as any fees associated with these restrictions.** +No. + +* **Have any third parties imposed IP-based or other restrictions on the data associated with the instances? If so, please describe these restrictions, and provide a link or other access point to, or otherwise reproduce, any relevant licensing terms, as well as any fees associated with these restrictions.** +No. + +* **Do any export controls or other regulatory restrictions apply to the dataset or to individual instances? If so, please describe these restrictions, and provide a link or other access point to, or otherwise reproduce, any supporting documentation.** +N/A + +* **Any other comments?** +No. + +## Maintenance + +* **Who is supporting/hosting/maintaining the dataset?** +FAIR (Fundamental Artificial Intelligence Research) + +* **How can the owner/curator/manager of the dataset be contacted (e.g., email address)?** +Refer to the main document. + +* **Is there an erratum? If so, please provide a link or other access point.** +N/A + +* **Will the dataset be updated (e.g., to correct labeling errors, add new instances, delete instances)? If so, please describe how often, by whom, and how updates will be communicated to users (e.g., mailing list, GitHub)?** +No plan for updating. + +* **If the dataset relates to people, are there applicable limits on the retention of the data associated with the instances (e.g., were individuals in question told that their data would be retained for a fixed period of time and then deleted)? If so, please describe these limits and explain how they will be enforced.** +N/A + +* **Will older versions of the dataset continue to be supported/hosted/maintained? If so, please describe how. If not, please describe how its obsolescence will be communicated to users.** +N/A + +* **If others want to extend/augment/build on/contribute to the dataset, is there a mechanism for them to do so? If so, please provide a description. Will these contributions be validated/ verified? If so, please describe how. If not, why not? Is there a process for communicating/ distributing these contributions to other users? If so, please provide a description.** +No. + +* **Any other comments?** +No. + +## References +Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. 2019. Roberta: A robustly optimized bert pretraining approach. arXiv preprint arXiv:1907.11692. + +Yukun Zhu, Ryan Kiros, Richard Zemel, Ruslan Salakhutdinov, Raquel Urtasun, Antonio Torralba, and Sanja Fidler. 2019. Aligning books and movies: Towards story-like visual explanations by watching movies and reading books. arXiv:1506.06724. + +Sebastian Nagel. 2016. Cc-news. http: //web.archive.org/save/http: //commoncrawl.org/2016/10/news-dataset-available. + +Aaron Gokaslan and Vanya Cohen. 2019. Openwebtext corpus. http://web.archive.org/save/http://Skylion007.github.io/OpenWebTextCorpus + +Trieu H Trinh and Quoc V Le. 2018. A simple method for commonsense reasoning. arXiv preprint arXiv:1806.02847. + +Guillaume Wenzek, Marie-Anne Lachaux, Alexis Conneau, Vishrav Chaudhary, Francisco Guzmán, Armand Joulin, and Edouard Grave. 2020. CCNet: Extracting high quality monolingual datasets from web crawl data. In Proceedings of the 12th Language Resources and Evaluation Conference, pages 4003–4012, Marseille, France. European Language Resources Association. + diff --git a/fairseq/examples/moe_lm/model_card.md b/fairseq/examples/moe_lm/model_card.md new file mode 100644 index 0000000000000000000000000000000000000000..a1cd68116aeff8d401ddd96461b7159151840ac9 --- /dev/null +++ b/fairseq/examples/moe_lm/model_card.md @@ -0,0 +1,170 @@ +# Model card for the paper ``Efficient Large Scale Language Modeling with Mixtures of Experts" +## Version 1.0.0 + +### Model developer +FAIR (Fundamental Artificial Intelligence Research) + +### Model type +An autoregressive English language model trained on a union of six English language models. We explore dense and sparse (MoE based) architectures in the paper. +* Dense models - Our dense models range from 125M parameters to 13B parameters. +* Sparse (MoE) models - Our MoE based models range from 15B parameters to 1.1 Trillion parameters. +This model card focuses on the 1.1 Trillion parameter model, but the discussion +applies to all of the models explored in this work. + +### Citation details +Artetxe et al. (2021): Efficient Large Scale Language Modeling with Mixtures of Experts + +### Model Feedback Channel +fairseq + +## Intended use +### Primary intended use +For research purposes only, e.g. reproducing model evaluation results. Generation is only used in a limited capacity for explanation/justification or for prompting/probing/priming for class labels. + +### Out of scope uses +The primary purpose of the model is not to generate language, although the model is capable of doing that. + +## Factors influencing model performance +This section discusses potential risks associated with using the model. + +### Relevant factors +Based on known problems with NLP technology, potential relevant factors include bias (gender, profession, race and religion). + +### Evaluation factors +The 1.1T model was evaluated on StereoSet and CrowS-Pairs datasets to quantify encoded bias in the model. + +## Metrics +### Model performance measures +The 1.1T parameter model was primarily evaluated on +1. In-domain and out-of-domain language modeling perplexity. +2. Zero-shot and few-shot priming. +3. Fully supervised finetuning. + +### Approaches to handle uncertainty +For few-shot learning, we report the average results across 25 runs, randomly sampling a different set of few-shot examples from the training set each time. + +## Evaluation data +## Zero Shot evaluation + +### HellaSwag +#### Description +HellaSwag is a dataset for evaluating commonsense reasoning. + +### PIQA +#### Description +PIQA is a dataset designed to evaluate reasoning about Physical Commonsense in Natural Language + +### ReCoRd +#### Description +Reading Comprehension with Commonsense Reasoning Dataset (ReCoRD) is a large-scale reading comprehension dataset which requires commonsense reasoning. ReCoRD consists of queries automatically generated from CNN/Daily Mail news articles; the answer to each query is a text span from a summarizing passage of the corresponding news. The goal of ReCoRD is to evaluate a machine's ability of commonsense reasoning in reading comprehension. + +## Few Shot evaluation +### Winogrande +#### Description +Winogrande is a benchmark for commonsense reasoning. The dataset contains pronoun resolution problems originally designed to be unsolvable for statistical models that rely on selectional preferences or word associations. + +### StoryCloze +#### Description +StoryCloze is a new commonsense reasoning framework for evaluating story understanding, story generation, and script learning. This test requires a system to choose the correct ending to a four-sentence story. + +### OpenBookQA +#### Description +OpenBookQA is a new kind of question-answering dataset modeled after open book exams for assessing human understanding of a subject. It consists of 5,957 multiple-choice elementary-level science questions (4,957 train, 500 dev, 500 test), which probe the understanding of a small “book” of 1,326 core science facts and the application of these facts to novel situations. + +## Fully supervised evaluation + +### BoolQ +#### Description +BoolQ is a question answering dataset for yes/no questions containing 15942 examples. These questions are naturally occurring – they are generated in unprompted and unconstrained settings. Each example is a triplet of (question, passage, answer), with the title of the page as optional additional context. + +### SST-2 +#### Description +SST-2 (or SST-binary) is a binary classification dataset where the goal is to differentiate between negative or somewhat negative vs somewhat positive or positive. + +### MNLI +#### Description +The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. + +## Responsible AI (RAI) evaluation +### StereoSet +#### Description +A large-scale natural dataset in English to measure stereotypical biases in four domains: gender, profession, race, and religion + +#### Motivation for dataset use +The motivation for evaluating the 1.1T parameter model on this dataset is to evaluate the model's stereotype bias in gender, profession, race, and religion + +### CrowS +#### Description +Challenge Dataset for Measuring Social Biases in Masked Language Models + +#### Motivation for dataset use +The motivation for evaluating the 1.1T parameter model on this dataset is to evaluate the model’s bias in the domains of race, religion and age + +---- + +## Training data +### BookCorpus +#### Description +A dataset consisting of more than 10K unpublished books. 4GB in size. (Zhu et al., 2019) + +### English Wikipedia +#### Description +Data from English wikipedia, excluding lists, tables and headers. 12GB in size. + +### CC-News +#### Description +A dataset containing 63 millions English news articles crawled between September 2016 and February 2019. 76GB in size. (Nagel,2016) + +### OpenWebText +#### Description +An open source recreation of the WebText dataset used to train GPT-2. 38GB in size. (Gokaslan and Cohen, 2019) + +### CC-Stories +#### Description +A dataset containing a subset of CommonCrawl data filtered to match the story-like style of Winograd schemas. 31GB in size. (Trinh and Le, 2018) + +### English CC100 +#### Description +A dataset extracted from CommonCrawl snapshots between January 2018 and December 2018, filtered to match the style of Wikipedia following the methodology introduced in CCNet (https://arxiv.org/abs/1911.00359). 292GB in size. (Wenzek et al., 2020) + +## Responsible AI (RAI) Dimensions +### Fairness (Bias and inclusion) +The 1.1T parameter model was evaluated on the StereoSet and CrowS pairs dataset for inherent bias in the model, and bias as a result of the data. Similar to StereoSet, we observe that both the dense and MoE models get worse in terms of the Stereotype Score (SS) with scale. + +### Privacy and security +The 1.1T model did not have any special Privacy and Security considerations. The training data and evaluation data were both public and went through standard Meta privacy and licensing procedures. + +### Transparency and control +In the spirit of transparency and accountability we have created this model card for the 1.1T parameter model and a data card for the training data (referenced in Artetxe et al. (2021)). + +### Efficiency (Green AI) +The 1.1T parameter model is trained as a Mixture of Experts (MoE) model. Mixture of expert (MoE) models are efficient because they leverage sparse computation, i.e., only a small fraction of parameters are active for any given input. For instance, our 1.1T parameter MoE model requires only 30% more FLOPS compared to a 6.7B parameter dense model, i.e., a 160x increase in parameters with only a 30% increase in FLOPS. Notably, MoE models achieve much better validation perplexity for a given compute budget compared to dense models. + +## References +Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. 2019. HellaSwag: Can a machine really finish your sentence? In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 4791– 4800, Florence, Italy. Association for Computational Linguistics. + +Yonatan Bisk, Rowan Zellers, Ronan Le bras, Jianfeng Gao, and Yejin Choi. 2020. Piqa: Reasoning about physical commonsense in natural language. Proceedings of the AAAI Conference on Artificial Intelligence, 34(05):7432–7439. + +Sheng Zhang, Xiaodong Liu, Jingjing Liu, Jianfeng Gao, Kevin Duh, and Benjamin Van Durme. 2018. ReCoRD: Bridging the gap between human and machine commonsense reading comprehension. arXiv preprint 1810.12885. + +Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. 2020. Winogrande: An adversarial winograd schema challenge at scale. Proceedings of the AAAI Conference on Artificial Intelligence, 34(05):8732–8740. + +Nasrin Mostafazadeh, Nathanael Chambers, Xiaodong He, Devi Parikh, Dhruv Batra, Lucy Vanderwende, Pushmeet Kohli, and James Allen. 2016. A corpus and cloze evaluation for deeper understanding of commonsense stories. In Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 839–849, San Diego, California. Association for Computational Linguistics. + +Todor Mihaylov, Peter Clark, Tushar Khot, and Ashish Sabharwal. 2018. Can a suit of armor conduct electricity? a new dataset for open book question answering. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pages 2381–2391, Brussels, Belgium. Association for Computational Linguistics. + +Christopher Clark and Kenton Lee and Ming-Wei Chang and Tom Kwiatkowski and Michael Collins and Kristina Toutanova. 2019. BoolQ: Exploring the Surprising Difficulty of Natural Yes/No Questions + +Moin Nadeem, Anna Bethke, and Siva Reddy. 2021. StereoSet: Measuring stereotypical bias in pretrained language models. In Association for Computational Linguistics (ACL). + +Nikita Nangia, Clara Vania, Rasika Bhalerao, and Samuel R. Bowman. 2020. CrowS-pairs: A challenge dataset for measuring social biases in masked language models. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 1953–1967, Online. Association for Computational Linguistics. + +Yukun Zhu, Ryan Kiros, Richard Zemel, Ruslan Salakhutdinov, Raquel Urtasun, Antonio Torralba, and Sanja Fidler. 2019. Aligning books and movies: Towards story-like visual explanations by watching movies and reading books. arXiv:1506.06724. + +Sebastian Nagel. 2016. Cc-news. http: //web.archive.org/save/http: //commoncrawl.org/2016/10/news-dataset-available. + +Aaron Gokaslan and Vanya Cohen. 2019. Openwebtext corpus. http://web.archive.org/save/http://Skylion007.github.io/OpenWebTextCorpus + +Trieu H Trinh and Quoc V Le. 2018. A simple method for commonsense reasoning. arXiv preprint arXiv:1806.02847. + +Guillaume Wenzek, Marie-Anne Lachaux, Alexis Conneau, Vishrav Chaudhary, Francisco Guzmán, Armand Joulin, and Edouard Grave. 2020. CCNet: Extracting high quality monolingual datasets from web crawl data. In Proceedings of the 12th Language Resources and Evaluation Conference, pages 4003–4012, Marseille, France. European Language Resources Association. diff --git a/fairseq/examples/mr_hubert/README.md b/fairseq/examples/mr_hubert/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e72c09c047958363dedae8fa49106f2209b098ba --- /dev/null +++ b/fairseq/examples/mr_hubert/README.md @@ -0,0 +1,187 @@ +# MR-HuBERT + +## Pre-trained models + +### Main models +Model | Pretraining Data | Model | Paper Reference +|---|---|---|--- +MR-HuBERT Base (~97M) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_base/mrhubert_mono_base.pt) | mono\_base +MR-HuBERT Base (~321M) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_large/mrhubert_mono_large.pt) | mono\_large +Multilingual MR-HuBERT Base (~97M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_base/multi_base.pt) | multi\_base +Multilingual MR-HuBERT Large (~321M) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download 400k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_400k.pt) or [download 600k steps](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large_600k.pt) | Not in the paper + + +### Abalation models +Model | Pretraining Data | Model | Paper Reference +|---|---|---|--- +MR-HuBERT Base (2-4-6 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-a/b1-a.pt) | (B.1)-a +MR-HuBERT Base (5-2-5 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-b/b1-b.pt) | (B.1)-b +MR-HuBERT Base (6-4-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-c/b1-c.pt) | (B.1)-c +MR-HuBERT Base (3res 3-2-2-2-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-a/b2-a.pt) | (B.2)-a +MR-HuBERT Base (3res 2-2-4-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-b/b2-b.pt) | (B.2)-b +MR-HuBERT Base (3res 2-2-2-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-c/b2-c.pt) | (B.2)-c +MR-HuBERT Base (Simple sampling) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b3-a/b3-a.pt) | (B.3)-a +MR-HuBERT Base (Single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-a/b4-a.pt) | (B.4)-a +MR-HuBERT Base (Simple Sampling + single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-b/b4-b.pt) | (B.4)-b +MR-HuBERT Base (Mono-resolution 20ms) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b5-a/b5-a.pt) | (B.5)-a +MR-HuBERT Base (3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-a/b6-a.pt) | (B.6)-a +MR-HuBERT Base (Mono-resolution 20ms, 3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-b/b6-b.pt) | (B.6)-b +MR-HuBERT Base (HuBERT 20ms&40ms units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-a/b7-a.pt) | (B.7)-a +MR-HuBERT Base (Encodec 50Hz unit) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-b/b7-b.pt) | (B.7)-b +MR-HuBERT Base (Encodec 50Hz units and 25Hz units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-c/b7-c.pt) | (B.7)-c +MR-HuBERT Base (Encodec 50Hz units stream 0&1 ) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-d/b7-d.pt) | (B.7)-d +MR-HuBERT Large (no audio norm) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-a.pt) | (B.8)-a +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-b/b8-b.pt) | (B.8)-b +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-c/b8-c.pt) | (B.8)-c +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-d/b8-d.pt) | (B.8)-d +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-e/b8-e.pt) | (B.8)-e +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-f/b8-f.pt) | (B.8)-f +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-g/b8-g.pt) | (B.8)-g +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-h/b8-h.pt) | (B.8)-h +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-i/b8-i.pt) | (B.8)-i +MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-j/b8-j.pt) | (B.8)-j +Multilingual MR-HuBERT Large (Simple sampling) | [Voxpopuli](https://github.com/facebookresearch/voxpopuli) 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_large_simple/multi_large_simple.pt) | Not in paper +MR-HuBERT xLarge (from HuBERT-base label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v1.pt) | Not in paper +MR-HuBERT xLarge (from HuBERT-large label) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_xlarge/v2.pt) | Not in paper + +## Load a model +``` +ckpt_path = "/path/to/the/checkpoint.pt" +models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) +model = models[0] +``` + +## Train a new model + +### Data preparation + +Follow the steps in `./simple_kmeans` to create: +- `{train,valid}.tsv` waveform list files with length information +``` +/path/to/your/audio/files +file1.wav\t160000 +file2.wav\t154600 +... +filen.wav\t54362 +``` +- `{train,valid}.km` frame-aligned pseudo label files (the order is the same as wavefiles in the tsv file). +``` +44 44 44 48 48 962 962 962 962 962 962 962 962 967 967 967 967 967 967 967 967 370 852 370 ... 18 18 745 745 +44 44 44 48 48 962 962 962 147 147 147 147 147 147 147 147 147 147 147 147 176 176 271 271 ... 27 27 745 745 +... +44 44 44 48 962 962 962 962 962 962 377 377 377 77 77 852 696 694 433 578 578 82 740 622 ... 27 27 745 745 +``` +- `dict.km.txt` a dummy dictionary (first column is id, the second is dummy one) +``` +0 1 +1 1 +2 1 +... +999 1 +``` + +The `label_rate` is the same as the feature frame rate used for clustering, +which is 100Hz for MFCC features and 50Hz for HuBERT features by default. + +### Pre-train a MR-HuBERT model + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km` +are saved at `/path/to/labels`, and the label rate is 100Hz. + +To train a base model (12 layer transformer), run: +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/pretrain \ + --config-name mrhubert_base_librispeech \ + task.data=/path/to/data task.label_dir=/path/to/labels \ + task.labels='["km"]' model.label_rate=100 \ + task.label_rate_ratios='[1, 2]' \ +``` + +Please see sample pre-training scripts `train.sh` for an example script. + +### Fine-tune a MR-HuBERT model with a CTC loss + +Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their +corresponding character transcripts `{train,valid}.ltr` are saved at +`/path/to/trans`. A typical ltr file is with the same order of tsv waveform files as +``` +HOW | ARE | YOU +... +THANK | YOU +``` + +To fine-tune a pre-trained MR-HuBERT model at `/path/to/checkpoint`, run +```sh +$ python fairseq_cli/hydra_train.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/finetune \ + --config-name base_10h \ + task.data=/path/to/data task.label_dir=/path/to/trans \ + model.w2v_path=/path/to/checkpoint +``` + +Please see sample fine-tuning scripts `finetune.sh` for an example script. + +### Decode a MR-HuBERT model + +Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of +the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is +saved at `/path/to/checkpoint`. + + +We support three decoding modes: +- Viterbi decoding: greedy decoding without a language model +- KenLM decoding: decoding with an arpa-format KenLM n-gram language model +- Fairseq-LM deocding: decoding with a Fairseq neural language model (not fully tested) + + +#### Viterbi decoding + +`task.normalize` needs to be consistent with the value used during fine-tuning. +Decoding results will be saved at +`/path/to/experiment/directory/decode/viterbi/test`. + +```sh +$ python examples/speech_recognition/new/infer.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ + --config-name infer \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ +``` + +#### KenLM / Fairseq-LM decoding + +Suppose the pronunciation lexicon and the n-gram LM are saved at +`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be +saved at `/path/to/experiment/directory/decode/kenlm/test`. + +```sh +$ python examples/speech_recognition/new/infer.py \ + --config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ + --config-name infer_lm \ + task.data=/path/to/data \ + task.normalize=[true|false] \ + decoding.exp_dir=/path/to/experiment/directory \ + common_eval.path=/path/to/checkpoint + dataset.gen_subset=test \ + decoding.decoder.lexicon=/path/to/lexicon \ + decoding.decoder.lmpath=/path/to/arpa +``` + +The command above uses the default decoding hyperparameter, which can be found +in `examples/speech_recognition/hydra/decoder.py`. These parameters can be +configured from the command line. For example, to search with a beam size of +500, we can append the command above with `decoding.decoder.beam=500`. +Important parameters include: +- decoding.decoder.beam +- decoding.decoder.beamthreshold +- decoding.decoder.lmweight +- decoding.decoder.wordscore +- decoding.decoder.silweight + +To decode with a Fairseq LM, you may check the usage examples in wav2vec2 or hubert examples. + +Please see sample decoding scripts `decode.sh` for an example script. diff --git a/fairseq/examples/mr_hubert/config/decode/infer.yaml b/fairseq/examples/mr_hubert/config/decode/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eff39802e7a398c804e4c7eacdb2eee6ab33db81 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/decode/infer.yaml @@ -0,0 +1,30 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/viterbi + sweep: + dir: ${common_eval.results_path} + subdir: viterbi + +task: + _name: multires_hubert_pretraining + single_target: true + fine_tuning: true + label_rate_ratios: ??? + data: ??? + normalize: false + +decoding: + type: viterbi + unique_wer_file: true +common_eval: + results_path: ??? + path: ??? + post_process: letter +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/fairseq/examples/mr_hubert/config/decode/infer_lm.yaml b/fairseq/examples/mr_hubert/config/decode/infer_lm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..535b95077560dd1c64a9fa2d6ae43641068afc59 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/decode/infer_lm.yaml @@ -0,0 +1,37 @@ +# @package _group_ + +defaults: + - model: null + +hydra: + run: + dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} + sweep: + dir: ${common_eval.results_path} + subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} + +task: + _name: multires_hubert_pretraining + single_target: true + fine_tuning: true + data: ??? + label_rate_ratios: ??? + normalize: ??? + +decoding: + type: kenlm + lexicon: ??? + lmpath: ??? + beamthreshold: 100 + beam: 500 + lmweight: 1.5 + wordscore: -1 + silweight: 0 + unique_wer_file: true +common_eval: + results_path: ??? + path: ??? + post_process: letter +dataset: + max_tokens: 1100000 + gen_subset: ??? diff --git a/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml b/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b8065832ecacf9dd4fe4e99c87941e00fb3ef7f --- /dev/null +++ b/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 1 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml b/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f669f376312dbfe4611cc08f4996a314155fb87 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml @@ -0,0 +1,17 @@ +# @package _global_ +hydra: + launcher: + cpus_per_task: ${distributed_training.distributed_world_size} + gpus_per_node: ${distributed_training.distributed_world_size} + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 1 + mem_gb: 200 + timeout_min: 4320 + max_num_timeout: 50 + name: ${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir}/submitit + +distributed_training: + distributed_world_size: 8 + distributed_no_spawn: true + distributed_port: 29761 diff --git a/fairseq/examples/mr_hubert/config/finetune/base_100h.yaml b/fairseq/examples/mr_hubert/config/finetune/base_100h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c52a118cb819d8504eb44501fed90a3e8f68bc9f --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_100h.yaml @@ -0,0 +1,97 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_100h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [3e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml b/fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d0c0da3dbc336495640f7c1fc35c9b72814021f --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml @@ -0,0 +1,97 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 1600000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_100h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [3e-5] + sentence_avg: true + update_freq: [2] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/finetune/base_10h.yaml b/fairseq/examples/mr_hubert/config/finetune/base_10h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25123e44816d13fd5d53e77bcd6e9e6fa0369030 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_10h.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 5 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [2e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml b/fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..65448c7722c7cc059fcb7e153c86d1583fff10bd --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 5 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [2e-5] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/finetune/base_1h.yaml b/fairseq/examples/mr_hubert/config/finetune/base_1h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7459c3fc4c4ccdd96ae6648d1f95d7e2bb59ab1f --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_1h.yaml @@ -0,0 +1,100 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 50 + keep_interval_updates: 1 + save_interval_updates: 1000 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: false # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 3200000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 1000 + train_subset: train_1h + valid_subset: dev_other + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [5e-5] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml b/fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34ef4dc19de87b20a53af49b527b5fc0002625d2 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml @@ -0,0 +1,99 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tblog + seed: 1337 + +checkpoint: + save_interval: 1000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +distributed_training: + ddp_backend: c10d + find_unused_parameters: true + distributed_world_size: 8 + distributed_port: 29671 + nprocs_per_node: 8 + +task: + _name: multires_hubert_pretraining + data: ??? + fine_tuning: true + label_dir: ??? + label_rate_ratios: ??? + normalize: true # must be consistent with pre-training + labels: ["ltr"] + single_target: true + +dataset: + num_workers: 0 + max_tokens: 1280000 + validate_after_updates: ${model.freeze_finetune_updates} + validate_interval: 5 + train_subset: train_10h + valid_subset: dev + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 25000 + lr: [3e-4] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: multires_hubert_ctc + multires_hubert_path: ??? + apply_mask: true + mask_selection: static + mask_length: 10 + mask_other: 0 + mask_prob: 0.75 + mask_channel_selection: static + mask_channel_length: 64 + mask_channel_other: 0 + mask_channel_prob: 0.5 + layerdrop: 0.1 + dropout: 0.0 + activation_dropout: 0.1 + attention_dropout: 0.0 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + - model.multires_hubert_path + - dataset.train_subset + - dataset.valid_subset + - criterion.wer_kenlm_model + - criterion.wer_lexicon + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml b/fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16a35d340a6994210da6d73280369b4635cb9c61 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml @@ -0,0 +1,103 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + min_loss_scale: 1e-8 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: multires_hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + label_rate_ratios: ??? + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + # max_keep_size: 300000 + # max_keep_size: 50000 + + +dataset: + num_workers: 0 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: multires_hubert + label_rate: ??? + label_rate_ratios: ${task.label_rate_ratios} + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: default + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + final_dim: 256 + encoder_layers: 4 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + conv_adapator_kernal: 1 + use_single_target: true + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '/' + exclude_keys: + - run + - task.data + - task.label_dir + - common.min_loss_scale + - common.log_interval + - optimization.clip_norm diff --git a/fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml b/fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml new file mode 100644 index 0000000000000000000000000000000000000000..423f3b25c22bfa5a18216ddd317546eedb785544 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml @@ -0,0 +1,107 @@ +# @package _group_ + +common: + memory_efficient_fp16: true + log_format: json + log_interval: 200 + seed: 1337 + tensorboard_logdir: tblog + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 128 + distributed_port: 29671 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: multires_hubert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + label_rate_ratios: ??? + sample_rate: 16000 + max_sample_size: 250000 + min_sample_size: 32000 + pad_audio: false + random_crop: true + normalize: true # must be consistent with extractor + # max_keep_size: 50000 + +dataset: + num_workers: 0 + max_tokens: 300000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 5 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10,] + +optimization: + max_update: 400000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [3] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: multires_hubert + label_rate: ??? + label_rate_ratios: ${task.label_rate_ratios} + encoder_layers: 8 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + final_dim: 768 + skip_masked: false + skip_nomask: false + mask_prob: 0.80 + extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + layer_norm_first: true + feature_grad_mult: 1.0 + untie_final_proj: true + activation_dropout: 0.0 + conv_adapator_kernal: 1 + use_single_target: true + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + run: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + sweep: + dir: /checkpoint/wnhsu/w2v/hubert_final/hydra_pt + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml b/fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46c979cd2835fe026b0a532a54533904d1001e54 --- /dev/null +++ b/fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +hydra: + launcher: + cpus_per_task: 8 + gpus_per_node: 8 + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 4 + comment: null + mem_gb: 384 + timeout_min: 4320 + max_num_timeout: 100 + constraint: volta32gb + name: ${hydra.job.config_name}/${hydra.job.override_dirname} + submitit_folder: ${hydra.sweep.dir}/submitit/%j + +distributed_training: + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 diff --git a/fairseq/examples/mr_hubert/train.sh b/fairseq/examples/mr_hubert/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..da561eb171416f3c321f5dac681cdc9603bebe83 --- /dev/null +++ b/fairseq/examples/mr_hubert/train.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +FAIRSEQ= # Setup your fairseq directory + +config_dir=${FAIRSEQ}/examples/mr_hubert/config +config_name=mr_hubert_base_librispeech + +# Prepared Data Directory +data_dir=librispeech +# -- data_dir +# -- train.tsv +# -- valid.tsv + +label_dir=labels +# -- label_dir +# -- train.km +# -- valid.km +# -- dict.km.txt + + +exp_dir=exp # Target experiments directory +ratios="[1, 2]" # Default label rate ratios +label_rate=50 # Base label rate + + +_opts= + +# If use slurm, uncomment this line and modify the job submission at +# _opts="${_opts} hydra/launcher=submitit_slurm +hydra.launcher.partition=${your_slurm_partition} +run=submitit_reg" + +# If want to set additional experiment tag, uncomment this line +# _opts="${_opts} hydra.sweep.subdir=${your_experiment_tag}" + + +python ${FAIRSEQ}/fairseq_cli/hydra_train.py \ + -m --config-dir ${config_dir} --config-name ${config_name} ${_opts} \ + task.data=${data_dir} \ + task.label_dir=${label_dir} \ + task.labels='["km"]' \ + model.label_rate=${label_rate} \ + task.label_rate_ratios='${ratios}' \ + hydra.sweep.dir=${exp_dir} & + + + diff --git a/fairseq/examples/multilingual/ML50_langs.txt b/fairseq/examples/multilingual/ML50_langs.txt new file mode 100644 index 0000000000000000000000000000000000000000..558abbc785072629de8000e343fc02a32c0afb97 --- /dev/null +++ b/fairseq/examples/multilingual/ML50_langs.txt @@ -0,0 +1,52 @@ +ar_AR +cs_CZ +de_DE +en_XX +es_XX +et_EE +fi_FI +fr_XX +gu_IN +hi_IN +it_IT +ja_XX +kk_KZ +ko_KR +lt_LT +lv_LV +my_MM +ne_NP +nl_XX +ro_RO +ru_RU +si_LK +tr_TR +vi_VN +zh_CN +af_ZA +az_AZ +bn_IN +fa_IR +he_IL +hr_HR +id_ID +ka_GE +km_KH +mk_MK +ml_IN +mn_MN +mr_IN +pl_PL +ps_AF +pt_XX +sv_SE +sw_KE +ta_IN +te_IN +th_TH +tl_XX +uk_UA +ur_PK +xh_ZA +gl_ES +sl_SI \ No newline at end of file diff --git a/fairseq/examples/multilingual/README.md b/fairseq/examples/multilingual/README.md new file mode 100644 index 0000000000000000000000000000000000000000..46ff9c351b1030e0729f89f246e0cd86444c1633 --- /dev/null +++ b/fairseq/examples/multilingual/README.md @@ -0,0 +1,158 @@ +# Multilingual Translation + +[[Multilingual Translation with Extensible Multilingual Pretraining and Finetuning, https://arxiv.org/abs/2008.00401]](https://arxiv.org/abs/2008.00401) + +## Introduction + +This work is for training multilingual translation models with multiple bitext datasets. This multilingual translation framework supports (see [[training section]](#Training) and [[finetuning section]](#Finetuning) for examples) + +* temperature based sampling over unbalancing datasets of different translation directions + - --sampling-method' with + choices=['uniform', 'temperature', 'concat'] + - --sampling-temperature +* configurable to automatically add source and/or target language tokens to source/target sentences using data which are prepared in the same way as bilignual training + - --encoder-langtok with choices=['src', 'tgt', None] to specify whether to add source or target language tokens to the source sentences + - --decoder-langtok (binary option) to specify whether to add target language tokens to the target sentences or not +* finetuning mBART pretrained models for multilingual translation + - --finetune-from-model to specify the path from which to load the pretrained model + +## Preprocessing data +Multilingual training requires a joint BPE vocab. Please follow [mBART's preprocessing steps](https://github.com/pytorch/fairseq/tree/main/examples/mbart#bpe-data) to reuse our pretrained sentence-piece model. + +You can also train a joint BPE model on your own dataset and then follow the steps in [[link]](https://github.com/pytorch/fairseq/tree/main/examples/translation#multilingual-translation). + +## Training + + +```bash +lang_pairs= +path_2_data= +lang_list= + +fairseq-train $path_2_data \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 +``` + +## Finetuning +We can also finetune multilingual models from a monolingual pretrained models, e.g. [mMBART](https://github.com/pytorch/fairseq/tree/main/examples/mbart). +```bash +lang_pairs= +path_2_data= +lang_list= +pretrained_model= + +fairseq-train $path_2_data \ + --finetune-from-model $pretrained_model \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 +``` +## Generate +The following command uses the multilingual task (translation_multi_simple_epoch) to generate translation from $source_lang to $target_lang on the test dataset. During generaton, the source language tokens are added to source sentences and the target language tokens are added as the starting token to decode target sentences. Options --lang-dict and --lang-pairs are needed to tell the generation process the ordered list of languages and translation directions that the trained model are awared of; they will need to be consistent with the training. + +```bash +model= +source_lang= +target_lang= + +fairseq-generate $path_2_data \ + --path $model \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang $source_lang \ + --target-lang $target_lang + --sacrebleu --remove-bpe 'sentencepiece'\ + --batch-size 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" > ${source_lang}_${target_lang}.txt +``` +Fairseq will generate translation into a file {source_lang}_${target_lang}.txt with sacreblue at the end. + +You can also use costomized tokenizer to compare the performance with the literature. For example, you get a tokenizer [here](https://github.com/rsennrich/wmt16-scripts) and do the following: +```bash +TOKENIZER= +TOK_CMD=<"$TOKENIZER $target_lang" or cat for sacrebleu> + +cat {source_lang}_${target_lang}.txt | grep -P "^H" |sort -V |cut -f 3- |$TOK_CMD > ${source_lang}_${target_lang}.hyp +cat {source_lang}_${target_lang}.txt | grep -P "^T" |sort -V |cut -f 2- |$TOK_CMD > ${source_lang}_${target_lang}.ref +sacrebleu -tok 'none' -s 'none' ${source_lang}_${target_lang}.ref < ${source_lang}_${target_lang}.hyp +``` + +# mBART50 models + +* [mMBART 50 pretrained model](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.pretrained.tar.gz). +* [mMBART 50 finetuned many-to-one](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.n1.tar.gz). +* [mMBART 50 finetuned one-to-many](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.1n.tar.gz). +* [mMBART 50 finetuned many-to-many](https://dl.fbaipublicfiles.com/fairseq/models/mbart50/mbart50.ft.nn.tar.gz). + +Please download and extract from the above tarballs. Each tarball contains +* The fairseq model checkpoint: model.pt +* The list of supported languages: ML50_langs.txt +* Sentence piece model: sentence.bpe.model +* Fairseq dictionary of each language: dict.{lang}.txt (please replace lang with a language specified in ML50_langs.txt) + +To use the trained models, +* use the tool [binarize.py](./data_scripts/binarize.py) to binarize your data using sentence.bpe.model and dict.{lang}.txt, and copy the dictionaries to your data path +* then run the generation command: +```bash +path_2_data= +model=/model.pt +lang_list=/ML50_langs.txt +source_lang= +target_lang= + +fairseq-generate $path_2_data \ + --path $model \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang $source_lang \ + --target-lang $target_lang + --sacrebleu --remove-bpe 'sentencepiece'\ + --batch-size 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" +``` + +## Citation + +```bibtex +@article{tang2020multilingual, + title={Multilingual Translation with Extensible Multilingual Pretraining and Finetuning}, + author={Yuqing Tang and Chau Tran and Xian Li and Peng-Jen Chen and Naman Goyal and Vishrav Chaudhary and Jiatao Gu and Angela Fan}, + year={2020}, + eprint={2008.00401}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/fairseq/examples/multilingual/data_scripts/README.md b/fairseq/examples/multilingual/data_scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc610c0c9e936a5ae4659ceda691c6db6d387296 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/README.md @@ -0,0 +1,24 @@ + +# Install dependency +```bash +pip install -r requirement.txt +``` + +# Download the data set +```bash +export WORKDIR_ROOT= + +``` +The downloaded data will be at $WORKDIR_ROOT/ML50 + +# preprocess the data +Install SPM [here](https://github.com/google/sentencepiece) +```bash +export WORKDIR_ROOT= +export SPM_PATH= +``` +* $WORKDIR_ROOT/ML50/raw: extracted raw data +* $WORKDIR_ROOT/ML50/dedup: dedup data +* $WORKDIR_ROOT/ML50/clean: data with valid and test sentences removed from the dedup data + + diff --git a/fairseq/examples/multilingual/data_scripts/binarize.py b/fairseq/examples/multilingual/data_scripts/binarize.py new file mode 100644 index 0000000000000000000000000000000000000000..ee54c6aabf021ca526743f8f1f67b91889e1e335 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/binarize.py @@ -0,0 +1,200 @@ +import shutil +import os, sys +from subprocess import check_call, check_output +import glob +import argparse +import shutil +import pathlib +import itertools + +def call_output(cmd): + print(f"Executing: {cmd}") + ret = check_output(cmd, shell=True) + print(ret) + return ret + +def call(cmd): + print(cmd) + check_call(cmd, shell=True) + + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +SPM_PATH = os.environ.get('SPM_PATH', None) + +if SPM_PATH is None or not SPM_PATH.strip(): + print("Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting...") + sys.exit(-1) + + +SPM_MODEL = f'{WORKDIR_ROOT}/sentence.bpe.model' +SPM_VOCAB = f'{WORKDIR_ROOT}/dict_250k.txt' + +SPM_ENCODE = f'{SPM_PATH}' + +if not os.path.exists(SPM_MODEL): + call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/sentence.bpe.model -O {SPM_MODEL}") + + +if not os.path.exists(SPM_VOCAB): + call(f"wget https://dl.fbaipublicfiles.com/fairseq/models/mbart50/dict_250k.txt -O {SPM_VOCAB}") + + + +def get_data_size(raw): + cmd = f'wc -l {raw}' + ret = call_output(cmd) + return int(ret.split()[0]) + +def encode_spm(model, direction, prefix='', splits=['train', 'test', 'valid'], pairs_per_shard=None): + src, tgt = direction.split('-') + + for split in splits: + src_raw, tgt_raw = f'{RAW_DIR}/{split}{prefix}.{direction}.{src}', f'{RAW_DIR}/{split}{prefix}.{direction}.{tgt}' + if os.path.exists(src_raw) and os.path.exists(tgt_raw): + cmd = f"""python {SPM_ENCODE} \ + --model {model}\ + --output_format=piece \ + --inputs {src_raw} {tgt_raw} \ + --outputs {BPE_DIR}/{direction}{prefix}/{split}.bpe.{src} {BPE_DIR}/{direction}{prefix}/{split}.bpe.{tgt} """ + print(cmd) + call(cmd) + + +def binarize_( + bpe_dir, + databin_dir, + direction, spm_vocab=SPM_VOCAB, + splits=['train', 'test', 'valid'], +): + src, tgt = direction.split('-') + + try: + shutil.rmtree(f'{databin_dir}', ignore_errors=True) + os.mkdir(f'{databin_dir}') + except OSError as error: + print(error) + cmds = [ + "fairseq-preprocess", + f"--source-lang {src} --target-lang {tgt}", + f"--destdir {databin_dir}/", + f"--workers 8", + ] + if isinstance(spm_vocab, tuple): + src_vocab, tgt_vocab = spm_vocab + cmds.extend( + [ + f"--srcdict {src_vocab}", + f"--tgtdict {tgt_vocab}", + ] + ) + else: + cmds.extend( + [ + f"--joined-dictionary", + f"--srcdict {spm_vocab}", + ] + ) + input_options = [] + if 'train' in splits and glob.glob(f"{bpe_dir}/train.bpe*"): + input_options.append( + f"--trainpref {bpe_dir}/train.bpe", + ) + if 'valid' in splits and glob.glob(f"{bpe_dir}/valid.bpe*"): + input_options.append(f"--validpref {bpe_dir}/valid.bpe") + if 'test' in splits and glob.glob(f"{bpe_dir}/test.bpe*"): + input_options.append(f"--testpref {bpe_dir}/test.bpe") + if len(input_options) > 0: + cmd = " ".join(cmds + input_options) + print(cmd) + call(cmd) + + +def binarize( + databin_dir, + direction, spm_vocab=SPM_VOCAB, prefix='', + splits=['train', 'test', 'valid'], + pairs_per_shard=None, +): + def move_databin_files(from_folder, to_folder): + for bin_file in glob.glob(f"{from_folder}/*.bin") \ + + glob.glob(f"{from_folder}/*.idx") \ + + glob.glob(f"{from_folder}/dict*"): + try: + shutil.move(bin_file, to_folder) + except OSError as error: + print(error) + bpe_databin_dir = f"{BPE_DIR}/{direction}{prefix}_databin" + bpe_dir = f"{BPE_DIR}/{direction}{prefix}" + if pairs_per_shard is None: + binarize_(bpe_dir, bpe_databin_dir, direction, spm_vocab=spm_vocab, splits=splits) + move_databin_files(bpe_databin_dir, databin_dir) + else: + # binarize valid and test which will not be sharded + binarize_( + bpe_dir, bpe_databin_dir, direction, + spm_vocab=spm_vocab, splits=[s for s in splits if s != "train"]) + for shard_bpe_dir in glob.glob(f"{bpe_dir}/shard*"): + path_strs = os.path.split(shard_bpe_dir) + shard_str = path_strs[-1] + shard_folder = f"{bpe_databin_dir}/{shard_str}" + databin_shard_folder = f"{databin_dir}/{shard_str}" + print(f'working from {shard_folder} to {databin_shard_folder}') + os.makedirs(databin_shard_folder, exist_ok=True) + binarize_( + shard_bpe_dir, shard_folder, direction, + spm_vocab=spm_vocab, splits=["train"]) + + for test_data in glob.glob(f"{bpe_databin_dir}/valid.*") + glob.glob(f"{bpe_databin_dir}/test.*"): + filename = os.path.split(test_data)[-1] + try: + os.symlink(test_data, f"{databin_shard_folder}/{filename}") + except OSError as error: + print(error) + move_databin_files(shard_folder, databin_shard_folder) + + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", default=f"{WORKDIR_ROOT}/ML50") + parser.add_argument("--raw-folder", default='raw') + parser.add_argument("--bpe-folder", default='bpe') + parser.add_argument("--databin-folder", default='databin') + + args = parser.parse_args() + + DATA_PATH = args.data_root #'/private/home/yuqtang/public_data/ML50' + RAW_DIR = f'{DATA_PATH}/{args.raw_folder}' + BPE_DIR = f'{DATA_PATH}/{args.bpe_folder}' + DATABIN_DIR = f'{DATA_PATH}/{args.databin_folder}' + os.makedirs(BPE_DIR, exist_ok=True) + + raw_files = itertools.chain( + glob.glob(f'{RAW_DIR}/train*'), + glob.glob(f'{RAW_DIR}/valid*'), + glob.glob(f'{RAW_DIR}/test*'), + ) + + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + + for direction in directions: + prefix = "" + splits = ['train', 'valid', 'test'] + try: + shutil.rmtree(f'{BPE_DIR}/{direction}{prefix}', ignore_errors=True) + os.mkdir(f'{BPE_DIR}/{direction}{prefix}') + os.makedirs(DATABIN_DIR, exist_ok=True) + except OSError as error: + print(error) + spm_model, spm_vocab = SPM_MODEL, SPM_VOCAB + encode_spm(spm_model, direction=direction, splits=splits) + binarize(DATABIN_DIR, direction, spm_vocab=spm_vocab, splits=splits) diff --git a/fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py b/fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e2eb0f15699f1b458a8445d0c1dd6229a21f77 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os, sys +import subprocess +import re +from subprocess import check_call, check_output + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +BLEU_REGEX = re.compile("^BLEU\\S* = (\\S+) ") +def run_eval_bleu(cmd): + output = check_output(cmd, shell=True, stderr=subprocess.STDOUT).decode("utf-8").strip() + print(output) + bleu = -1.0 + for line in output.strip().split('\n'): + m = BLEU_REGEX.search(line) + if m is not None: + bleu = m.groups()[0] + bleu = float(bleu) + break + return bleu + +def check_data_test_bleu(raw_folder, data_lang_pairs): + not_matchings = [] + for sacrebleu_set, src_tgts in data_lang_pairs: + for src_tgt in src_tgts: + print(f'checking test bleus for: {src_tgt} at {sacrebleu_set}') + src, tgt = src_tgt.split('-') + ssrc, stgt = src[:2], tgt[:2] + if os.path.exists(f'{raw_folder}/test.{tgt}-{src}.{src}'): + # reversed direction may have different test set + test_src = f'{raw_folder}/test.{tgt}-{src}.{src}' + else: + test_src = f'{raw_folder}/test.{src}-{tgt}.{src}' + cmd1 = f'cat {test_src} | sacrebleu -t "{sacrebleu_set}" -l {stgt}-{ssrc}; [ $? -eq 0 ] || echo ""' + test_tgt = f'{raw_folder}/test.{src}-{tgt}.{tgt}' + cmd2 = f'cat {test_tgt} | sacrebleu -t "{sacrebleu_set}" -l {ssrc}-{stgt}; [ $? -eq 0 ] || echo ""' + bleu1 = run_eval_bleu(cmd1) + if bleu1 != 100.0: + not_matchings.append(f'{sacrebleu_set}:{src_tgt} source side not matching: {test_src}') + bleu2 = run_eval_bleu(cmd2) + if bleu2 != 100.0: + not_matchings.append(f'{sacrebleu_set}:{src_tgt} target side not matching: {test_tgt}') + return not_matchings + +if __name__ == "__main__": + to_data_path = f'{WORKDIR_ROOT}/iwsltv2' + not_matching = check_data_test_bleu( + f'{to_data_path}/raw', + [ + ('iwslt17', ['en_XX-ar_AR', 'en_XX-ko_KR', 'ar_AR-en_XX', 'ko_KR-en_XX']), + ('iwslt17', ['en_XX-it_IT', 'en_XX-nl_XX', 'it_IT-en_XX', 'nl_XX-en_XX']), + ('iwslt17/tst2015', ['en_XX-vi_VN', "vi_VN-en_XX"]), + ] + ) + if len(not_matching) > 0: + print('the following datasets do not have matching test datasets:\n\t', '\n\t'.join(not_matching)) + diff --git a/fairseq/examples/multilingual/data_scripts/check_self_overlaps.py b/fairseq/examples/multilingual/data_scripts/check_self_overlaps.py new file mode 100644 index 0000000000000000000000000000000000000000..07b338dcfd2d7f10317608274631d0edd93ba889 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/check_self_overlaps.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import glob +import argparse +from utils.dedup import deup +import sys + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +def get_directions(folder): + raw_files = glob.glob(f'{folder}/train*') + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + return directions + +def diff_list(lhs, rhs): + return set(lhs).difference(set(rhs)) + +def check_diff( + from_src_file, from_tgt_file, + to_src_file, to_tgt_file, +): + seen_in_from = set() + seen_src_in_from = set() + seen_tgt_in_from = set() + from_count = 0 + with open(from_src_file, encoding='utf-8') as fsrc, \ + open(from_tgt_file, encoding='utf-8') as ftgt: + for s, t in zip(fsrc, ftgt): + seen_in_from.add((s, t)) + seen_src_in_from.add(s) + seen_tgt_in_from.add(t) + from_count += 1 + common = 0 + common_src = 0 + common_tgt = 0 + to_count = 0 + seen = set() + + with open(to_src_file, encoding='utf-8') as fsrc, \ + open(to_tgt_file, encoding='utf-8') as ftgt: + for s, t in zip(fsrc, ftgt): + to_count += 1 + if (s, t) not in seen: + if (s, t) in seen_in_from: + common += 1 + if s in seen_src_in_from: + common_src += 1 + seen_src_in_from.remove(s) + if t in seen_tgt_in_from: + common_tgt += 1 + seen_tgt_in_from.remove(t) + seen.add((s, t)) + return common, common_src, common_tgt, from_count, to_count + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--folder", type=str, required=True, + help="the data folder ") + parser.add_argument("--split", type=str, default='test', + help="split (valid, test) to check against training data") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + + if args.directions is None: + directions = set(get_directions(args.folder)) + directions = sorted(directions) + else: + directions = args.directions.split(',') + directions = sorted(set(directions)) + + results = [] + print(f'checking where {args.split} split data are in training') + print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size') + + for direction in directions: + src, tgt = direction.split('-') + from_src_file = f'{args.folder}/{args.split}.{src}-{tgt}.{src}' + from_tgt_file = f'{args.folder}/{args.split}.{src}-{tgt}.{tgt}' + if not os.path.exists(from_src_file): + # some test/valid data might in reverse directinos: + from_src_file = f'{args.folder}/{args.split}.{tgt}-{src}.{src}' + from_tgt_file = f'{args.folder}/{args.split}.{tgt}-{src}.{tgt}' + to_src_file = f'{args.folder}/train.{src}-{tgt}.{src}' + to_tgt_file = f'{args.folder}/train.{src}-{tgt}.{tgt}' + if not os.path.exists(to_src_file) or not os.path.exists(from_src_file): + continue + r = check_diff(from_src_file, from_tgt_file, to_src_file, to_tgt_file) + results.append(r) + print(f'{direction}\t', '\t'.join(map(str, r))) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py b/fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py new file mode 100644 index 0000000000000000000000000000000000000000..40fa9aecdf9108e095feb3661236453c0f7ed7c4 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import argparse +import pandas as pd +import sys + + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + + + +def load_sentences(raw_data, split, direction): + src, tgt = direction.split('-') + src_path = f"{raw_data}/{split}.{direction}.{src}" + tgt_path = f"{raw_data}/{split}.{direction}.{tgt}" + if os.path.exists(src_path) and os.path.exists(tgt_path): + return [(src, open(src_path).read().splitlines()), (tgt, open(tgt_path).read().splitlines())] + else: + return [] + +def swap_direction(d): + src, tgt = d.split('-') + return f'{tgt}-{src}' + +def get_all_test_data(raw_data, directions, split='test'): + test_data = [ + x + for dd in directions + for d in [dd, swap_direction(dd)] + for x in load_sentences(raw_data, split, d) + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_data = {} + for lang, d in test_data: + for s in d: + s = s.strip() + lgs = all_test_data.get(s, set()) + lgs.add(lang) + all_test_data[s] = lgs + return all_test_data, test_data + + +def check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train={}): + # src, tgt = direction.split('-') + print(f'check training data for {direction} in {src_path} and {tgt_path}') + size = 0 + overlapped_size_counted_dup = 0 + if not os.path.exists(tgt_path) or not os.path.exists(src_path): + return mess_up_train, size, overlapped_size_counted_dup + + with open(src_path) as f, open(tgt_path) as g: + for src_line, tgt_line in zip(f, g): + s = src_line.strip() + t = tgt_line.strip() + size += 1 + if s in all_test_data: + langs = mess_up_train.get(s, set()) + langs.add(direction) + mess_up_train[s] = langs + overlapped_size_counted_dup += 1 + if t in all_test_data: + langs = mess_up_train.get(t, set()) + langs.add(direction) + mess_up_train[t] = langs + overlapped_size_counted_dup += 1 + print(f'{direction}: size={size}, overlapped={overlapped_size_counted_dup}') + return mess_up_train, size, overlapped_size_counted_dup + +def check_train_all(raw_data, directions, all_test_data): + mess_up_train = {} + data_sizes = {} + # raw_data = '~chau/data-bin/MineBART/multilingual_mined_100M/en_XX/et_EE-en_XX/all.{en_XX, et_EE}' + print(f'checking training data againsts # {len(all_test_data)} sentences') + print(f'example test data: ', [s for i, s in enumerate(all_test_data.keys()) if i < 10]) + for direction in directions: + src, tgt = direction.split('-') + path = f'{raw_data}/en_XX/{direction}/all' + src_path = f'{path}.{src}' + tgt_path = f'{path}.{tgt}' + print(f'checking {src_path} {tgt_path}') + _, size, overlapped_size_counted_dup = check_train_sentences(src_path, tgt_path, direction, all_test_data, mess_up_train) + data_sizes[direction] = (size, overlapped_size_counted_dup) + return mess_up_train, data_sizes + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--folder", type=str, required=True, + help="the data folder ") + parser.add_argument("--test-data", type=str, required=True, + help="the test data folder ") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + directions = args.directions.split(',') + directions = sorted(set(directions)) + + results = [] + # print(f'checking where {args.split} split data are in training') + # print(f'direction\tcommon_count\tsrc common\ttgt common\tfrom_size\tto_size') + raw_data = args.folder + all_test_data, test_data = get_all_test_data(args.test_data, directions, split='test') + mess_up_train, data_sizes = check_train_all(raw_data, directions, all_test_data) + print(data_sizes) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/multilingual/data_scripts/dedup_all.py b/fairseq/examples/multilingual/data_scripts/dedup_all.py new file mode 100644 index 0000000000000000000000000000000000000000..ef39c05ee606aaeda1d9e94970932d2241a8b281 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/dedup_all.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + + +import os +import glob +import argparse +from utils.dedup import deup + +import sys +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--from-folder", type=str, required=True, + help="the data folder to be dedup") + parser.add_argument("--to-folder", type=str, required=True, + help="the data folder to save deduped data") + parser.add_argument('--directions', type=str, default=None, required=False) + + args = parser.parse_args() + + if args.directions is None: + raw_files = glob.glob(f'{args.from_folder}/train*') + + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + else: + directions = args.directions.split(',') + directions = sorted(set(directions)) + + for direction in directions: + src, tgt = direction.split('-') + src_file = f'{args.from_folder}/train.{src}-{tgt}.{src}' + tgt_file = f'{args.from_folder}/train.{src}-{tgt}.{tgt}' + src_file_out = f'{args.to_folder}/train.{src}-{tgt}.{src}' + tgt_file_out = f'{args.to_folder}/train.{src}-{tgt}.{tgt}' + assert src_file != src_file_out + assert tgt_file != tgt_file_out + print(f'deduping {src_file}, {tgt_file}') + deup(src_file, tgt_file, src_file_out, tgt_file_out) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh b/fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh new file mode 100644 index 0000000000000000000000000000000000000000..99fbc75920836a4b4bbdbd6b523749843288e450 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +# first run download_wmt20.sh; it will install a few useful tools for other scripts +# TODO: need to print out instructions on downloading a few files which requires manually authentication from the websites +bash ./download_wmt20.sh + +python ./download_wmt19_and_before.py +bash ./download_wat19_my.sh +python ./download_ted_and_extract.py +bash ./download_lotus.sh +bash ./download_iitb.sh +bash ./download_af_xh.sh + + +# IWSLT downloading URLs have changed in between; TODO: fix them: +bash ./download_iwslt_and_extract.sh + +# TODO: globalvoices URLs changed; need to be fixed +bash ./download_flores_data.sh diff --git a/fairseq/examples/multilingual/data_scripts/download_af_xh.sh b/fairseq/examples/multilingual/data_scripts/download_af_xh.sh new file mode 100644 index 0000000000000000000000000000000000000000..a78fbbbbccb6f6ae005a1f03b97f083a2d958ebe --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_af_xh.sh @@ -0,0 +1,164 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# set -x -e + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +# put intermediate files +TMP_DIR=$WORKDIR_ROOT/temp/af_xhv2 +# output {train,valid,test} files to dest +DEST=${WORKDIR_ROOT}/ML50/raw + + + +ROOT=${WORKDIR_ROOT} +UTILS=$PWD/utils +TMX2CORPUS="${UTILS}/tmx2corpus" +TMX_TOOL="python ${TMX2CORPUS}/tmx2corpus.py" + +mkdir -p $TMP_DIR +mkdir -p $DEST +mkdir -p $UTILS + +function download_opus(){ + src=$1 + tgt=$2 + subset=$3 + ulr=$4 + + mkdir extract_$subset.$src-$tgt + pushd extract_$subset.$src-$tgt + if [ ! -f "$subset.$src-$tgt.tmx.gz" ]; then + wget $url -O "$subset.$src-$tgt.tmx.gz" + gzip -d "$subset.$src-$tgt.tmx.gz" + f=$subset.$src-$tgt.tmx + $TMX_TOOL $f + mv bitext.$src ../$subset.$src-$tgt.$src + mv bitext.$tgt ../$subset.$src-$tgt.$tgt + fi + popd +} + +function concat_subsets(){ + src=$1 + tgt=$2 + subsets=$3 + src_train=raw_train.$src-$tgt.$src + tgt_train=raw_train.$src-$tgt.$tgt + > $src_train + > $tgt_train + for subset in $subsets; do + cat $subset.$src-$tgt.$src >> $src_train + cat $subset.$src-$tgt.$tgt >> $tgt_train + done +} + + + +function get_seeded_random() +{ + seed="$1" + openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \ + /dev/null +} + +function split_train_valid(){ + src=$1 + tgt=$2 + raw_src_train=raw_train.$src-$tgt.$src + raw_tgt_train=raw_train.$src-$tgt.$tgt + + shuf --random-source=<(get_seeded_random 43) $raw_src_train > shuffled.$src-$tgt.$src + shuf --random-source=<(get_seeded_random 43) $raw_tgt_train > shuffled.$src-$tgt.$tgt + + head -n 1500 shuffled.$src-$tgt.$src > valid.$src-$tgt.$src + head -n 1500 shuffled.$src-$tgt.$tgt > valid.$src-$tgt.$tgt + + tail +1501 shuffled.$src-$tgt.$src > train.$src-$tgt.$src + tail +1501 shuffled.$src-$tgt.$tgt > train.$src-$tgt.$tgt +} + +function copy2dst(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + + + cp valid.$src-$tgt.$src $DEST/valid.$lsrc-$ltgt.$lsrc + cp valid.$src-$tgt.$tgt $DEST/valid.$lsrc-$ltgt.$ltgt + + cp train.$src-$tgt.$src $DEST/train.$lsrc-$ltgt.$lsrc + cp train.$src-$tgt.$tgt $DEST/train.$lsrc-$ltgt.$ltgt +} + + + + +#for xh-en +declare -A xh_en_urls +xh_en_urls=( + [Tatoeba]=https://object.pouta.csc.fi/OPUS-Tatoeba/v20190709/tmx/en-xh.tmx.gz + [wikimedia]=https://object.pouta.csc.fi/OPUS-wikimedia/v20190628/tmx/en-xh.tmx.gz + [memat]=https://object.pouta.csc.fi/OPUS-memat/v1/tmx/en-xh.tmx.gz + [uedin]=https://object.pouta.csc.fi/OPUS-bible-uedin/v1/tmx/en-xh.tmx.gz + [GNOME]=https://object.pouta.csc.fi/OPUS-GNOME/v1/tmx/en-xh.tmx.gz + [XhosaNavy]=https://object.pouta.csc.fi/OPUS-XhosaNavy/v1/tmx/en-xh.tmx.gz + [KDE4]=https://object.pouta.csc.fi/OPUS-KDE4/v2/tmx/en-xh.tmx.gz + [Ubuntu]=https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/tmx/en-xh.tmx.gz +) + +mkdir $TMP_DIR/xh-en +pushd $TMP_DIR/xh-en +for k in "${!xh_en_urls[@]}" +do + name=$k + url=${xh_en_urls[$k]} + echo "$name: $url" + download_opus xh en $name $ulr +done +concat_subsets xh en "${!xh_en_urls[@]}" +split_train_valid xh en +copy2dst xh_ZA en_XX +popd + + +## +#for af-en +declare -A af_en_urls +af_en_urls=( + [Tatoeba]=https://object.pouta.csc.fi/OPUS-Tatoeba/v20190709/tmx/af-en.tmx.gz + [uedin]=https://object.pouta.csc.fi/OPUS-bible-uedin/v1/tmx/af-en.tmx.gz + [GNOME]=https://object.pouta.csc.fi/OPUS-GNOME/v1/tmx/af-en.tmx.gz + [QED]=https://object.pouta.csc.fi/OPUS-QED/v2.0a/tmx/af-en.tmx.gz + [KDE4]=https://object.pouta.csc.fi/OPUS-KDE4/v2/tmx/af-en.tmx.gz + [OpenSubtitles]=https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2018/tmx/af-en.tmx.gz + [SPC]=https://object.pouta.csc.fi/OPUS-SPC/v1/tmx/af-en.tmx.gz + [Ubuntu]=https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/tmx/af-en.tmx.gz +) + +mkdir $TMP_DIR/af-en +pushd $TMP_DIR/af-en +for k in "${!af_en_urls[@]}" +do + name=$k + url=${af_en_urls[$k]} + echo "$name: $url" + download_opus af en $name $ulr +done +concat_subsets af en "${!af_en_urls[@]}" +split_train_valid af en +copy2dst af_ZA en_XX +popd + + diff --git a/fairseq/examples/multilingual/data_scripts/download_flores_data.sh b/fairseq/examples/multilingual/data_scripts/download_flores_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..e6175ce0c38b06a1ebddaeca808f71b47f77f500 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_flores_data.sh @@ -0,0 +1,246 @@ +#!/bin/bash + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +set -e +set -o pipefail + +SRC=en +SI_TGT=si +NE_TGT=ne + +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + +ROOT=${WORKDIR_ROOT}/tmp +mkdir -p $ROOT +DATA=$ROOT/data +NE_ROOT=$DATA/all-clean-ne +SI_ROOT=$DATA/all-clean-si + +mkdir -p $DATA $NE_ROOT $SI_ROOT + +SI_OPUS_DATASETS=( + "$SI_ROOT/GNOME.en-si" + "$SI_ROOT/Ubuntu.en-si" + "$SI_ROOT/KDE4.en-si" + "$SI_ROOT/OpenSubtitles.en-si" +) + +SI_OPUS_URLS=( + "https://object.pouta.csc.fi/OPUS-GNOME/v1/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-KDE4/v2/moses/en-si.txt.zip" + "https://object.pouta.csc.fi/OPUS-OpenSubtitles/v2018/moses/en-si.txt.zip" +) + +NE_OPUS_DATASETS=( + "$NE_ROOT/GNOME.en-ne" + "$NE_ROOT/Ubuntu.en-ne" + "$NE_ROOT/KDE4.en-ne" +) + +NE_OPUS_URLS=( + "https://object.pouta.csc.fi/OPUS-GNOME/v1/moses/en-ne.txt.zip" + "https://object.pouta.csc.fi/OPUS-Ubuntu/v14.10/moses/en-ne.txt.zip" + "https://object.pouta.csc.fi/OPUS-KDE4/v2/moses/en-ne.txt.zip" +) + +REMOVE_FILE_PATHS=() + +# Download data +download_data() { + CORPORA=$1 + URL=$2 + + if [ -f $CORPORA ]; then + echo "$CORPORA already exists, skipping download" + else + echo "Downloading $URL" + wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA + if [ -f $CORPORA ]; then + echo "$URL successfully downloaded." + else + echo "$URL not successfully downloaded." + rm -f $CORPORA + exit -1 + fi + fi +} + +# Example: download_opus_data $LANG_ROOT $TGT +download_opus_data() { + LANG_ROOT=$1 + TGT=$2 + + if [ "$TGT" = "si" ]; then + URLS=("${SI_OPUS_URLS[@]}") + DATASETS=("${SI_OPUS_DATASETS[@]}") + else + URLS=("${NE_OPUS_URLS[@]}") + DATASETS=("${NE_OPUS_DATASETS[@]}") + fi + + # Download and extract data + for ((i=0;i<${#URLS[@]};++i)); do + URL=${URLS[i]} + CORPORA=${DATASETS[i]} + + download_data $CORPORA $URL + unzip -o $CORPORA -d $LANG_ROOT + REMOVE_FILE_PATHS+=( $CORPORA $CORPORA.xml $CORPORA.ids $LANG_ROOT/README $LANG_ROOT/LICENSE ) + done + + cat ${DATASETS[0]}.$SRC ${DATASETS[1]}.$SRC ${DATASETS[2]}.$SRC > $LANG_ROOT/GNOMEKDEUbuntu.$SRC-$TGT.$SRC + cat ${DATASETS[0]}.$TGT ${DATASETS[1]}.$TGT ${DATASETS[2]}.$TGT > $LANG_ROOT/GNOMEKDEUbuntu.$SRC-$TGT.$TGT + + REMOVE_FILE_PATHS+=( ${DATASETS[0]}.$SRC ${DATASETS[1]}.$SRC ${DATASETS[2]}.$SRC ) + REMOVE_FILE_PATHS+=( ${DATASETS[0]}.$TGT ${DATASETS[1]}.$TGT ${DATASETS[2]}.$TGT ) +} + +download_opus_data $SI_ROOT $SI_TGT +cp ${SI_OPUS_DATASETS[3]}.$SRC $SI_ROOT/OpenSubtitles2018.$SRC-$SI_TGT.$SRC +cp ${SI_OPUS_DATASETS[3]}.$SI_TGT $SI_ROOT/OpenSubtitles2018.$SRC-$SI_TGT.$SI_TGT +REMOVE_FILE_PATHS+=( ${SI_OPUS_DATASETS[3]}.$SRC ${SI_OPUS_DATASETS[3]}.$SI_TGT ) + +download_opus_data $NE_ROOT $NE_TGT + + +# Download and extract Global Voices data +GLOBAL_VOICES="$NE_ROOT/globalvoices.2018q4.ne-en" +GLOBAL_VOICES_URL="http://www.casmacat.eu/corpus/global-voices/globalvoices.ne-en.xliff.gz" + +download_data $GLOBAL_VOICES.gz $GLOBAL_VOICES_URL +gunzip -Nf $GLOBAL_VOICES.gz + +sed -ne 's?.*\(.*\).*?\1?p' $GLOBAL_VOICES > $GLOBAL_VOICES.$NE_TGT +sed -ne 's?.*]*>\(.*\).*?\1?p' $GLOBAL_VOICES > $GLOBAL_VOICES.$SRC + +REMOVE_FILE_PATHS+=( $GLOBAL_VOICES ) + +# Download and extract the bible dataset +BIBLE_TOOLS=bible-corpus-tools +XML_BIBLES=XML_Bibles +XML_BIBLES_DUP=XML_Bibles_dup + +if [ ! -e $BIBLE_TOOLS ]; then + echo "Cloning bible-corpus-tools repository..." + git clone https://github.com/christos-c/bible-corpus-tools.git +fi + +mkdir -p $BIBLE_TOOLS/bin $XML_BIBLES $XML_BIBLES_DUP +javac -cp "$BIBLE_TOOLS/lib/*" -d $BIBLE_TOOLS/bin $BIBLE_TOOLS/src/bible/readers/*.java $BIBLE_TOOLS/src/bible/*.java + +download_data bible.tar.gz "https://github.com/christos-c/bible-corpus/archive/v1.2.1.tar.gz" +tar xvzf bible.tar.gz + +cp bible-corpus-1.2.1/bibles/{Greek.xml,English.xml,Nepali.xml} $XML_BIBLES/ +cp bible-corpus-1.2.1/bibles/{Greek.xml,English-WEB.xml,Nepali.xml} $XML_BIBLES_DUP/ + +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateMLBooks $XML_BIBLES +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateMLBooks $XML_BIBLES_DUP +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateVerseAlignedBooks $XML_BIBLES +java -cp $BIBLE_TOOLS/lib/*:$BIBLE_TOOLS/bin bible.CreateVerseAlignedBooks $XML_BIBLES_DUP + +cat $XML_BIBLES/aligned/*/English.txt > $NE_ROOT/bible.$SRC-$NE_TGT.$SRC +cat $XML_BIBLES/aligned/*/Nepali.txt > $NE_ROOT/bible.$SRC-$NE_TGT.$NE_TGT +cat $XML_BIBLES_DUP/aligned/*/English-WEB.txt > $NE_ROOT/bible_dup.$SRC-$NE_TGT.$SRC +cat $XML_BIBLES_DUP/aligned/*/Nepali.txt > $NE_ROOT/bible_dup.$SRC-$NE_TGT.$NE_TGT +REMOVE_FILE_PATHS+=( bible-corpus-1.2.1 bible.tar.gz $BIBLE_TOOLS $XML_BIBLES $XML_BIBLES_DUP ) + +# Download and extract the Penn Treebank dataset +NE_TAGGED=$ROOT/new_submissions_parallel_corpus_project_Nepal +NE_TAGGED_URL="http://www.cle.org.pk/Downloads/ling_resources/parallelcorpus/NepaliTaggedCorpus.zip" +EN_TAGGED_PATCH_URL="https://dl.fbaipublicfiles.com/fairseq/data/nepali-penn-treebank.en.patch" +NE_TAGGED_PATCH_URL="https://dl.fbaipublicfiles.com/fairseq/data/nepali-penn-treebank.ne.patch" +MOSES=mosesdecoder +MOSES_TOK=$MOSES/scripts/tokenizer +EN_PATCH_REGEX="{s:\\\/:\/:g;s/\*\T\*\-\n+//g;s/\-LCB\-/\{/g;s/\-RCB\-/\}/g; s/\-LSB\-/\[/g; s/\-RSB\-/\]/g;s/\-LRB\-/\(/g; s/\-RRB\-/\)/g; s/\'\'/\"/g; s/\`\`/\"/g; s/\ +\'s\ +/\'s /g; s/\ +\'re\ +/\'re /g; s/\"\ +/\"/g; s/\ +\"/\"/g; s/\ n't([\ \.\"])/n't\1/g; s/\r+(.)/\1/g;}" +NE_PATCH_REGEX="{s:\p{Cf}::g;s:\\\/:\/:g;s/\*\T\*\-\n+//g;s/\-LCB\-/\{/g;s/\-RCB\-/\}/g; s/\-LSB\-/\[/g; s/\-RSB\-/\]/g;s/\-LRB\-/\(/g; s/\-RRB\-/\)/g; s/\'\'/\"/g; s/\`\`/\"/g; s/\ +\'s\ +/\'s /g; s/\ +\'re\ +/\'re /g; s/\"\ +/\"/g; s/\ +\"/\"/g; s/\ n't([\ \.\"])/n't\1/g; s/\r+(.)/\1/g;}" + +download_data $DATA/nepali-penn-treebank.$SRC.patch $EN_TAGGED_PATCH_URL +download_data $DATA/nepali-penn-treebank.$NE_TGT.patch $NE_TAGGED_PATCH_URL +download_data original.zip $NE_TAGGED_URL +unzip -o original.zip -d $ROOT + +cat $NE_TAGGED/00.txt $NE_TAGGED/01.txt $NE_TAGGED/02.txt > $NE_TAGGED/nepali-penn-treebank.$SRC +cat $NE_TAGGED/00ne_revised.txt $NE_TAGGED/01ne_revised.txt $NE_TAGGED/02ne_revised.txt > $NE_TAGGED/nepali-penn-treebank.$NE_TGT + +patch $NE_TAGGED/nepali-penn-treebank.$SRC -i $DATA/nepali-penn-treebank.$SRC.patch -o $NE_TAGGED/nepali-penn-treebank-patched.$SRC +patch $NE_TAGGED/nepali-penn-treebank.$NE_TGT -i $DATA/nepali-penn-treebank.$NE_TGT.patch -o $NE_TAGGED/nepali-penn-treebank-patched.$NE_TGT + +if [ ! -e $MOSES ]; then + echo "Cloning moses repository..." + git clone https://github.com/moses-smt/mosesdecoder.git +fi + +cat $NE_TAGGED/nepali-penn-treebank-patched.$SRC | \ + perl -anpe "$EN_PATCH_REGEX" | \ + $MOSES_TOK/tokenizer.perl -l $SRC | \ + $MOSES_TOK/detokenizer.perl -l $SRC > $NE_ROOT/nepali-penn-treebank.$SRC + +cat $NE_TAGGED/nepali-penn-treebank-patched.$NE_TGT | \ + perl -CIO -anpe "$NE_PATCH_REGEX" | \ + $MOSES_TOK/detokenizer.perl -l $SRC > $NE_ROOT/nepali-penn-treebank.$NE_TGT + + +# Download nepali dictionary data +NE_DICT=$NE_ROOT/dictionaries +download_data $NE_DICT "http://www.seas.upenn.edu/~nlp/resources/TACL-data-release/dictionaries.tar.gz" +tar xvzf $NE_DICT +cp dictionaries/dict.ne $NE_ROOT/dictionary.$NE_TGT-$SRC +REMOVE_FILE_PATHS+=( $NE_DICT dictionaries ) + +REMOVE_FILE_PATHS+=( $MOSES $NE_TAGGED original.zip $DATA/nepali-penn-treebank.$SRC.patch $DATA/nepali-penn-treebank.$NE_TGT.patch ) + + +# Remove the temporary files +for ((i=0;i<${#REMOVE_FILE_PATHS[@]};++i)); do + rm -rf ${REMOVE_FILE_PATHS[i]} +done + +# Copy the training data +si=si_LK +ne=ne_NP +en=en_XX +cat $SI_ROOT/GNOMEKDEUbuntu.en-si.si $SI_ROOT/OpenSubtitles2018.en-si.si > $DESTDIR/train.$si-$en.$si +cat $SI_ROOT/GNOMEKDEUbuntu.en-si.en $SI_ROOT/OpenSubtitles2018.en-si.en > $DESTDIR/train.$si-$en.$en + +cat $NE_ROOT/bible_dup.en-ne.ne $NE_ROOT/bible.en-ne.ne $NE_ROOT/globalvoices.2018q4.ne-en.ne $NE_ROOT/GNOMEKDEUbuntu.en-ne.ne $NE_ROOT/nepali-penn-treebank.ne > $DESTDIR/train.$ne-$en.$ne +cat $NE_ROOT/bible_dup.en-ne.en $NE_ROOT/bible.en-ne.en $NE_ROOT/globalvoices.2018q4.ne-en.en $NE_ROOT/GNOMEKDEUbuntu.en-ne.en $NE_ROOT/nepali-penn-treebank.en > $DESTDIR/train.$ne-$en.$en + + +#Download the test sets +wget https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz +tar -xvzf wikipedia_en_ne_si_test_sets.tgz + +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.ne-en.ne $DESTDIR/valid.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.ne-en.en $DESTDIR/valid.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.si-en.si $DESTDIR/valid.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.dev.si-en.en $DESTDIR/valid.$si-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.ne-en.ne $DESTDIR/devtest.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.ne-en.en $DESTDIR/devtest.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.si-en.si $DESTDIR/devtest.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.devtest.si-en.en $DESTDIR/devtest.$si-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.test.ne-en.ne $DESTDIR/test.$ne-$en.$ne +cp wikipedia_en_ne_si_test_sets/wikipedia.test.ne-en.en $DESTDIR/test.$ne-$en.$en + +cp wikipedia_en_ne_si_test_sets/wikipedia.test.si-en.si $DESTDIR/test.$si-$en.$si +cp wikipedia_en_ne_si_test_sets/wikipedia.test.si-en.en $DESTDIR/test.$si-$en.$en + +rm -rf wikipedia_en_ne_si_test_sets.tgz wikipedia_en_ne_si_test_sets diff --git a/fairseq/examples/multilingual/data_scripts/download_iitb.sh b/fairseq/examples/multilingual/data_scripts/download_iitb.sh new file mode 100644 index 0000000000000000000000000000000000000000..a884e20839e2a41a57405cb6af362e37bd16ab6f --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_iitb.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +IITB=$WORKDIR_ROOT/IITB +mkdir -p $IITB +pushd $IITB + +wget http://www.cfilt.iitb.ac.in/~moses/iitb_en_hi_parallel/iitb_corpus_download/parallel.tgz +tar -xvzf parallel.tgz + +wget http://www.cfilt.iitb.ac.in/~moses/iitb_en_hi_parallel/iitb_corpus_download/dev_test.tgz +tar -xvzf dev_test.tgz + +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + +cp parallel/IITB.en-hi.en $DESTDIR/train.hi_IN-en_XX.en_XX +cp parallel/IITB.en-hi.hi $DESTDIR/train.hi_IN-en_XX.hi_IN + +cp dev_test/dev.en $DESTDIR/valid.hi_IN-en_XX.en_XX +cp dev_test/dev.hi $DESTDIR/valid.hi_IN-en_XX.hi_IN + +cp dev_test/test.en $DESTDIR/test.hi_IN-en_XX.en_XX +cp dev_test/test.hi $DESTDIR/test.hi_IN-en_XX.hi_IN +popd \ No newline at end of file diff --git a/fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh b/fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh new file mode 100644 index 0000000000000000000000000000000000000000..ca3591b3db1715f136773d62e4b9b9ede97d436c --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh @@ -0,0 +1,225 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +#echo 'Cloning Moses github repository (for tokenization scripts)...' +#git clone https://github.com/moses-smt/mosesdecoder.git + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + + +data_root=${WORKDIR_ROOT}/iwsltv2 +DESTDIR=${WORKDIR_ROOT}/ML50/raw + + +langs="ar_AR it_IT nl_XX ko_KR vi_VN" +echo "data_root: $data_root" + +download_path=${data_root}/downloads +raw=${DESTDIR} +tmp=${data_root}/tmp +orig=${data_root}/orig + +mkdir -p $download_path $orig $raw $tmp +####################### +download_iwslt(){ + iwslt_key=$1 + src=$2 + tgt=$3 + save_prefix=$4 + pushd ${download_path} + if [[ ! -f ${save_prefix}$src-$tgt.tgz ]]; then + wget https://wit3.fbk.eu/archive/${iwslt_key}/texts/$src/$tgt/$src-$tgt.tgz -O ${save_prefix}$src-$tgt.tgz + [ $? -eq 0 ] && return 0 + fi + popd +} + +extract_iwslt(){ + src=$1 + tgt=$2 + prefix=$3 + pushd $orig + tar zxvf ${download_path}/${prefix}$src-${tgt}.tgz + popd +} + +generate_train(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + for ll in $lsrc $ltgt; do + l=${ll:0:2} + f="$orig/*/train.tags.$src-$tgt.$l" + f_raw=$raw/train.$lsrc-$ltgt.$ll + cat $f \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | sed -e 's///g' \ + | sed -e 's/<\/title>//g' \ + | sed -e 's/<description>//g' \ + | sed -e 's/<\/description>//g' \ + | sed 's/^\s*//g' \ + | sed 's/\s*$//g' \ + > $f_raw + [ $? -eq 0 ] && echo "extracted $f to $f_raw" + done + return 0 +} + +convert_valid_test(){ + src=$1 + tgt=$2 + for l in $src $tgt; do + echo "lang: ${l}" + for o in `ls $orig/*/IWSLT*.TED*.$src-$tgt.$l.xml`; do + fname=${o##*/} + f=$tmp/${fname%.*} + echo "$o => $f" + grep '<seg id' $o \ + | sed -e 's/<seg id="[0-9]*">\s*//g' \ + | sed -e 's/\s*<\/seg>\s*//g' \ + | sed -e "s/\’/\'/g" \ + > $f + echo "" + done + done +} + +generate_subset(){ + lsrc=$1 + ltgt=$2 + src=${lsrc:0:2} + tgt=${ltgt:0:2} + subset=$3 + prefix=$4 + for ll in $lsrc $ltgt; do + l=${ll:0:2} + f=$tmp/$prefix.${src}-${tgt}.$l + if [[ -f $f ]]; then + cp $f $raw/$subset.${lsrc}-$ltgt.${ll} + fi + done +} +################# + +echo "downloading iwslt training and dev data" +# using multilingual for it, nl +download_iwslt "2017-01-trnmted" DeEnItNlRo DeEnItNlRo +download_iwslt "2017-01-trnted" ar en +download_iwslt "2017-01-trnted" en ar +download_iwslt "2017-01-trnted" ko en +download_iwslt "2017-01-trnted" en ko +download_iwslt "2015-01" vi en +download_iwslt "2015-01" en vi + +echo "donwloading iwslt test data" +download_iwslt "2017-01-mted-test" it en "test." +download_iwslt "2017-01-mted-test" en it "test." +download_iwslt "2017-01-mted-test" nl en "test." +download_iwslt "2017-01-mted-test" en nl "test." + +download_iwslt "2017-01-ted-test" ar en "test." +download_iwslt "2017-01-ted-test" en ar "test." +download_iwslt "2017-01-ted-test" ko en "test." +download_iwslt "2017-01-ted-test" en ko "test." +download_iwslt "2015-01-test" vi en "test." +download_iwslt "2015-01-test" en vi "test." + +echo "extract training data tar balls" +extract_iwslt DeEnItNlRo DeEnItNlRo +extract_iwslt ar en +extract_iwslt en ar +extract_iwslt ko en +extract_iwslt en ko +extract_iwslt vi en +extract_iwslt en vi + + +echo "extracting iwslt test data" +for lang in $langs; do + l=${lang:0:2} + extract_iwslt $l en "test." + extract_iwslt en $l "test." +done + +echo "convert dev and test data" +for lang in $langs; do + s_lang=${lang:0:2} + convert_valid_test $s_lang en + convert_valid_test en $s_lang +done + + + +echo "creating training data into $raw" +for lang in $langs; do + generate_train $lang en_XX + generate_train en_XX $lang +done + +echo "creating iwslt dev data into raw" +generate_subset en_XX vi_VN valid "IWSLT15.TED.tst2013" +generate_subset vi_VN en_XX valid "IWSLT15.TED.tst2013" + +generate_subset en_XX ar_AR valid "IWSLT17.TED.tst2016" +generate_subset ar_AR en_XX valid "IWSLT17.TED.tst2016" +generate_subset en_XX ko_KR valid "IWSLT17.TED.tst2016" +generate_subset ko_KR en_XX valid "IWSLT17.TED.tst2016" + + +generate_subset en_XX it_IT valid "IWSLT17.TED.tst2010" +generate_subset it_IT en_XX valid "IWSLT17.TED.tst2010" +generate_subset en_XX nl_XX valid "IWSLT17.TED.tst2010" +generate_subset nl_XX en_XX valid "IWSLT17.TED.tst2010" + +echo "creating iswslt test data into raw" +generate_subset en_XX vi_VN test "IWSLT15.TED.tst2015" +generate_subset vi_VN en_XX test "IWSLT15.TED.tst2015" + +generate_subset en_XX ar_AR test "IWSLT17.TED.tst2017" +generate_subset ar_AR en_XX test "IWSLT17.TED.tst2017" +generate_subset en_XX ko_KR test "IWSLT17.TED.tst2017" +generate_subset ko_KR en_XX test "IWSLT17.TED.tst2017" + +generate_subset en_XX it_IT test "IWSLT17.TED.tst2017.mltlng" +generate_subset it_IT en_XX test "IWSLT17.TED.tst2017.mltlng" +generate_subset en_XX nl_XX test "IWSLT17.TED.tst2017.mltlng" +generate_subset nl_XX en_XX test "IWSLT17.TED.tst2017.mltlng" + +# normalze iwslt directions into x-en +pushd $raw +for lang in $langs; do + for split in test valid; do + x_en_f1=$split.$lang-en_XX.en_XX + x_en_f2=$split.$lang-en_XX.${lang} + + en_x_f1=$split.en_XX-$lang.en_XX + en_x_f2=$split.en_XX-$lang.${lang} + + if [ -f $en_x_f1 ] && [ ! -f $x_en_f1 ]; then + echo "cp $en_x_f1 $x_en_f1" + cp $en_x_f1 $x_en_f1 + fi + if [ -f $x_en_f2 ] && [ ! -f $x_en_f2 ]; then + echo "cp $en_x_f2 $x_en_f2" + cp $en_x_f2 $x_en_f2 + fi + done +done +popd \ No newline at end of file diff --git a/fairseq/examples/multilingual/data_scripts/download_lotus.sh b/fairseq/examples/multilingual/data_scripts/download_lotus.sh new file mode 100644 index 0000000000000000000000000000000000000000..c08c701314a8e575637deff78381ab02c2ef6728 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_lotus.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +SRCDIR=$WORKDIR_ROOT/indic_languages_corpus +DESTDIR=${WORKDIR_ROOT}/ML50/raw/ +mkdir -p $SRCDIR +mkdir -p $DESTDIR + +cd $SRCDIR +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/indic-multilingual/indic_languages_corpus.tar.gz +tar -xvzf indic_languages_corpus.tar.gz + +SRC_EXTRACT_DIR=$SRCDIR/indic_languages_corpus/bilingual + +cp $SRC_EXTRACT_DIR/ml-en/train.ml $DESTDIR/train.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/train.en $DESTDIR/train.ml_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ml-en/dev.ml $DESTDIR/valid.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/dev.en $DESTDIR/valid.ml_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ml-en/test.ml $DESTDIR/test.ml_IN-en_XX.ml_IN +cp $SRC_EXTRACT_DIR/ml-en/test.en $DESTDIR/test.ml_IN-en_XX.en_XX + +cp $SRC_EXTRACT_DIR/ur-en/train.ur $DESTDIR/train.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/train.en $DESTDIR/train.ur_PK-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ur-en/dev.ur $DESTDIR/valid.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/dev.en $DESTDIR/valid.ur_PK-en_XX.en_XX +cp $SRC_EXTRACT_DIR/ur-en/test.ur $DESTDIR/test.ur_PK-en_XX.ur_PK +cp $SRC_EXTRACT_DIR/ur-en/test.en $DESTDIR/test.ur_PK-en_XX.en_XX + +cp $SRC_EXTRACT_DIR/te-en/train.te $DESTDIR/train.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/train.en $DESTDIR/train.te_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/te-en/dev.te $DESTDIR/valid.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/dev.en $DESTDIR/valid.te_IN-en_XX.en_XX +cp $SRC_EXTRACT_DIR/te-en/test.te $DESTDIR/test.te_IN-en_XX.te_IN +cp $SRC_EXTRACT_DIR/te-en/test.en $DESTDIR/test.te_IN-en_XX.en_XX diff --git a/fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py b/fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..eb756680fa7dc31a14ba45c216776a6d60c16b60 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import os +import csv +from collections import defaultdict +from six.moves import zip +import io +import wget +import sys + +from subprocess import check_call, check_output + +# scripts and data locations +CWD = os.getcwd() +UTILS = f"{CWD}/utils" + +MOSES = f"{UTILS}/mosesdecoder" + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +# please donwload mosesdecoder here: +detok_cmd = f'{MOSES}/scripts/tokenizer/detokenizer.perl' + + +def call(cmd): + print(f"Executing: {cmd}") + check_call(cmd, shell=True) + +class MultiLingualAlignedCorpusReader(object): + """A class to read TED talk dataset + """ + + def __init__(self, corpus_path, delimiter='\t', + target_token=True, bilingual=True, corpus_type='file', + lang_dict={'source': ['fr'], 'target': ['en']}, + eval_lang_dict=None, zero_shot=False, + detok=True, + ): + + self.empty_line_flag = 'NULL' + self.corpus_path = corpus_path + self.delimiter = delimiter + self.bilingual = bilingual + self.lang_dict = lang_dict + self.lang_set = set() + self.target_token = target_token + self.zero_shot = zero_shot + self.eval_lang_dict = eval_lang_dict + self.corpus_type = corpus_type + self.detok = detok + + for list_ in self.lang_dict.values(): + for lang in list_: + self.lang_set.add(lang) + + self.data = dict() + self.data['train'] = self.read_aligned_corpus(split_type='train') + self.data['test'] = self.read_aligned_corpus(split_type='test') + self.data['dev'] = self.read_aligned_corpus(split_type='dev') + + def read_data(self, file_loc_): + data_list = list() + with io.open(file_loc_, 'r', encoding='utf8') as fp: + for line in fp: + try: + text = line.strip() + except IndexError: + text = self.empty_line_flag + data_list.append(text) + return data_list + + def filter_text(self, dict_): + if self.target_token: + field_index = 1 + else: + field_index = 0 + data_dict = defaultdict(list) + list1 = dict_['source'] + list2 = dict_['target'] + for sent1, sent2 in zip(list1, list2): + try: + src_sent = ' '.join(sent1.split()[field_index: ]) + except IndexError: + src_sent = 'NULL' + + if src_sent.find(self.empty_line_flag) != -1 or len(src_sent) == 0: + continue + + elif sent2.find(self.empty_line_flag) != -1 or len(sent2) == 0: + continue + + else: + data_dict['source'].append(sent1) + data_dict['target'].append(sent2) + return data_dict + + def read_file(self, split_type, data_type): + return self.data[split_type][data_type] + + def save_file(self, path_, split_type, data_type, lang): + tok_file = tok_file_name(path_, lang) + with io.open(tok_file, 'w', encoding='utf8') as fp: + for line in self.data[split_type][data_type]: + fp.write(line + '\n') + if self.detok: + de_tok(tok_file, lang) + + def add_target_token(self, list_, lang_id): + new_list = list() + token = '__' + lang_id + '__' + for sent in list_: + new_list.append(token + ' ' + sent) + return new_list + + def read_from_single_file(self, path_, s_lang, t_lang): + data_dict = defaultdict(list) + with io.open(path_, 'r', encoding='utf8') as fp: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + for row in reader: + data_dict['source'].append(row[s_lang]) + data_dict['target'].append(row[t_lang]) + + if self.target_token: + text = self.add_target_token(data_dict['source'], t_lang) + data_dict['source'] = text + + return data_dict['source'], data_dict['target'] + + def read_aligned_corpus(self, split_type='train'): + data_dict = defaultdict(list) + iterable = [] + s_list = [] + t_list = [] + + if self.zero_shot: + if split_type == "train": + iterable = zip(self.lang_dict['source'], self.lang_dict['target']) + else: + iterable = zip(self.eval_lang_dict['source'], self.eval_lang_dict['target']) + + elif self.bilingual: + iterable = itertools.product(self.lang_dict['source'], self.lang_dict['target']) + + for s_lang, t_lang in iterable: + if s_lang == t_lang: + continue + if self.corpus_type == 'file': + split_type_file_path = os.path.join(self.corpus_path, + "all_talks_{}.tsv".format(split_type)) + s_list, t_list = self.read_from_single_file(split_type_file_path, + s_lang=s_lang, + t_lang=t_lang) + data_dict['source'] += s_list + data_dict['target'] += t_list + new_data_dict = self.filter_text(data_dict) + return new_data_dict + + +def read_langs(corpus_path): + split_type_file_path = os.path.join(corpus_path, 'extracted', + "all_talks_dev.tsv") + with io.open(split_type_file_path, 'r', encoding='utf8') as fp: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + header = next(reader) + return [k for k in header.keys() if k != 'talk_name'] + +def extra_english(corpus_path, split): + split_type_file_path = os.path.join(corpus_path, + f"all_talks_{split}.tsv") + output_split_type_file_path = os.path.join(corpus_path, + f"all_talks_{split}.en") + with io.open(split_type_file_path, 'r', encoding='utf8') as fp, io.open(output_split_type_file_path, 'w', encoding='utf8') as fw: + reader = csv.DictReader(fp, delimiter='\t', quoting=csv.QUOTE_NONE) + for row in reader: + line = row['en'] + fw.write(line + '\n') + de_tok(output_split_type_file_path, 'en') + + + +def tok_file_name(filename, lang): + seps = filename.split('.') + seps.insert(-1, 'tok') + tok_file = '.'.join(seps) + return tok_file + +def de_tok(tok_file, lang): + # seps = tok_file.split('.') + # seps.insert(-1, 'detok') + # de_tok_file = '.'.join(seps) + de_tok_file = tok_file.replace('.tok.', '.') + cmd = 'perl {detok_cmd} -l {lang} < {tok_file} > {de_tok_file}'.format( + detok_cmd=detok_cmd, tok_file=tok_file, + de_tok_file=de_tok_file, lang=lang[:2]) + call(cmd) + +def extra_bitex( + ted_data_path, + lsrc_lang, + ltrg_lang, + target_token, + output_data_path, +): + def get_ted_lang(lang): + long_langs = ['pt-br', 'zh-cn', 'zh-tw', 'fr-ca'] + if lang[:5] in long_langs: + return lang[:5] + elif lang[:4] =='calv': + return lang[:5] + elif lang in ['pt_BR', 'zh_CN', 'zh_TW', 'fr_CA']: + return lang.lower().replace('_', '-') + return lang[:2] + src_lang = get_ted_lang(lsrc_lang) + trg_lang = get_ted_lang(ltrg_lang) + train_lang_dict={'source': [src_lang], 'target': [trg_lang]} + eval_lang_dict = {'source': [src_lang], 'target': [trg_lang]} + + obj = MultiLingualAlignedCorpusReader(corpus_path=ted_data_path, + lang_dict=train_lang_dict, + target_token=target_token, + corpus_type='file', + eval_lang_dict=eval_lang_dict, + zero_shot=False, + bilingual=True) + + os.makedirs(output_data_path, exist_ok=True) + lsrc_lang = lsrc_lang.replace('-', '_') + ltrg_lang = ltrg_lang.replace('-', '_') + obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='train', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/train.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='train', data_type='target', lang=trg_lang) + + obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='test', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/test.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='test', data_type='target', lang=trg_lang) + + obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{lsrc_lang}", + split_type='dev', data_type='source', lang=src_lang) + obj.save_file(output_data_path + f"/valid.{lsrc_lang}-{ltrg_lang}.{ltrg_lang}", + split_type='dev', data_type='target', lang=trg_lang) + + +def bar_custom(current, total, width=80): + print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') + + +def download_and_extract(download_to, extract_to): + url = 'http://phontron.com/data/ted_talks.tar.gz' + filename = f"{download_to}/ted_talks.tar.gz" + if os.path.exists(filename): + print(f'{filename} has already been downloaded so skip') + else: + filename = wget.download(url, filename, bar=bar_custom) + if os.path.exists(f'{extract_to}/all_talks_train.tsv'): + print(f'Already extracted so skip') + else: + extract_cmd = f'tar xzfv "{filename}" -C "{extract_to}"' + call(extract_cmd) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--ted_data_path', type=str, default=WORKDIR_ROOT, required=False) + parser.add_argument( + '--direction-list', + type=str, + # default=None, + #for ML50 + default=( + "bn_IN-en_XX,he_IL-en_XX,fa_IR-en_XX,id_ID-en_XX,sv_SE-en_XX,pt_XX-en_XX,ka_GE-en_XX,ka_GE-en_XX,th_TH-en_XX," + "mr_IN-en_XX,hr_HR-en_XX,uk_UA-en_XX,az_AZ-en_XX,mk_MK-en_XX,gl_ES-en_XX,sl_SI-en_XX,mn_MN-en_XX," + #non-english directions + # "fr_XX-de_DE," # replaced with wmt20 + # "ja_XX-ko_KR,es_XX-pt_XX,ru_RU-sv_SE,hi_IN-bn_IN,id_ID-ar_AR,cs_CZ-pl_PL,ar_AR-tr_TR" + ), + required=False) + parser.add_argument('--target-token', action='store_true', default=False) + parser.add_argument('--extract-all-english', action='store_true', default=False) + + args = parser.parse_args() + + import sys + import json + + # TED Talks data directory + ted_data_path = args.ted_data_path + + download_to = f'{ted_data_path}/downloads' + extract_to = f'{ted_data_path}/extracted' + + #DESTDIR=${WORKDIR_ROOT}/ML50/raw/ + output_path = f'{ted_data_path}/ML50/raw' + os.makedirs(download_to, exist_ok=True) + os.makedirs(extract_to, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + download_and_extract(download_to, extract_to) + + + if args.extract_all_english: + for split in ['train', 'dev', 'test']: + extra_english(ted_data_path, split) + exit(0) + if args.direction_list is not None: + directions = args.direction_list.strip().split(',') + directions = [tuple(d.strip().split('-', 1)) for d in directions if d] + else: + langs = read_langs(ted_data_path) + # directions = [ + # '{}.{}'.format(src, tgt) + # for src in langs + # for tgt in langs + # if src < tgt + # ] + directions = [('en', tgt) for tgt in langs if tgt != 'en'] + print(f'num directions={len(directions)}: {directions}') + + for src_lang, trg_lang in directions: + print('--working on {}-{}'.format(src_lang, trg_lang)) + extra_bitex( + extract_to, + src_lang, + trg_lang, + target_token=args.target_token, + output_data_path=output_path + ) diff --git a/fairseq/examples/multilingual/data_scripts/download_wat19_my.sh b/fairseq/examples/multilingual/data_scripts/download_wat19_my.sh new file mode 100644 index 0000000000000000000000000000000000000000..c1e2d47287a29af4576e7a63641e8152ecb63c44 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_wat19_my.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + +SRCDIR=$WORKDIR_ROOT/indic_languages_corpus +DESTDIR=$WORKDIR_ROOT/ML50/raw +mkdir -p $SRCDIR +mkdir -p $DESTDIR + +WAT_MY_EN=wat2020.my-en.zip +cd $SRCDIR +# please refer to http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/ for latest URL if the following url expired +#- The data used for WAT2020 are identical to those used in WAT2019. +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/$WAT_MY_EN +unzip $WAT_MY_EN + + +SRC_EXTRACT_DIR=$SRCDIR/wat2020.my-en/alt + +cp $SRC_EXTRACT_DIR/train.alt.en $DESTDIR/train.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/train.alt.my $DESTDIR/train.my_MM-en_XX.my_MM +cp $SRC_EXTRACT_DIR/dev.alt.en $DESTDIR/valid.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/dev.alt.my $DESTDIR/valid.my_MM-en_XX.my_MM +cp $SRC_EXTRACT_DIR/test.alt.en $DESTDIR/test.my_MM-en_XX.en_XX +cp $SRC_EXTRACT_DIR/test.alt.my $DESTDIR/test.my_MM-en_XX.my_MM diff --git a/fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py b/fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py new file mode 100644 index 0000000000000000000000000000000000000000..3465731eb3e55047c44d1b336a97e99cb3a89a53 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py @@ -0,0 +1,899 @@ +from typing import NamedTuple, List +from urllib.parse import urlparse +import os, sys +import subprocess +from subprocess import check_call, check_output +import glob +import wget +import re +import multiprocessing as mp +from functools import partial +import pathlib +from collections import OrderedDict + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + +# scripts and data locations +CWD = os.getcwd() +UTILS = f"{CWD}/utils" + +MOSES = f"{UTILS}/mosesdecoder" +SGM_TOOL = f'{MOSES}/scripts/ems/support/input-from-sgm.perl' + +TMX2CORPUS = f"{UTILS}/tmx2corpus" +TMX_TOOL = f'python {TMX2CORPUS}/tmx2corpus.py' + +to_data_path = f'{WORKDIR_ROOT}/wmt' +download_to = f'{to_data_path}/downloads' +manually_downloads = f'{to_data_path}/downloads' +extract_to = f'{to_data_path}/extracted' +#DESTDIR=${WORKDIR_ROOT}/ML50/raw/ +raw_data = f'{WORKDIR_ROOT}/ML50/raw' +#### + +class DLDataset(NamedTuple): + name: str + train_urls: List[str] + valid_urls: List[str] + test_urls: List[str] + train_files_patterns: List[str] = [] + valid_files_patterns: List[str] = [] + test_files_patterns: List[str] = [] + + + +def bar_custom(current, total, width=80): + print("Downloading: %d%% [%d / %d] Ks" % (current / total * 100, current / 1000, total / 1000), end='\r') + +def get_downloaded_file(dl_folder, url): + if isinstance(url, tuple): + url, f = url + else: + url_f = urlparse(url) + # f = os.path.split(url_f.path)[-1] + f = '_'.join(url_f.path.split('/')[1:]) + return url, f"{dl_folder}/{f}" + +def download_parts_and_combine(dl_folder, urls, filename): + parts = [] + for url_record in urls: + url, part_file = get_downloaded_file(dl_folder, url_record) + if os.path.exists(part_file): + print(f'{part_file} has already been downloaded so skip') + else: + part_file = wget.download(url, part_file, bar=bar_custom) + parts.append(part_file) + + def get_combine_cmd(parts): + #default as tar.gz.?? + return f'cat {" ".join(parts)} > {filename}' + + combine_cmd = get_combine_cmd(parts) + call(combine_cmd, debug=True) + return filename + +def download_a_url(dl_folder, url): + url, filename = get_downloaded_file(dl_folder, url) + if os.path.exists(filename): + print(f'{filename} has already been downloaded so skip') + return filename + + print(f'downloading {url} to {filename}') + if isinstance(url, list) or isinstance(url, tuple): + download_parts_and_combine(dl_folder, url, filename) + else: + wget.download(url, filename, bar=bar_custom) + print(f'dowloaded: {filename}') + return filename + +def download_files(dl_folder, urls, completed_urls={}): + for url_record in urls: + url, _ = get_downloaded_file(dl_folder, url_record) + filename = download_a_url(dl_folder, url_record) + completed_urls[str(url)] = filename + return completed_urls + +def check_need_manual_downalod(dl_folder, to_manually_download_urls): + to_be_manually_dowloaded = [] + manually_completed_urls = {} + for url_record, instruction in to_manually_download_urls: + url, filename = get_downloaded_file(dl_folder, url_record) + if not os.path.exists(filename): + print(f'{url} need to be download manually, please download it manually following {instruction}; and copy it to {filename}') + to_be_manually_dowloaded.append((url, filename)) + else: + manually_completed_urls[url] = filename + # if len(to_be_manually_dowloaded) > 0: + # raise ValueError('Missing files that need to be downloaded manually; stop the process now.') + return to_be_manually_dowloaded + +def download_dataset(to_folder, dl_dataset, completed_urls={}): + download_files(to_folder, dl_dataset.train_urls, completed_urls) + download_files(to_folder, dl_dataset.valid_urls, completed_urls) + download_files(to_folder, dl_dataset.test_urls, completed_urls) + print('completed downloading') + return completed_urls + +def call(cmd, debug=False): + if debug: + print(cmd) + check_call(cmd, shell=True) + + +def get_extract_name(file_path): + path = os.path.split(file_path) + return path[-1] + '_extract' #.split('.')[0] + +def extract_file(downloaded_file, extract_folder, get_extract_name=get_extract_name, debug=False): + extract_name = get_extract_name(downloaded_file) + extract_to = f'{extract_folder}/{extract_name}' + os.makedirs(extract_to, exist_ok=True) + if os.path.exists(f'{extract_to}/DONE'): + print(f'{downloaded_file} has already been extracted to {extract_to} so skip') + return extract_to + def get_extract_cmd(filename): + if filename.endswith('.tgz') or filename.endswith('tar.gz'): + return f'tar xzfv {filename} -C {extract_to}' + elif filename.endswith('.gz.tar'): + return f'tar xfv {filename} -C {extract_to}; (cd {extract_to}; gzip -d *.gz; [ $? -eq 0 ] || gzip -d */*.gz)' + elif filename.endswith('.tar'): + return f'tar xfv {filename} -C {extract_to}' + elif filename.endswith('.gz'): + return f'cp {filename} {extract_to}; (cd {extract_to}; gzip -d *.gz)' + elif filename.endswith('.zip'): + return f'unzip {filename} -d {extract_to}' + extract_cmd = get_extract_cmd(downloaded_file) + print(f'extracting {downloaded_file}') + if isinstance(extract_cmd, list): + for c in extract_cmd: + call(c, debug=debug) + else: + call(extract_cmd, debug=debug) + call(f'echo DONE > {extract_to}/DONE') + return extract_to + + +def extract_all_files( + completed_urls, extract_folder, + get_extract_name=get_extract_name, + completed_extraction={}, + debug=False): + extracted_folders = OrderedDict() + for url, downloaded_file in set(completed_urls.items()): + if downloaded_file in completed_extraction: + print(f'{downloaded_file} is already extracted; so skip') + continue + folder = extract_file(downloaded_file, extract_folder, get_extract_name, debug) + extracted_folders[url] = folder + return extracted_folders + + +def my_glob(folder): + for p in [f'{folder}/*', f'{folder}/*/*', f'{folder}/*/*/*']: + for f in glob.glob(p): + yield f + + +def sgm2raw(sgm, debug): + to_file = sgm[0:len(sgm) - len('.sgm')] + if os.path.exists(to_file): + debug and print(f'{sgm} already converted to {to_file}; so skip') + return to_file + cmd = f'{SGM_TOOL} < {sgm} > {to_file}' + call(cmd, debug) + return to_file + +def tmx2raw(tmx, debug): + to_file = tmx[0:len(tmx) - len('.tmx')] + to_folder = os.path.join(*os.path.split(tmx)[:-1]) + if os.path.exists(f'{to_folder}/bitext.en'): + debug and print(f'{tmx} already extracted to {to_file}; so skip') + return to_file + cmd = f'(cd {to_folder}; {TMX_TOOL} {tmx})' + call(cmd, debug) + return to_file + +CZENG16_REGEX = re.compile(r'.*?data.plaintext-format/0[0-9]train$') +WMT19_WIKITITLES_REGEX = re.compile(r'.*?wikititles-v1.(\w\w)-en.tsv.gz') +TSV_REGEX = re.compile(r'.*?(\w\w)-(\w\w).tsv$') + + + +def cut_wikitles(wiki_file, debug): + # different languages have different file names: + if wiki_file.endswith('wiki/fi-en/titles.fi-en'): + to_file1 = f'{wiki_file}.fi' + to_file2 = f'{wiki_file}.en' + BACKSLASH = '\\' + cmd1 = f"cat {wiki_file} | sed 's/|||/{BACKSLASH}t/g' |cut -f1 |awk '{{$1=$1}};1' > {to_file1}" + cmd2 = f"cat {wiki_file} | sed 's/|||/{BACKSLASH}t/g' |cut -f2 |awk '{{$1=$1}};1' > {to_file2}" +# elif WMT19_WIKITITLES_REGEX.match(wiki_file): +# src = WMT19_WIKITITLES_REGEX.match(wiki_file).groups()[0] +# to_file1 = f'{wiki_file}.{src}' +# to_file2 = f'{wiki_file}.en' +# cmd1 = f"cat {wiki_file} | cut -f1 |awk '{{$1=$1}};1' > {to_file1}" +# cmd2 = f"cat {wiki_file} | cut -f2 |awk '{{$1=$1}};1' > {to_file2}" + else: + return None + if os.path.exists(to_file1) and os.path.exists(to_file2): + debug and print(f'{wiki_file} already processed to {to_file1} and {to_file2}; so skip') + return wiki_file + + call(cmd1, debug=debug) + call(cmd2, debug=debug) + return wiki_file + +def cut_tsv(file, debug): + m = TSV_REGEX.match(file) + if m is None: + raise ValueError(f'{file} is not matching tsv pattern') + src = m.groups()[0] + tgt = m.groups()[1] + + to_file1 = f'{file}.{src}' + to_file2 = f'{file}.{tgt}' + cmd1 = f"cat {file} | cut -f1 |awk '{{$1=$1}};1' > {to_file1}" + cmd2 = f"cat {file} | cut -f2 |awk '{{$1=$1}};1' > {to_file2}" + if os.path.exists(to_file1) and os.path.exists(to_file2): + debug and print(f'{file} already processed to {to_file1} and {to_file2}; so skip') + return file + + call(cmd1, debug=debug) + call(cmd2, debug=debug) + return file + + +def convert_file_if_needed(file, debug): + if file.endswith('.sgm'): + return sgm2raw(file, debug) + elif file.endswith('.tmx'): + return tmx2raw(file, debug) + elif file.endswith('wiki/fi-en/titles.fi-en'): + return cut_wikitles(file, debug) +# elif WMT19_WIKITITLES_REGEX.match(file): +# return cut_wikitles(file, debug) + elif file.endswith('.tsv'): + return cut_tsv(file, debug) + elif CZENG16_REGEX.match(file): + return convert2czeng17(file, debug) + else: + return file + + +def convert_files_if_needed(extracted_foldrs, my_glob=my_glob, debug=False): + return { + url: list(sorted(set(convert_file_if_needed(f, debug)) for f in sorted(set(my_glob(folder))))) + for url, folder in extracted_foldrs.items() + } + +def match_patt(file_path, file_pattern, src, tgt, lang): + return file_pattern.format(src=src, tgt=tgt, lang=lang) in file_path + +def match_patts(file_path, file_patterns, src, tgt, lang): + for file_pattern in file_patterns: + params = { k: v for k, v in [('src', src), ('tgt', tgt), ('lang', lang)] if k in file_pattern} + matching = file_pattern.format(**params) + + if isinstance(file_pattern, tuple): + pattern, directions = file_pattern + if f'{src}-{tgt}' in directions and matching in file_path: + return True + else: + if matching in file_path: + return True + return False + +def extracted_glob(extracted_folder, file_patterns, src, tgt, lang): + def get_matching_pattern(file_pattern): + params = { + k: v + for k, v in [('src', src), ('tgt', tgt), ('lang', lang)] + if '{' + k + '}' in file_pattern + } + file_pattern = re.sub(r'{src:(.*?)}', r'\1' if lang == src else '', file_pattern) + file_pattern = re.sub(r'{tgt:(.*?)}', r'\1' if lang == tgt else '', file_pattern) + file_pattern = file_pattern.format(**params) + return file_pattern + for file_pattern in file_patterns: + if isinstance(file_pattern, tuple): + file_pattern, lang_pairs = file_pattern + if f'{src}-{tgt}' not in lang_pairs: + continue +# print('working on pattern: ', file_pattern, lang_pairs ) + matching_pattern = get_matching_pattern(file_pattern) + if matching_pattern is None: + continue + glob_patterns = f'{extracted_folder}/{matching_pattern}' +# print('glob_patterns: ', glob_patterns) + for f in glob.glob(glob_patterns): + yield f + +# for debug usage +def all_extracted_files(split, src, tgt, extracted_folders, split_urls): + def get_url(url): + if isinstance(url, tuple): + url, downloaded_file = url + return url + return [ + f + for url in split_urls + for f in my_glob(extracted_folders[str(get_url(url))]) + ] + +def concat_files(split, src, tgt, extracted_folders, split_urls, path_patterns, to_folder, debug=False): +# if debug: +# print('extracted files to be filtered by patterns: ', +# '\n\t'.join(sorted(all_extracted_files(split, src, tgt, extracted_folders, split_urls)))) + for lang in [src, tgt]: + to_file = f'{to_folder}/{split}.{src}-{tgt}.{lang}' + s_src, s_tgt, s_lang = src.split('_')[0], tgt.split('_')[0], lang.split('_')[0] + files = [] + for url in split_urls: + if isinstance(url, tuple): + url, downloaded_file = url + if str(url) not in extracted_folders: + print(f'warning: {url} not in extracted files') + for extracted_file in set( + extracted_glob( + extracted_folders[str(url)], path_patterns, + s_src, s_tgt, s_lang)): + files.append(extracted_file) + if len(files) == 0: + print('warning: ', f'No files found for split {to_file}') + continue + files = sorted(set(files)) + print(f'concating {len(files)} files into {to_file}') + cmd = ['cat'] + [f'"{f}"' for f in files] + [f'>{to_file}'] + cmd = " ".join(cmd) + call(cmd, debug=debug) + +UTILS = os.path.join(pathlib.Path(__file__).parent, 'utils') +LID_MODEL = f'{download_to}/lid.176.bin' +LID_MULTI = f'{UTILS}/fasttext_multi_filter.py' + +def lid_filter(split, src, tgt, from_folder, to_folder, debug=False): + if not os.path.exists(LID_MODEL): + call(f'wget -nc https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -O {LID_MODEL}') + from_prefix = f'{from_folder}/{split}.{src}-{tgt}' + to_prefix = f'{to_folder}/{split}.{src}-{tgt}' + if os.path.exists(f'{from_prefix}.{src}') and os.path.exists(f'{from_prefix}.{tgt}'): + s_src, s_tgt = src.split('_')[0], tgt.split('_')[0] + cmd = ( + f'python {LID_MULTI} --model {LID_MODEL} --inputs {from_prefix}.{src} {from_prefix}.{tgt} ' + f'--langs {s_src} {s_tgt} --outputs {to_prefix}.{src} {to_prefix}.{tgt}' + ) + print(f'filtering {from_prefix}') + call(cmd, debug=debug) + +def concat_into_splits(dl_dataset, src, tgt, extracted_folders, to_folder, debug): + to_folder_tmp = f"{to_folder}_tmp" + os.makedirs(to_folder_tmp, exist_ok=True) + concat_files('train', src, tgt, + extracted_folders, + split_urls=dl_dataset.train_urls, + path_patterns=dl_dataset.train_files_patterns, + to_folder=to_folder_tmp, debug=debug) + lid_filter('train', src, tgt, to_folder_tmp, to_folder, debug) + + concat_files('valid', src, tgt, + extracted_folders, + split_urls=dl_dataset.valid_urls, + path_patterns=dl_dataset.valid_files_patterns, + to_folder=to_folder, debug=debug) + concat_files('test', src, tgt, + extracted_folders, + split_urls=dl_dataset.test_urls, + path_patterns=dl_dataset.test_files_patterns, + to_folder=to_folder, debug=debug) + + +def download_multi(dl_folder, extract_folder, urls, num_processes=8, debug=False): + pool = mp.Pool(processes=num_processes) + download_f = partial(download_a_url, dl_folder) + downloaded_files = pool.imap_unordered(download_f, urls) + pool.close() + pool.join() + +BLEU_REGEX = re.compile("^BLEU\\S* = (\\S+) ") +def run_eval_bleu(cmd): + output = check_output(cmd, shell=True, stderr=subprocess.STDOUT).decode("utf-8").strip() + print(output) + bleu = -1.0 + for line in output.strip().split('\n'): + m = BLEU_REGEX.search(line) + if m is not None: + bleu = m.groups()[0] + bleu = float(bleu) + break + return bleu + +def check_wmt_test_bleu(raw_folder, wmt_lang_pairs): + not_matchings = [] + for wmt, src_tgts in wmt_lang_pairs: + for src_tgt in src_tgts: + print(f'checking test bleus for: {src_tgt} at {wmt}') + src, tgt = src_tgt.split('-') + ssrc, stgt = src[:2], tgt[:2] + if os.path.exists(f'{raw_folder}/test.{tgt}-{src}.{src}'): + # reversed direction may have different test set + test_src = f'{raw_folder}/test.{tgt}-{src}.{src}' + else: + test_src = f'{raw_folder}/test.{src}-{tgt}.{src}' + cmd1 = f'cat {test_src} | sacrebleu -t "{wmt}" -l {stgt}-{ssrc}; [ $? -eq 0 ] || echo ""' + test_tgt = f'{raw_folder}/test.{src}-{tgt}.{tgt}' + cmd2 = f'cat {test_tgt} | sacrebleu -t "{wmt}" -l {ssrc}-{stgt}; [ $? -eq 0 ] || echo ""' + bleu1 = run_eval_bleu(cmd1) + if bleu1 != 100.0: + not_matchings.append(f'{wmt}:{src_tgt} source side not matching: {test_src}') + bleu2 = run_eval_bleu(cmd2) + if bleu2 != 100.0: + not_matchings.append(f'{wmt}:{src_tgt} target side not matching: {test_tgt}') + return not_matchings + +def download_and_extract( + to_folder, lang_pairs, dl_dataset, + to_manually_download_urls, + completed_urls={}, completed_extraction={}, + debug=False): + + dl_folder = f'{to_folder}/downloads' + extract_folder = f'{to_folder}/extracted' + raw_folder = f'{to_folder}/raw' + lid_filtered = f'{to_folder}/lid_filtered' + + os.makedirs(extract_folder, exist_ok=True) + os.makedirs(raw_folder, exist_ok=True) + os.makedirs(lid_filtered, exist_ok=True) + + + to_be_manually_dowloaded = check_need_manual_downalod(dl_folder, to_manually_download_urls) + + completed_urls = download_dataset( + dl_folder, dl_dataset, completed_urls) + if debug: + print('completed urls: ', completed_urls) + + + extracted_folders = extract_all_files( + completed_urls, + extract_folder=extract_folder, + completed_extraction=completed_extraction, + debug=debug) + if debug: + print('download files have been extracted to folders: ', extracted_folders) + + converted_files = convert_files_if_needed(extracted_folders, debug=False) + for src_tgt in lang_pairs: + print(f'working on {dl_dataset.name}: {src_tgt}') + src, tgt = src_tgt.split('-') + concat_into_splits(dl_dataset, + src=src, tgt=tgt, + extracted_folders=extracted_folders, + to_folder=raw_folder, debug=debug) + print('completed data into: ', raw_folder) + +def download_czang16(download_to, username=None): + wgets = [ + f'wget --user={username} --password=czeng -P {download_to} http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar' + for i in range(10)] + cmds = [] + for i, cmd in enumerate(wgets): + filename = f'{download_to}/data-plaintext-format.{i}.tar' + if os.path.exists(filename): + print(f'{filename} has already been downloaded; so skip') + continue + cmds.append(cmd) + if cmds and username is None: + raise ValueError('No czeng username is given; please register at http://ufal.mff.cuni.cz/czeng/czeng16 to obtain username to download') + for cmd in cmds: + call(cmd) + print('done with downloading czeng1.6') + +def download_czeng17_script(download_to, extract_folder, debug=False): + url = 'http://ufal.mff.cuni.cz/czeng/download.php?f=convert_czeng16_to_17.pl.zip' + filename = f'{download_to}/convert_czeng16_to_17.pl.zip' + extract_to = f'{extract_folder}/{get_extract_name(filename)}' + script_path = f'{extract_to}/convert_czeng16_to_17.pl' + + if not os.path.exists(script_path): + wget.download(url, filename, bar=bar_custom) + extract_to = extract_file(f'{download_to}/convert_czeng16_to_17.pl.zip', extract_folder, get_extract_name=get_extract_name, debug=debug) + return script_path + +czeng17_script_path = "" +def convert2czeng17(file, debug): + en_file = f'{file}.en' + cs_file = f'{file}.cs' + + if not os.path.exists(en_file) or not os.path.exists(cs_file): + cs_cmd = f'cat {file} | perl {czeng17_script_path} | cut -f3 > {cs_file}' + en_cmd = f'cat {file} | perl {czeng17_script_path} | cut -f4 > {en_file}' + call(cs_cmd, debug) + call(en_cmd, debug) + else: + print(f'already extracted: {en_file} and {cs_file}') + return file + +def extract_czeng17(extract_folder, debug=False): + url = 'http://ufal.mff.cuni.cz/czeng/download.php?f=convert_czeng16_to_17.pl.zip' + filename = f'{download_to}/convert_czeng16_to_17.pl.zip' + extract_to = f'{extract_folder}/{get_extract_name(filename)}' + script_path = f'{extract_to}/convert_czeng16_to_17.pl' + + if not os.path.exists(script_path): + wget.download(url, filename, bar=bar_custom) + extract_to = extract_file(f'{download_to}/convert_czeng16_to_17.pl.zip', extract_folder, get_extract_name=get_extract_name, debug=debug) + return script_path + +######### +# definitions of wmt data sources +# for es-en +# Punctuation in the official test sets will be encoded with ASCII characters (not complex Unicode characters) as much as possible. You may want to normalize your system's output before submission. You are able able to use a rawer version of the test sets that does not have this normalization. +# script to normalize punctuation: http://www.statmt.org/wmt11/normalize-punctuation.perl +wmt13_es_en = DLDataset( + name='wmt13_es-en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://www.statmt.org/wmt13/training-parallel-un.tgz', + 'http://www.statmt.org/wmt13/training-parallel-nc-v8.tgz', + ], + valid_urls=[ + ('http://www.statmt.org/wmt13/dev.tgz', 'wmt13_dev.tgz') + ], + test_urls=[ + ('http://www.statmt.org/wmt13/test.tgz', 'wmt13_test.tgz') + ], + train_files_patterns=[ + ('*/europarl-v7.{src}-{tgt}.{lang}', ['es-en']), + ('*commoncrawl.{src}-{tgt}.{lang}', ['es-en']), + ('*/news-commentary-v8.{src}-{tgt}.{lang}', ['es-en']), + ('un/*undoc.2000.{src}-{tgt}.{lang}', ['es-en']), + ] , + valid_files_patterns=[ + ('dev/newstest2012.{lang}', ['es-en']) + ], + test_files_patterns=[ + ('test/newstest*.{lang}', ['es-en']) + ], +) + +wmt14_de_fr_en = DLDataset( + name='wmt14_de_fr_en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://www.statmt.org/wmt13/training-parallel-un.tgz', + 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz', + ('http://www.statmt.org/wmt10/training-giga-fren.tar', 'training-giga-fren.gz.tar'), #it is actuall a gz.tar + ], + valid_urls=[ + ('http://www.statmt.org/wmt14/dev.tgz', 'wmt14_dev.tgz'), + ], + test_urls=[ + ('http://www.statmt.org/wmt14/test-full.tgz', 'wmt14_test_full.tgz'), # cleaned test sets + ], + train_files_patterns=[ + ('*/europarl-v7.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('*commoncrawl.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('*/*news-commentary-v9.{src}-{tgt}.{lang}', ['fr-en', 'de-en']), + ('un/undoc.2000.{src}-{tgt}.{lang}', ['fr-en']), + ('*giga-{src}{tgt}*{lang}', ['fr-en']) + ], + valid_files_patterns=[ + ('dev/newstest2013.{lang}', ['fr-en', 'de-en']) + ], + test_files_patterns=[ + ('test-full/newstest*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['en-de', 'de-en', 'fr-en', 'en-fr']), + ], +) + +# pip install git+https://github.com/amake/tmx2corpus.git +wmt16_ro_en = DLDataset( + name='wmt16_ro-en', + train_urls=[ + ('http://data.statmt.org/wmt16/translation-task/training-parallel-ep-v8.tgz', 'wmt16_training-parallel-ep-v8.tgz'), + ('http://opus.nlpl.eu/download.php?f=SETIMES/v2/tmx/en-ro.tmx.gz', 'en-ro.tmx.gz'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt16/translation-task/dev-romanian-updated.tgz', 'wmt16_dev.tgz') + ], + test_urls=[ + ('http://data.statmt.org/wmt16/translation-task/test.tgz', 'wmt16_test.tgz') + ], + train_files_patterns=[ + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['ro-en']), + ('bitext.{lang}', ['ro-en']) #setimes from tmux + ] , + valid_files_patterns=[ + ('dev/newsdev2016*{src}{tgt}*.{lang}', ['ro-en', 'ro-en']) + ], + test_files_patterns=[ + ('test/newstest*{src}{tgt}*.{lang}', ['ro-en', 'en-ro']) + ], +) + +cwmt_wmt_instruction = 'cwmt download instruction at: http://nlp.nju.edu.cn/cwmt-wmt' +wmt17_fi_lv_tr_zh_en_manual_downloads = [ + # fake urls to have unique keys for the data + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASIA2015.zip', 'CASIA2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2011.zip', 'CASICT2011.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2015.zip', 'CASICT2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2015.zip', 'Datum2015.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2017.zip', 'Datum2017.zip'), cwmt_wmt_instruction), + ( ('http://nlp.nju.edu.cn/cwmt-wmt/NEU2017.zip', 'NEU2017.zip'), cwmt_wmt_instruction), +] +wmt17_fi_lv_tr_zh_en = DLDataset( + name='wmt17_fi_lv_tr_zh_en', + train_urls=[ + ('http://data.statmt.org/wmt17/translation-task/training-parallel-ep-v8.tgz', 'wmt17_training-parallel-ep-v8.tgz'), + 'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz', + 'http://www.statmt.org/wmt15/wiki-titles.tgz', + ('http://opus.nlpl.eu/download.php?f=SETIMES/v2/tmx/en-tr.tmx.gz', 'en-tr.tmx.gz'), + ('http://data.statmt.org/wmt17/translation-task/rapid2016.tgz', 'wmt17_rapid2016.tgz'), + 'http://data.statmt.org/wmt17/translation-task/leta.v1.tgz', + 'http://data.statmt.org/wmt17/translation-task/dcep.lv-en.v1.tgz', + 'http://data.statmt.org/wmt17/translation-task/books.lv-en.v1.tgz', + (('https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01',), 'UNv1.0.en-zh.tar.gz'), + #manually download files: + ('http://nlp.nju.edu.cn/cwmt-wmt/CASIA2015.zip', 'CASIA2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2011.zip', 'CASICT2011.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/CASICT2015.zip', 'CASICT2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2015.zip', 'Datum2015.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/Datum2017.zip', 'Datum2017.zip'), + ('http://nlp.nju.edu.cn/cwmt-wmt/NEU2017.zip', 'NEU2017.zip'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt17/translation-task/dev.tgz', 'wmt17_dev.tgz'), + ], + test_urls=[ + #NEW: Improved translations for zh test sets + ('http://data.statmt.org/wmt17/translation-task/test-update-1.tgz', 'wmt17_test_zh_en.tgz'), + ('http://data.statmt.org/wmt17/translation-task/test.tgz', 'wmt17_test_others.tgz') + ], + train_files_patterns=[ + ('casict*/cas*{src:ch}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('casia*/cas*{src:ch}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('dataum*/Book*{src:cn}{tgt:en}.txt', ['zh-en', 'zh-en']), + ('neu*/NEU*{src:cn}{tgt:en}.txt', ['zh-en', 'zh-en'] ), + ('*/*UNv1.0.en-zh.{src:zh}{tgt:en}', ['zh-en']), + ('training/*news-commentary-v12.{src}-{tgt}.{lang}', ['zh-en', ]), + + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['fi-en', 'lv-en']), + ('wiki/fi-en/titles.{src}-{tgt}.{lang}', ['fi-en', ]), + ('rapid2016.{tgt}-{src}.{lang}', ['fi-en', 'lv-en']), + ('*/leta.{lang}', ['lv-en']), + ('*/dcep.{lang}', ['lv-en']), + ('*/farewell.{lang}', ['lv-en']), + ('bitext.{lang}', ['tr-en']), + ] , + valid_files_patterns=[ + ('dev/newsdev2017*{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'lv-en', 'tr-en', 'zh-en', + 'en-fi', 'en-lv', 'en-tr', 'en-zh' + ]), + ('dev/newstest2016*{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'tr-en', + 'en-fi', 'en-tr', + ]), + ], + test_files_patterns=[ + ('test/newstest2017-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'fi-en', 'lv-en', 'tr-en', + 'en-fi', 'en-lv', 'en-tr', + ]), + ('newstest2017-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + [ + 'zh-en', + 'en-zh' + ]), + ], +) + +czeng_instruction = 'download instruction at: http://ufal.mff.cuni.cz/czeng/czeng16' +#alternative: use the prepared data but detokenize it? +wmt18_cs_et_en_manual_downloads = [ +#for cs, need to register and download; Register and download CzEng 1.6. +#Better results can be obtained by using a subset of sentences, released under a new version name CzEng 1.7. + # ((f'http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar', + # f'data-plaintext-format.{i}.tar'), czeng_instruction) + # for i in range(10) +] + +wmt18_cs_et_en = DLDataset( + name='wmt18_cs_et_en', + train_urls=[ + 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz', + 'http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-cs.zipporah0-dedup-clean.tgz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-et.zipporah0-dedup-clean.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz', + ('http://data.statmt.org/wmt18/translation-task/rapid2016.tgz', 'wmt18_rapid2016.tgz'), + # (tuple( + # (f'http://ufallab.ms.mff.cuni.cz/~bojar/czeng16-data/data-plaintext-format.{i}.tar', + # f'data-plaintext-format.{i}.tar') + # for i in range(10) + # ), + # 'czeng16_data_plaintext.gz.tar'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt18/translation-task/dev.tgz', 'wmt18_dev.tgz'), + ], + test_urls=[ + ('http://data.statmt.org/wmt18/translation-task/test.tgz', 'wmt18_test.tgz'), + ], + train_files_patterns=[ + # ('*/*europarl-v7.{src}-{tgt}.{lang}', ['cs-en']), + ('*/*europarl-v8.{src}-{tgt}.{lang}', ['et-en']), + # ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['cs-en', 'et-en']), + ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['et-en']), + # ('*commoncrawl.{src}-{tgt}.{lang}', ['cs-en']), + # ('*/news-commentary-v13.{src}-{tgt}.{lang}', ['cs-en']), + # ('data.plaintext-format/*train.{lang}', ['cs-en']), + ('rapid2016.{tgt}-{src}.{lang}', ['et-en']), + ] , + valid_files_patterns=[ + ('dev/newsdev2018*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['et-en']), + # ('dev/newstest2017*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['cs-en']) + ], + test_files_patterns=[ + ('test/newstest2018-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + # ['cs-en', 'et-en']), + ['et-en']), + ] +) + +ru_en_yandex_instruction = 'Yandex Corpus download instruction at: https://translate.yandex.ru/corpus?lang=en' +wmt19_ru_gu_kk_lt_manual_downloads = [ + (('https://translate.yandex.ru/corpus?lang=en', 'wmt19_1mcorpus.zip'), ru_en_yandex_instruction) +] +wmt19_ru_gu_kk_lt = DLDataset( + name='wmt19_ru_gu_kk_lt', + train_urls=[ + 'http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release3/en-lt.bicleaner07.tmx.gz', + 'https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz', + 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz', + 'http://data.statmt.org/news-commentary/v14/training/news-commentary-v14-wmt19.en-kk.tsv.gz', + 'http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-ru.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz', + 'http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz', + (('https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01', + 'https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02',), + 'wmt19_UNv1.0.en-ru.tar.gz'), + 'https://tilde-model.s3-eu-west-1.amazonaws.com/rapid2016.en-lt.tmx.zip', + ('https://translate.yandex.ru/corpus?lang=en', 'wmt19_1mcorpus.zip'), + ], + valid_urls=[ + ('http://data.statmt.org/wmt19/translation-task/dev.tgz', 'wmt19_dev.tgz'), + ], + test_urls=[ + ('http://data.statmt.org/wmt19/translation-task/test.tgz', 'wmt19_test.tgz'), + ], + train_files_patterns=[ + ('*europarl-v9.{src}-{tgt}.tsv.{lang}', ['lt-en']), + #paracrawl + ('*paracrawl-release1.{tgt}-{src}.zipporah0-dedup-clean.{lang}', ['ru-en']), + ('bitext.{lang}', ['lt-en',]), + ('*commoncrawl.{src}-{tgt}.{lang}', ['ru-en',]), + ('*news-commentary-v14-wmt19.{tgt}-{src}.tsv.{lang}', ['kk-en', ]), + ('*news-commentary-v14.{tgt}-{src}.tsv.{lang}', ['ru-en']), + #yandex + ('corpus.{tgt}_{src}.1m.{lang}', ['ru-en']), + ('wikititles_v1_wikititles-v1.{src}-{tgt}.tsv.{lang}', ['ru-en', 'kk-en', 'lt-en', 'gu-en']), + ('*/UNv1.0.{tgt}-{src}.{lang}', ['ru-en']), + #rapid + ('bitext.{lang}', ['lt-en']) + ], + valid_files_patterns=[ + ('dev/newsdev2019*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['gu-en', 'kk-en', 'lt-en']), + ('dev/newstest2018*{src}{tgt}-{src:src}{tgt:ref}.{lang}', ['ru-en']), + ], + test_files_patterns=[ + ('sgm/newstest2019-{src}{tgt}-{src:src}{tgt:ref}.{lang}', + ['ru-en', 'gu-en', 'kk-en', 'lt-en', 'en-ru', 'en-gu', 'en-kk', 'en-lt']), + ] +) + + +######### + +if __name__ == "__main__": + # speed up the downloads with multiple processing + dl_folder = f'{to_data_path}/downloads' + extract_folder = f'{to_data_path}/extracted' + + urls = [ + url + for dataset in [wmt13_es_en, wmt14_de_fr_en, wmt16_ro_en, wmt18_cs_et_en, wmt19_ru_gu_kk_lt] + for urls in [dataset.train_urls, dataset.valid_urls, dataset.test_urls] + for url in urls + ] + urls = set(urls) + download_multi(dl_folder, extract_folder, urls, num_processes=8, debug=True) + + # check manually downlaods + to_manually_download_urls = ( + wmt17_fi_lv_tr_zh_en_manual_downloads + wmt18_cs_et_en_manual_downloads + wmt19_ru_gu_kk_lt_manual_downloads + ) + to_be_manually_dowloaded = check_need_manual_downalod(dl_folder, to_manually_download_urls) + if len(to_be_manually_dowloaded) > 0: + print('Missing files that need to be downloaded manually; stop the process now.') + exit(-1) + + completed_urls = {} + completed_extraction = {} + def work_on_wmt(directions, wmt_data): + download_and_extract( + to_data_path, + directions, + wmt_data, + to_manually_download_urls=to_manually_download_urls, + completed_urls=completed_urls, completed_extraction=completed_extraction, debug=True) + + work_on_wmt( + ['es_XX-en_XX'], + wmt13_es_en,) + work_on_wmt( + [ + 'fr_XX-en_XX', 'en_XX-fr_XX', + # 'en_XX-de_DE', 'de_DE-en_XX', + ], + wmt14_de_fr_en,) + work_on_wmt( + ['ro_RO-en_XX', 'en_XX-ro_XX'], + wmt16_ro_en,) + work_on_wmt( + [ + # 'zh_CN-en_XX', + 'lv_LV-en_XX', 'fi_FI-en_XX', 'tr_TR-en_XX', + #in case the reversed directions have different train/valid/test data + # 'en_XX-zh_CN', + 'en_XX-lv_LV', 'en_XX-fi_FI', 'en_XX-tr_TR', + ], + wmt17_fi_lv_tr_zh_en, ) + # czeng17_script_path = download_czeng17_script(download_to, extract_to, debug=False) + # cz_username = None + work_on_wmt( + [ + # 'cs_CZ-en_XX', + 'et_EE-en_XX'], + wmt18_cs_et_en,) + work_on_wmt( + [ + # 'ru_RU-en_XX', 'en_XX-ru_RU', + 'gu_IN-en_XX', 'kk_KZ-en_XX', 'lt_LT-en_XX', + #in case the reversed directions have different train/valid/test data + 'en_XX-gu_IN', 'en_XX-kk_KZ', 'en_XX-lt_LT' + ], + wmt19_ru_gu_kk_lt,) + + not_matching = check_wmt_test_bleu( + f'{to_data_path}/raw', + [ + ('wmt13', ['es_XX-en_XX']), + ('wmt14/full', ['fr_XX-en_XX',]), + ('wmt16', ['ro_RO-en_XX',]), + # ('wmt17/improved', ['zh_CN-en_XX']), + ('wmt17', [ 'lv_LV-en_XX', 'fi_FI-en_XX', 'tr_TR-en_XX']), + ('wmt18', ['cs_CZ-en_XX', 'et_EE-en_XX']), + ('wmt19', ['gu_IN-en_XX', 'kk_KZ-en_XX', 'lt_LT-en_XX']), + #'ru_RU-en_XX', + ] + ) + if len(not_matching) > 0: + print('the following datasets do not have matching test datasets:\n\t', '\n\t'.join(not_matching)) + diff --git a/fairseq/examples/multilingual/data_scripts/download_wmt20.sh b/fairseq/examples/multilingual/data_scripts/download_wmt20.sh new file mode 100644 index 0000000000000000000000000000000000000000..31cd5c76b75081331ae03c5ea70ea7ddebaa06e1 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/download_wmt20.sh @@ -0,0 +1,547 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + + + +set -x -e + +# TODO update the workdir and dest dir name +# put fasttext model +WORKDIR=$WORKDIR_ROOT +# put intermediate files +TMP_DIR=$WORKDIR_ROOT/tmp/tmp_wmt20_lowres_download +# output {train,valid,test} files to dest +DEST=$WORKDIR_ROOT/ML50/raw + +UTILS=$PWD/utils + +# per dataset locations +COMMONCRAWL_DIR=$TMP_DIR/commoncrawl +YANDEX_CORPUS=$WORKDIR_ROOT/wmt20/official/ru/yandex/1mcorpus.zip +# unzipped +CZENG_CORPUS=$WORKDIR_ROOT/wmt20/official/cs/czeng/czeng20-train +CCMT_DIR=$WORKDIR_ROOT/wmt20/official/zh/ccmt/parallel + +download_and_select() { + SUBFOLDER=$1 + URL=$2 + UNCOMPRESS_CMD=$3 + LANG=$4 + INPUT_FILEPATH=$5 + if [[ $# -gt 5 ]]; then + LANG_COL=$6 + EN_COL=$7 + fi + + mkdir -p $SUBFOLDER + cd $SUBFOLDER + wget -nc --content-disposition $URL + $UNCOMPRESS_CMD + + if [[ $# -gt 5 ]]; then + cut -f$LANG_COL $INPUT_FILEPATH > $INPUT_FILEPATH.$LANG + cut -f$EN_COL $INPUT_FILEPATH > $INPUT_FILEPATH.en + fi + cd .. + + ln -sf $SUBFOLDER/$INPUT_FILEPATH.$LANG $SUBFOLDER.$LANG + ln -sf $SUBFOLDER/$INPUT_FILEPATH.en $SUBFOLDER.en +} + +prepare_lid() { + pip install fasttext + + # TODO specify global workdir + MODEL=$WORKDIR/fasttext/lid.176.bin + LID_MULTI=$UTILS/fasttext_multi_filter.py + + if [ ! -f "$MODEL" ]; then + echo "downloading fasttext lid model..." + mkdir -p $WORKDIR/fasttext + wget -nc https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin -O $MODEL + fi +} + +prepare_moses() { + pushd $UTILS + echo 'Cloning Moses github repository (for tokenization scripts)...' + git clone https://github.com/moses-smt/mosesdecoder.git + popd +} + +lid_filter() { + # TODO specify global workdir + MODEL=$WORKDIR/fasttext/lid.176.bin + LID_MULTI=$UTILS/fasttext_multi_filter.py + + prepare_lid + + SRC=$1 + SRC_FILE=$2 + SRC_OUTPUT=$3 + TGT=$4 + TGT_FILE=$5 + TGT_OUTPUT=$6 + python $LID_MULTI --model $MODEL --inputs $SRC_FILE $TGT_FILE --langs $SRC $TGT --outputs $SRC_OUTPUT $TGT_OUTPUT +} + +prepare_ja_ted() { + mkdir -p ted + cd ted + + wget -nc https://wit3.fbk.eu/archive/2017-01-trnted//texts/en/ja/en-ja.tgz + tar -zxvf en-ja.tgz + cat en-ja/train.tags.en-ja.en | grep -v -P "^[ ]*\<" | sed 's/^[ \t]*//g' | sed 's/[ \t]*$//g' > en-ja/train.en-ja.en + cat en-ja/train.tags.en-ja.ja | grep -v -P "^[ ]*\<" | sed 's/^[ \t]*//g' | sed 's/[ \t]*$//g' > en-ja/train.en-ja.ja + + cd .. + ln -sf ted/en-ja/train.en-ja.ja ted.ja + ln -sf ted/en-ja/train.en-ja.en ted.en +} + +prepare_ja() { + OUTPUT_DIR=$TMP_DIR/ja + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://www.kecl.ntt.co.jp/icl/lirg/jparacrawl/release/2.0/bitext/en-ja.tar.gz" "tar -zxvf en-ja.tar.gz" ja en-ja/en-ja.bicleaner05.txt 4 3 & + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-ja.tsv.gz" "gunzip -f news-commentary-v15.en-ja.tsv.gz" ja news-commentary-v15.en-ja.tsv 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ja-en.tsv.gz" "gunzip -f wikititles-v2.ja-en.tsv.gz" ja wikititles-v2.ja-en.tsv 1 2 & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ja.langid.tsv.gz" "gunzip -f WikiMatrix.v1.en-ja.langid.tsv.gz" ja WikiMatrix.v1.en-ja.langid.tsv 3 2 & + download_and_select subtitle "https://nlp.stanford.edu/projects/jesc/data/split.tar.gz" "tar -zxvf split.tar.gz" ja split/train 2 1 & + download_and_select kftt "http://www.phontron.com/kftt/download/kftt-data-1.0.tar.gz" "tar -zxvf kftt-data-1.0.tar.gz" ja kftt-data-1.0/data/orig/kyoto-train & + + prepare_ja_ted & + + # ted data needs to + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ja" | sort -V | xargs cat > all.ja + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ja all.ja $DEST/train.ja_XX-en_XX.ja_XX en all.en $DEST/train.ja_XX-en_XX.en_XX +} + +prepare_ta() { + OUTPUT_DIR=$TMP_DIR/ta + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ta-en.tsv.gz" "gunzip -f wikititles-v2.ta-en.tsv.gz" ta wikititles-v2.ta-en.tsv 1 2 & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ta.langid.tsv.gz" "gunzip -f WikiMatrix.v1.en-ta.langid.tsv.gz" ta WikiMatrix.v1.en-ta.langid.tsv 3 2 & + download_and_select pmindia "http://data.statmt.org/pmindia/v1/parallel/pmindia.v1.ta-en.tsv" "" ta pmindia.v1.ta-en.tsv 2 1 & + download_and_select tanzil "https://object.pouta.csc.fi/OPUS-Tanzil/v1/moses/en-ta.txt.zip" "unzip en-ta.txt.zip" ta Tanzil.en-ta & + download_and_select pib "http://preon.iiit.ac.in/~jerin/resources/datasets/pib-v0.tar" "tar -xvf pib-v0.tar" ta pib/en-ta/train & + download_and_select mkb "http://preon.iiit.ac.in/~jerin/resources/datasets/mkb-v0.tar" "tar -xvf mkb-v0.tar" ta mkb/en-ta/mkb & + download_and_select ufal "http://ufal.mff.cuni.cz/~ramasamy/parallel/data/v2/en-ta-parallel-v2.tar.gz" "tar -zxvf en-ta-parallel-v2.tar.gz" ta en-ta-parallel-v2/corpus.bcn.train & + + wait + + # need special handling for nlpc + mkdir -p nlpc + cd nlpc + wget -nc https://raw.githubusercontent.com/nlpc-uom/English-Tamil-Parallel-Corpus/master/En-Ta%20Corpus/En-Ta%20English.txt + wget -nc https://github.com/nlpc-uom/English-Tamil-Parallel-Corpus/raw/master/En-Ta%20Corpus/En-Ta%20Tamil.txt + tail -n +4 "En-Ta English.txt" > en-ta.en + tail -n +4 "En-Ta Tamil.txt" > en-ta.ta + cd .. + ln -sf nlpc/en-ta.en nlpc.en + ln -sf nlpc/en-ta.ta nlpc.ta + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ta" | sort -V | xargs cat > all.ta + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ta all.ta $DEST/train.ta_IN-en_XX.ta_IN en all.en $DEST/train.ta_IN-en_XX.en_XX +} + +prepare_iu() { + OUTPUT_DIR=$TMP_DIR/iu + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select nh "https://nrc-digital-repository.canada.ca/eng/view/dataset/?id=c7e34fa7-7629-43c2-bd6d-19b32bf64f60" "tar -zxvf Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0.1.tgz" iu Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0/NunavutHansard > /dev/null & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.iu-en.tsv.gz" "gunzip -f wikititles-v2.iu-en.tsv.gz" iu wikititles-v2.iu-en.tsv 1 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.iu" | sort -V | xargs cat | nh/Nunavut-Hansard-Inuktitut-English-Parallel-Corpus-3.0/scripts/normalize-iu-spelling.pl > all.iu + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + paste all.iu all.en | awk -F $'\t' '$1!=""&&$2!=""' > all.iuen + cut -f1 all.iuen > $DEST/train.iu_CA-en_XX.iu_CA + cut -f2 all.iuen > $DEST/train.iu_CA-en_XX.en_XX +} + +prepare_km() { + OUTPUT_DIR=$TMP_DIR/km + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://data.statmt.org/wmt20/translation-task/ps-km/wmt20-sent.en-km.xz" "unxz wmt20-sent.en-km.zx" km wmt20-sent.en-km 2 1 & + + # km-parallel has multiple sets, concat all of them together + mkdir -p opus + cd opus + wget -nc "http://data.statmt.org/wmt20/translation-task/ps-km/km-parallel.tgz" + tar -zxvf km-parallel.tgz + find ./km-parallel -maxdepth 1 -name "*.km" | sort -V | xargs cat > opus.km + find ./km-parallel -maxdepth 1 -name "*.en" | sort -V | xargs cat > opus.en + cd .. + ln -sf opus/opus.km . + ln -sf opus/opus.en . + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.km" | sort -V | xargs cat > all.km + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter km all.km $DEST/train.km_KH-en_XX.km_KH en all.en $DEST/train.km_KH-en_XX.en_XX +} + +prepare_ps() { + OUTPUT_DIR=$TMP_DIR/ps + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "http://data.statmt.org/wmt20/translation-task/ps-km/wmt20-sent.en-ps.xz" "unxz wmt20-sent.en-ps.xz" ps wmt20-sent.en-ps 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ps-en.tsv.gz" "gunzip -f wikititles-v2.ps-en.tsv.gz" ps wikititles-v2.ps-en.tsv 1 2 & + # ps-parallel has multiple sets, concat all of them together + mkdir -p opus + cd opus + wget -nc "http://data.statmt.org/wmt20/translation-task/ps-km/ps-parallel.tgz" + tar -zxvf ps-parallel.tgz + find ./ps-parallel -maxdepth 1 -name "*.ps" | sort -V | xargs cat > opus.ps + find ./ps-parallel -maxdepth 1 -name "*.en" | sort -V | xargs cat > opus.en + cd .. + ln -sf opus/opus.ps opus.ps + ln -sf opus/opus.en opus.en + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ps" | sort -V | xargs cat > all.ps + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ps all.ps $DEST/train.ps_AF-en_XX.ps_AF en all.en $DEST/train.ps_AF-en_XX.en_XX +} + +download_commoncrawl() { + mkdir -p $COMMONCRAWL_DIR + cd $COMMONCRAWL_DIR + + wget -nc "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" + tar -zxvf training-parallel-commoncrawl.tgz +} +link_commoncrawl() { + LANG=$1 + ln -sf $COMMONCRAWL_DIR/commoncrawl.$LANG-en.en commoncrawl.en + ln -sf $COMMONCRAWL_DIR/commoncrawl.$LANG-en.$LANG commoncrawl.$LANG +} + +strip_xlf() { + INPUT_FILE=$1 + SRC=$2 + TGT=$3 + grep '<source xml:lang=' $INPUT_FILE | sed 's/^<[^<>]*>//g' | sed 's/<[^<>]*>$//g' > $INPUT_FILE.$SRC + grep '<target xml:lang=' $INPUT_FILE | sed 's/^<[^<>]*>//g' | sed 's/<[^<>]*>$//g' > $INPUT_FILE.$TGT +} + +download_and_process_tilde() { + URL=$1 + UNCOMPRESS_CMD=$2 + FILENAME=$3 + LANG=$4 + PROCESS_CMD=$5 + + mkdir -p tilde + cd tilde + wget -nc $URL + $UNCOMPRESS_CMD + echo "executing cmd" + echo $PROCESS_CMD + $PROCESS_CMD + cd .. + ln -sf tilde/$FILENAME.$LANG tilde.$LANG + ln -sf tilde/$FILENAME.en tilde.en +} + +prepare_cs() { + OUTPUT_DIR=$TMP_DIR/cs + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + #download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.cs-en.tsv.gz" "gunzip europarl-v10.cs-en.tsv.gz" cs europarl-v10.cs-en.tsv 1 2 & + #download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-cs.txt.gz" "gunzip en-cs.txt.gz" cs en-cs.txt 2 1 & + #link_commoncrawl cs + #download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.cs-en.tsv.gz" "gunzip news-commentary-v15.cs-en.tsv.gz" cs news-commentary-v15.cs-en.tsv 1 2 & + #download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.cs-en.tsv.gz" "gunzip wikititles-v2.cs-en.tsv.gz" cs wikititles-v2.cs-en.tsv 1 2 & + #download_and_process_tilde "http://data.statmt.org/wmt20/translation-task/rapid/RAPID_2019.cs-en.xlf.gz" "gunzip RAPID_2019.cs-en.xlf.gz" RAPID_2019.cs-en.xlf cs "strip_xlf RAPID_2019.cs-en.xlf cs en" & + #download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.cs-en.langid.tsv.gz" "gunzip WikiMatrix.v1.cs-en.langid.tsv.gz" cs WikiMatrix.v1.cs-en.langid.tsv 2 3 & + + #wait + + # remove previous results + #rm -f all.?? + #find ./ -maxdepth 1 -name "*.cs" | sort -V | xargs cat > all.cs + #find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + if [ -z $CZENG_CORPUS ] ; + then + echo "Please download CZENG_CORPUS manually and place them at $CZENG_CORPUS. Exitting..." + exit + fi + cat $CZENG_CORPUS | sed '/^$/d' | cut -f5 > all.cs + cat $CZENG_CORPUS | sed '/^$/d' | cut -f6 > all.en + + lid_filter cs all.cs $DEST/train.cs_CZ-en_XX.cs_CZ en all.en $DEST/train.cs_CZ-en_XX.en_XX +} + +prepare_de() { + OUTPUT_DIR=$TMP_DIR/de + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.de-en.tsv.gz" "gunzip europarl-v10.de-en.tsv.gz" de europarl-v10.de-en.tsv 1 2 & + download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-de.txt.gz" "gunzip en-de.txt.gz" de en-de.txt 2 1 & + link_commoncrawl de + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.de-en.tsv.gz" "gunzip news-commentary-v15.de-en.tsv.gz" de news-commentary-v15.de-en.tsv 1 2 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.de-en.tsv.gz" "gunzip wikititles-v2.de-en.tsv.gz" de wikititles-v2.de-en.tsv 1 2 & + download_and_process_tilde "http://data.statmt.org/wmt20/translation-task/rapid/RAPID_2019.de-en.xlf.gz" "gunzip RAPID_2019.de-en.xlf.gz" RAPID_2019.de-en.xlf de "strip_xlf RAPID_2019.de-en.xlf de en" & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.de-en.langid.tsv.gz" "gunzip WikiMatrix.v1.de-en.langid.tsv.gz" de WikiMatrix.v1.de-en.langid.tsv 2 3 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.de" | sort -V | xargs cat > all.de + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter de all.de $DEST/train.de_DE-en_XX.de_DE en all.en $DEST/train.de_DE-en_XX.en_XX +} + +prepare_tmx() { + TMX_FILE=$1 + git clone https://github.com/amake/TMX2Corpus $UTILS/tmx2corpus + pip install tinysegmenter + + python $UTILS/tmx2corpus/tmx2corpus.py $TMX_FILE +} + +prepare_pl() { + OUTPUT_DIR=$TMP_DIR/pl + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + # download_and_select europarl "http://www.statmt.org/europarl/v10/training/europarl-v10.pl-en.tsv.gz" "gunzip europarl-v10.pl-en.tsv.gz" pl europarl-v10.pl-en.tsv 1 2 & + # download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release5.1/en-pl.txt.gz" "gunzip en-pl.txt.gz" pl en-pl.txt 2 1 & + # download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.pl-en.tsv.gz" "gunzip wikititles-v2.pl-en.tsv.gz" pl wikititles-v2.pl-en.tsv 1 2 & + download_and_select tilde "https://tilde-model.s3-eu-west-1.amazonaws.com/rapid2019.en-pl.tmx.zip" "gunzip rapid2019.en-pl.tmx.zip" bitext pl "prepare_tmx RAPID_2019.UNIQUE.en-pl.tmx" & + # download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-pl.langid.tsv.gz" "gunzip WikiMatrix.v1.en-pl.langid.tsv.gz" pl WikiMatrix.v1.en-pl.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.pl" | sort -V | xargs cat > all.pl + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter pl all.pl $DEST/train.pl_PL-en_XX.pl_PL en all.en $DEST/train.pl_PL-en_XX.en_XX +} + +prepare_uncorpus() { + $URLS=$1 + $FILES=$2 + + mkdir -p uncorpus + cd uncorpus + + for URL in $URLS; do + wget -nc $URL + done + cat $FILES > uncorpus.tar.gz + tar -zxvf uncorpus.tar.gz + + cd .. + ln -sf uncorpus/en-$LANG/UNv1.0.en-$LANG.$LANG uncorpus.$LANG + ln -sf uncorpus/en-$LANG/UNv1.0.en-$LANG.en uncorpus.en +} + +prepare_yandex() { + mkdir -p yandex + cd yandex + unzip $YANDEX_CORPUS ./ + cd .. + ln -s yandex/corpus.en_ru.1m.en yandex.en + ln -s yandex/corpus.en_ru.1m.ru yandex.ru +} + +prepare_ru() { + OUTPUT_DIR=$TMP_DIR/ru + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select paracrawl "https://s3.amazonaws.com/web-language-models/paracrawl/release1/paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" "tar -zxvf paracrawl-release1.en-ru.zipporah0-dedup-clean.tgz" ru paracrawl-release1.en-ru.zipporah0-dedup-clean & + link_commoncrawl ru + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-ru.tsv.gz" "gunzip news-commentary-v15.en-ru.tsv.gz" ru news-commentary-v15.en-ru.tsv 2 1 & + prepare_yandex & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.ru-en.tsv.gz" "gunzip wikititles-v2.ru-en.tsv.gz" ru wikititles-v2.ru-en.tsv 1 2 & + prepare_uncorpus "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" "UNv1.0.en-ru.tar.gz.00 UNv1.0.en-ru.tar.gz.01 UNv1.0.en-ru.tar.gz.02" & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-ru.langid.tsv.gz" "gunzip WikiMatrix.v1.en-ru.langid.tsv.gz" ru WikiMatrix.v1.en-ru.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.ru" | sort -V | xargs cat > all.ru + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter ru all.ru $DEST/train.ru_RU-en_XX.ru_RU en all.en $DEST/train.ru_RU-en_XX.en_XX +} + +prepare_ccmt() { + mkdir -p ccmt + cd ccmt + # assume ccmt data is already unzipped under CCMT_DIR folder + cat $CCMT_DIR/datum2017/Book*_cn.txt | sed 's/ //g' > datum2017.detok.zh + cat $CCMT_DIR/datum2017/Book*_en.txt > datum2017.detok.en + cat $CCMT_DIR/casict2011/casict-A_ch.txt $CCMT_DIR/casict2011/casict-B_ch.txt $CCMT_DIR/casict2015/casict2015_ch.txt $CCMT_DIR/datum2015/datum_ch.txt $CCMT_DIR/neu2017/NEU_cn.txt datum2017.detok.zh > ccmt.zh + cat $CCMT_DIR/casict2011/casict-A_en.txt $CCMT_DIR/casict2011/casict-B_en.txt $CCMT_DIR/casict2015/casict2015_en.txt $CCMT_DIR/datum2015/datum_en.txt $CCMT_DIR/neu2017/NEU_en.txt datum2017.detok.en > ccmt.en + cd .. + ln -sf ccmt/ccmt.zh ccmt.zh + ln -sf ccmt/ccmt.en ccmt.en +} + +prepare_zh() { + OUTPUT_DIR=$TMP_DIR/zh + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + + download_and_select newscommentary "http://data.statmt.org/news-commentary/v15/training/news-commentary-v15.en-zh.tsv.gz" "gunzip news-commentary-v15.en-zh.tsv.gz" zh news-commentary-v15.en-zh.tsv 2 1 & + download_and_select wikititles "http://data.statmt.org/wikititles/v2/wikititles-v2.zh-en.tsv.gz" "gunzip wikititles-v2.zh-en.tsv.gz" zh wikititles-v2.zh-en.tsv 1 2 & + prepare_uncorpus "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00 https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" "UNv1.0.en-zh.tar.gz.00 UNv1.0.en-zh.tar.gz.01" & + prepare_ccmt & + download_and_select wikimatrix "http://data.statmt.org/wmt20/translation-task/WikiMatrix/WikiMatrix.v1.en-zh.langid.tsv.gz" "gunzip WikiMatrix.v1.en-zh.langid.tsv.gz" zh WikiMatrix.v1.en-zh.langid.tsv 3 2 & + + wait + + # remove previous results + rm -f all.?? + find ./ -maxdepth 1 -name "*.zh" | sort -V | xargs cat > all.zh + find ./ -maxdepth 1 -name "*.en" | sort -V | xargs cat > all.en + lid_filter zh all.zh $DEST/train.zh_CN-en_XX.zh_CN en all.en $DEST/train.zh_CN-en_XX.en_XX +} + +prepare_tests() { + OUTPUT_DIR=$TMP_DIR + mkdir -p $OUTPUT_DIR + cd $OUTPUT_DIR + wget -nc http://data.statmt.org/wmt20/translation-task/dev.tgz + tar -zxvf dev.tgz + cd dev + + cat newsdev2020-jaen-src.ja.sgm | $UTILS/strip_sgm.sh > newsdev2020-jaen.ja + cat newsdev2020-jaen-ref.en.sgm | $UTILS/strip_sgm.sh > newsdev2020-jaen.en + split newsdev2020-jaen.ja -a 0 -n r/1/2 > $DEST/valid.ja_XX-en_XX.ja_XX + split newsdev2020-jaen.en -a 0 -n r/1/2 > $DEST/valid.ja_XX-en_XX.en_XX + split newsdev2020-jaen.ja -a 0 -n r/2/2 > $DEST/test.ja_XX-en_XX.ja_XX + split newsdev2020-jaen.en -a 0 -n r/2/2 > $DEST/test.ja_XX-en_XX.en_XX + + cat newsdev2020-iuen-src.iu.sgm | strip_sgm.sh > newsdev2020-iuen.iu + cat newsdev2020-iuen-ref.en.sgm | strip_sgm.sh > newsdev2020-iuen.en + split newsdev2020-iuen.iu -a 0 -n r/1/2 > $DEST/valid.iu_CA-en_XX.iu_CA + split newsdev2020-iuen.en -a 0 -n r/1/2 > $DEST/valid.iu_CA-en_XX.en_XX + split newsdev2020-iuen.iu -a 0 -n r/2/2 > $DEST/test.iu_CA-en_XX.iu_CA + split newsdev2020-iuen.en -a 0 -n r/2/2 > $DEST/test.iu_CA-en_XX.en_XX + + cat newsdev2020-taen-src.ta.sgm | strip_sgm.sh > newsdev2020-taen.ta + cat newsdev2020-taen-ref.en.sgm | strip_sgm.sh > newsdev2020-taen.en + split newsdev2020-taen.ta -a 0 -n r/1/2 > $DEST/valid.ta_IN-en_XX.ta_IN + split newsdev2020-taen.en -a 0 -n r/1/2 > $DEST/valid.ta_IN-en_XX.en_XX + split newsdev2020-taen.ta -a 0 -n r/2/2 > $DEST/test.ta_IN-en_XX.ta_IN + split newsdev2020-taen.en -a 0 -n r/2/2 > $DEST/test.ta_IN-en_XX.en_XX + + cp wikipedia.dev.km-en.km $DEST/valid.km_KH-en_XX.km_KH + cp wikipedia.dev.km-en.en $DEST/valid.km_KH-en_XX.en_XX + cp wikipedia.devtest.km-en.km $DEST/test.km_KH-en_XX.km_KH + cp wikipedia.devtest.km-en.en $DEST/test.km_KH-en_XX.en_XX + + cp wikipedia.dev.ps-en.ps $DEST/valid.ps_AF-en_XX.ps_AF + cp wikipedia.dev.ps-en.en $DEST/valid.ps_AF-en_XX.en_XX + cp wikipedia.devtest.ps-en.ps $DEST/test.ps_AF-en_XX.ps_AF + cp wikipedia.devtest.ps-en.en $DEST/test.ps_AF-en_XX.en_XX + + cat newsdev2020-plen-src.pl.sgm | strip_sgm.sh > newsdev2020-plen.pl + cat newsdev2020-plen-ref.en.sgm | strip_sgm.sh > newsdev2020-plen.en + split newsdev2020-plen.pl -a 0 -n r/1/2 > $DEST/valid.pl_PL-en_XX.pl_PL + split newsdev2020-plen.en -a 0 -n r/1/2 > $DEST/valid.pl_PL-en_XX.en_XX + split newsdev2020-plen.pl -a 0 -n r/2/2 > $DEST/test.pl_PL-en_XX.pl_PL + split newsdev2020-plen.en -a 0 -n r/2/2 > $DEST/test.pl_PL-en_XX.en_XX + + cat newstest2018-encs-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-cs_CZ.en_XX + cat newstest2018-encs-ref.cs.sgm | strip_sgm.sh > $DEST/valid.en_XX-cs_CZ.cs_CZ + cat newstest2019-encs-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-cs_CZ.en_XX + cat newstest2019-encs-ref.cs.sgm | strip_sgm.sh > $DEST/test.en_XX-cs_CZ.cs_CZ + + cat newstest2018-deen-src.de.sgm | strip_sgm.sh > $DEST/valid.de_DE-en_XX.de_DE + cat newstest2018-deen-ref.en.sgm | strip_sgm.sh > $DEST/valid.de_DE-en_XX.en_XX + cat newstest2018-ende-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-de_DE.en_XX + cat newstest2018-ende-ref.de.sgm | strip_sgm.sh > $DEST/valid.en_XX-de_DE.de_DE + cat newstest2019-deen-src.de.sgm | strip_sgm.sh > $DEST/test.de_DE-en_XX.de_DE + cat newstest2019-deen-ref.en.sgm | strip_sgm.sh > $DEST/test.de_DE-en_XX.en_XX + cat newstest2019-ende-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-de_DE.en_XX + cat newstest2019-ende-ref.de.sgm | strip_sgm.sh > $DEST/test.en_XX-de_DE.de_DE + + cat newstest2018-ruen-src.ru.sgm | strip_sgm.sh > $DEST/valid.ru_RU-en_XX.ru_RU + cat newstest2018-ruen-ref.en.sgm | strip_sgm.sh > $DEST/valid.ru_RU-en_XX.en_XX + cat newstest2018-enru-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-ru_RU.en_XX + cat newstest2018-enru-ref.ru.sgm | strip_sgm.sh > $DEST/valid.en_XX-ru_RU.ru_RU + cat newstest2019-ruen-src.ru.sgm | strip_sgm.sh > $DEST/test.ru_RU-en_XX.ru_RU + cat newstest2019-ruen-ref.en.sgm | strip_sgm.sh > $DEST/test.ru_RU-en_XX.en_XX + cat newstest2019-enru-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-ru_RU.en_XX + cat newstest2019-enru-ref.ru.sgm | strip_sgm.sh > $DEST/test.en_XX-ru_RU.ru_RU + + cat newstest2018-zhen-src.zh.sgm | strip_sgm.sh > $DEST/valid.zh_CN-en_XX.zh_CN + cat newstest2018-zhen-ref.en.sgm | strip_sgm.sh > $DEST/valid.zh_CN-en_XX.en_XX + cat newstest2018-enzh-src.en.sgm | strip_sgm.sh > $DEST/valid.en_XX-zh_CN.en_XX + cat newstest2018-enzh-ref.zh.sgm | strip_sgm.sh > $DEST/valid.en_XX-zh_CN.zh_CN + cat newstest2019-zhen-src.zh.sgm | strip_sgm.sh > $DEST/test.zh_CN-en_XX.zh_CN + cat newstest2019-zhen-ref.en.sgm | strip_sgm.sh > $DEST/test.zh_CN-en_XX.en_XX + cat newstest2019-enzh-src.en.sgm | strip_sgm.sh > $DEST/test.en_XX-zh_CN.en_XX + cat newstest2019-enzh-ref.zh.sgm | strip_sgm.sh > $DEST/test.en_XX-zh_CN.zh_CN +} + +mkdir -p $DEST + +prepare_lid +prepare_moses +download_commoncrawl + +prepare_ja & +prepare_ta & +prepare_km & +prepare_ps & +prepare_iu & +prepare_cs & +prepare_de & +prepare_pl & +prepare_ru & +prepare_zh & + +# prepare valid/test set +prepare_tests & + +# wait + +# TODO remove intermediate files +# rm -rf $TMP_DIR diff --git a/fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh b/fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh new file mode 100644 index 0000000000000000000000000000000000000000..4655936149cab212b3cfa14f306d71153729f9d7 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +if [ -z $WORKDIR_ROOT ] ; +then + echo "please specify your working directory root in environment variable WORKDIR_ROOT. Exitting..." + exit +fi + +if [ -z $SPM_PATH ] ; +then + echo "Please install sentence piecence from https://github.com/google/sentencepiece and set SPM_PATH pointing to the installed spm_encode.py. Exitting..." + exit +fi + +ML50=${WORKDIR_ROOT}/ML50 + +mkdir -p $ML50/dedup +mkdir -p $ML50/cleaned_dedup + +python ./dedup_all.py --from-folder $ML50/raw --to-folder $ML50/dedup +python ./remove_valid_test_in_train.py --from-folder $ML50/dedup --to-folder $ML50/clean +python ./binarize.py --raw-folder $ML50/clean \ No newline at end of file diff --git a/fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py b/fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py new file mode 100644 index 0000000000000000000000000000000000000000..ef618adef7c7d010f8de38fb5ebeb5a35d2d3cac --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py @@ -0,0 +1,290 @@ +import os, sys +import glob, itertools +import pandas as pd + +WORKDIR_ROOT = os.environ.get('WORKDIR_ROOT', None) + +if WORKDIR_ROOT is None or not WORKDIR_ROOT.strip(): + print('please specify your working directory root in OS environment variable WORKDIR_ROOT. Exitting..."') + sys.exit(-1) + + +def load_langs(path): + with open(path) as fr: + langs = [l.strip() for l in fr] + return langs + + + +def load_sentences(raw_data, split, direction): + src, tgt = direction.split('-') + src_path = f"{raw_data}/{split}.{direction}.{src}" + tgt_path = f"{raw_data}/{split}.{direction}.{tgt}" + if os.path.exists(src_path) and os.path.exists(tgt_path): + return [(src, open(src_path).read().splitlines()), (tgt, open(tgt_path).read().splitlines())] + else: + return [] + +def swap_direction(d): + src, tgt = d.split('-') + return f'{tgt}-{src}' + +def get_all_test_data(raw_data, directions, split='test'): + test_data = [ + x + for dd in directions + for d in [dd, swap_direction(dd)] + for x in load_sentences(raw_data, split, d) + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_data = {} + for lang, d in test_data: + for s in d: + s = s.strip() + lgs = all_test_data.get(s, set()) + lgs.add(lang) + all_test_data[s] = lgs + return all_test_data, test_data + +def check_train_sentences(raw_data, direction, all_test_data, mess_up_train={}): + src, tgt = direction.split('-') + tgt_path = f"{raw_data}/train.{direction}.{tgt}" + src_path = f"{raw_data}/train.{direction}.{src}" + print(f'check training data in {raw_data}/train.{direction}') + size = 0 + if not os.path.exists(tgt_path) or not os.path.exists(src_path): + return mess_up_train, size + with open(src_path) as f, open(tgt_path) as g: + for src_line, tgt_line in zip(f, g): + s = src_line.strip() + t = tgt_line.strip() + size += 1 + if s in all_test_data: + langs = mess_up_train.get(s, set()) + langs.add(direction) + mess_up_train[s] = langs + if t in all_test_data: + langs = mess_up_train.get(t, set()) + langs.add(direction) + mess_up_train[t] = langs + return mess_up_train, size + +def check_train_all(raw_data, directions, all_test_data): + mess_up_train = {} + data_sizes = {} + for direction in directions: + _, size = check_train_sentences(raw_data, direction, all_test_data, mess_up_train) + data_sizes[direction] = size + return mess_up_train, data_sizes + +def count_train_in_other_set(mess_up_train): + train_in_others = [(direction, s) for s, directions in mess_up_train.items() for direction in directions] + counts = {} + for direction, s in train_in_others: + counts[direction] = counts.get(direction, 0) + 1 + return counts + +def train_size_if_remove_in_otherset(data_sizes, mess_up_train): + counts_in_other = count_train_in_other_set(mess_up_train) + remain_sizes = [] + for direction, count in counts_in_other.items(): + remain_sizes.append((direction, data_sizes[direction] - count, data_sizes[direction], count, 100 * count / data_sizes[direction] )) + return remain_sizes + + +def remove_messed_up_sentences(raw_data, direction, mess_up_train, mess_up_train_pairs, corrected_langs): + split = 'train' + src_lang, tgt_lang = direction.split('-') + + tgt = f"{raw_data}/{split}.{direction}.{tgt_lang}" + src = f"{raw_data}/{split}.{direction}.{src_lang}" + print(f'working on {direction}: ', src, tgt) + if not os.path.exists(tgt) or not os.path.exists(src) : + return + + corrected_tgt = f"{to_folder}/{split}.{direction}.{tgt_lang}" + corrected_src = f"{to_folder}/{split}.{direction}.{src_lang}" + line_num = 0 + keep_num = 0 + with open(src, encoding='utf8',) as fsrc, \ + open(tgt, encoding='utf8',) as ftgt, \ + open(corrected_src, 'w', encoding='utf8') as fsrc_corrected, \ + open(corrected_tgt, 'w', encoding='utf8') as ftgt_corrected: + for s, t in zip(fsrc, ftgt): + s = s.strip() + t = t.strip() + if t not in mess_up_train \ + and s not in mess_up_train \ + and (s, t) not in mess_up_train_pairs \ + and (t, s) not in mess_up_train_pairs: + corrected_langs.add(direction) + print(s, file=fsrc_corrected) + print(t, file=ftgt_corrected) + keep_num += 1 + line_num += 1 + if line_num % 1000 == 0: + print(f'completed {line_num} lines', end='\r') + return line_num, keep_num + +########## + + +def merge_valid_test_messup(mess_up_train_valid, mess_up_train_test): + merged_mess = [] + for s in set(list(mess_up_train_valid.keys()) + list(mess_up_train_test.keys())): + if not s: + continue + valid = mess_up_train_valid.get(s, set()) + test = mess_up_train_test.get(s, set()) + merged_mess.append((s, valid | test)) + return dict(merged_mess) + + + +######### +def check_train_pairs(raw_data, direction, all_test_data, mess_up_train={}): + src, tgt = direction.split('-') + #a hack; TODO: check the reversed directions + path1 = f"{raw_data}/train.{src}-{tgt}.{src}" + path2 = f"{raw_data}/train.{src}-{tgt}.{tgt}" + if not os.path.exists(path1) or not os.path.exists(path2) : + return + + with open(path1) as f1, open(path2) as f2: + for src_line, tgt_line in zip(f1, f2): + s = src_line.strip() + t = tgt_line.strip() + if (s, t) in all_test_data or (t, s) in all_test_data: + langs = mess_up_train.get( (s, t), set()) + langs.add(src) + langs.add(tgt) + mess_up_train[(s, t)] = langs + + +def load_pairs(raw_data, split, direction): + src, tgt = direction.split('-') + src_f = f"{raw_data}/{split}.{direction}.{src}" + tgt_f = f"{raw_data}/{split}.{direction}.{tgt}" + if tgt != 'en_XX': + src_f, tgt_f = tgt_f, src_f + if os.path.exists(src_f) and os.path.exists(tgt_f): + return list(zip(open(src_f).read().splitlines(), + open(tgt_f).read().splitlines(), + )) + else: + return [] + +# skip_langs = ['cs_CZ', 'en_XX', 'tl_XX', 'tr_TR'] +def get_messed_up_test_pairs(split, directions): + test_pairs = [ + (d, load_pairs(raw_data, split, d)) + for d in directions + ] + # all_test_data = {s for _, d in test_data for s in d} + all_test_pairs = {} + for direction, d in test_pairs: + src, tgt = direction.split('-') + for s in d: + langs = all_test_pairs.get(s, set()) + langs.add(src) + langs.add(tgt) + all_test_pairs[s] = langs + mess_up_train_pairs = {} + for direction in directions: + check_train_pairs(raw_data, direction, all_test_pairs, mess_up_train_pairs) + return all_test_pairs, mess_up_train_pairs + + + +if __name__ == "__main__": + ####### + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + '--from-folder', + required=True, + type=str) + parser.add_argument( + '--to-folder', + required=True, + type=str) + parser.add_argument( + '--directions', + default=None, + type=str) + + + args = parser.parse_args() + raw_data = args.from_folder + to_folder = args.to_folder + os.makedirs(to_folder, exist_ok=True) + + if args.directions: + directions = args.directions.split(',') + else: + raw_files = itertools.chain( + glob.glob(f'{raw_data}/train*'), + glob.glob(f'{raw_data}/valid*'), + glob.glob(f'{raw_data}/test*'), + ) + directions = [os.path.split(file_path)[-1].split('.')[1] for file_path in raw_files] + print('working on directions: ', directions) + + ########## + + + + all_test_data, test_data = get_all_test_data(raw_data, directions, 'test') + print('==loaded test data==') + all_valid_data, valid_data = get_all_test_data(raw_data, directions, 'valid') + print('==loaded valid data==') + all_valid_test_data = merge_valid_test_messup(all_test_data, all_valid_data) + mess_up_train, data_sizes = check_train_all(raw_data, directions, all_valid_test_data) + print('training messing up with valid, test data:', len(mess_up_train)) + data_situation = train_size_if_remove_in_otherset(data_sizes, mess_up_train) + df = pd.DataFrame(data_situation, columns=['direction', 'train_size_after_remove', 'orig_size', 'num_to_remove', 'remove_percent']) + df.sort_values('remove_percent', ascending=False) + df.to_csv(f'{raw_data}/clean_summary.tsv', sep='\t') + print(f'projected data clean summary in: {raw_data}/clean_summary.tsv') + + # correct the dataset: + all_test_pairs, mess_up_test_train_pairs = get_messed_up_test_pairs('test', directions) + all_valid_pairs, mess_up_valid_train_pairs = get_messed_up_test_pairs('valid', directions) + + all_messed_pairs = set(mess_up_test_train_pairs.keys()).union(set(mess_up_valid_train_pairs.keys())) + corrected_directions = set() + + real_data_situation = [] + for direction in directions: + org_size, new_size = remove_messed_up_sentences(raw_data, direction, mess_up_train, all_messed_pairs, corrected_directions) + if org_size == 0: + print(f"{direction} has size 0") + continue + real_data_situation.append( + (direction, new_size, org_size, org_size - new_size, (org_size - new_size) / org_size * 100) + ) + print('corrected directions: ', corrected_directions) + df = pd.DataFrame(real_data_situation, columns=['direction', 'train_size_after_remove', 'orig_size', 'num_to_remove', 'remove_percent']) + df.sort_values('remove_percent', ascending=False) + df.to_csv(f'{raw_data}/actual_clean_summary.tsv', sep='\t') + print(f'actual data clean summary (which can be different from the projected one because of duplications) in: {raw_data}/actual_clean_summary.tsv') + + import shutil + for direction in directions: + src_lang, tgt_lang = direction.split('-') + for split in ['train', 'valid', 'test']: + # copying valid, test and uncorrected train + if direction in corrected_directions and split == 'train': + continue + tgt = f"{raw_data}/{split}.{direction}.{tgt_lang}" + src = f"{raw_data}/{split}.{direction}.{src_lang}" + if not (os.path.exists(src) and os.path.exists(tgt)): + continue + corrected_tgt = f"{to_folder}/{split}.{direction}.{tgt_lang}" + corrected_src = f"{to_folder}/{split}.{direction}.{src_lang}" + print(f'copying {src} to {corrected_src}') + shutil.copyfile(src, corrected_src) + print(f'copying {tgt} to {corrected_tgt}') + shutil.copyfile(tgt, corrected_tgt) + + print('completed') \ No newline at end of file diff --git a/fairseq/examples/multilingual/data_scripts/requirement.txt b/fairseq/examples/multilingual/data_scripts/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..e85d7d540e08a1407f92dfb2311972a1a5a30123 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/requirement.txt @@ -0,0 +1,2 @@ +wget +pandas \ No newline at end of file diff --git a/fairseq/examples/multilingual/data_scripts/utils/dedup.py b/fairseq/examples/multilingual/data_scripts/utils/dedup.py new file mode 100644 index 0000000000000000000000000000000000000000..d6fed8c695cf218d3502d6ed8d23015520c0e179 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/utils/dedup.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse + +def deup(src_file, tgt_file, src_file_out, tgt_file_out): + seen = set() + dup_count = 0 + with open(src_file, encoding='utf-8') as fsrc, \ + open(tgt_file, encoding='utf-8') as ftgt, \ + open(src_file_out, 'w', encoding='utf-8') as fsrc_out, \ + open(tgt_file_out, 'w', encoding='utf-8') as ftgt_out: + for s, t in zip(fsrc, ftgt): + if (s, t) not in seen: + fsrc_out.write(s) + ftgt_out.write(t) + seen.add((s, t)) + else: + dup_count += 1 + print(f'number of duplication: {dup_count}') + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--src-file", type=str, required=True, + help="src file") + parser.add_argument("--tgt-file", type=str, required=True, + help="tgt file") + parser.add_argument("--src-file-out", type=str, required=True, + help="src ouptut file") + parser.add_argument("--tgt-file-out", type=str, required=True, + help="tgt ouput file") + args = parser.parse_args() + deup(args.src_file, args.tgt_file, args.src_file_out, args.tgt_file_out) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py b/fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..41b38ba5bef20cb043921ac61820db8689189a5a --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +#!/bin/python + +import fasttext +from multiprocessing import Pool +import contextlib +import sys +import argparse +from functools import partial +import io + +model = None +def init(model_path): + global model + model = fasttext.load_model(model_path) + +def pred(lines): + return lines, [model.predict(line.strip())[0][0][9:] for line in lines] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, + help="model to load") + parser.add_argument("--inputs", nargs="+", default=['-'], + help="input files to filter") + parser.add_argument("--langs", nargs="+", required=True, + help="lang ids of each input file") + parser.add_argument("--outputs", nargs="+", default=['-'], + help="path to save lid filtered outputs") + parser.add_argument("--num-workers", type=int, metavar="N", default=10, + help="number of processes in parallel") + args = parser.parse_args() + + assert len(args.inputs) == len(args.langs) and len(args.inputs) == len(args.outputs) + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8", newline="\n", errors="replace")) + if input != "-" else io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8', errors="replace") + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8", newline="\n")) + if output != "-" else sys.stdout + for output in args.outputs + ] + with Pool(args.num_workers, initializer=partial(init, args.model)) as p: + skip_cnt = 0 + for lines, preds in p.imap(pred, list(zip(*inputs)), chunksize=500): + if not all(a == b for a, b in zip(preds, args.langs)): + skip_cnt += 1 + continue + for line, output_h in zip(lines, outputs): + print(line.strip(), file=output_h) + print(f"Skipped {skip_cnt} lines.") + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh b/fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f4f61d7b1a46f51a1221de6b336cb70b5a0b8b3 --- /dev/null +++ b/fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh @@ -0,0 +1 @@ +grep "seg id" | sed 's/<seg id="[0-9]\+">//g' | sed 's/<\/seg>//g' diff --git a/fairseq/examples/multilingual/finetune_multilingual_model.sh b/fairseq/examples/multilingual/finetune_multilingual_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..25960c5dc8a02e5580b61837099770a082b4dd83 --- /dev/null +++ b/fairseq/examples/multilingual/finetune_multilingual_model.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +path_2_data=$1 # <path to data> which contains binarized data for each directions +lang_list=$2 # <path to a file which contains a list of languages separted by new lines> +lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" +# pretrained can be an mBART pretrained model as well +pretrained_model=$4 #<path to a pretrained model> + + +fairseq-train "$path_2_data" \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --finetune-from-model "$pretrained_model" \ + --sampling-method "temperature" \ + --sampling-temperature "1.5" \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 diff --git a/fairseq/examples/multilingual/multilingual_fairseq_gen.sh b/fairseq/examples/multilingual/multilingual_fairseq_gen.sh new file mode 100644 index 0000000000000000000000000000000000000000..65aa322d7daaa428015de98abe4664a6a4164bfd --- /dev/null +++ b/fairseq/examples/multilingual/multilingual_fairseq_gen.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +lang_pairs="en-fr,en-cs,fr-en,cs-en" +path_2_data=$1 # <path to data> +lang_list=$2 # <path to a file which contains list of languages separted by new lines> +model=$3 # <path to a trained model> +source_lang=cs +target_lang=en + +fairseq-generate "$path_2_data" \ + --path "$model" \ + --task translation_multi_simple_epoch \ + --gen-subset test \ + --source-lang "$source_lang" \ + --target-lang "$target_lang" \ + --sacrebleu --remove-bpe 'sentencepiece'\ + --batch-size 32 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" diff --git a/fairseq/examples/multilingual/train_multilingual_model.sh b/fairseq/examples/multilingual/train_multilingual_model.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc050bd3f02de8a2f303737f187442d2eb80e4ef --- /dev/null +++ b/fairseq/examples/multilingual/train_multilingual_model.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +path_2_data=$1 # <path to data> which contains binarized data for each directions +lang_list=$2 # <path to a file which contains a list of languages separted by new lines> +lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en" + +fairseq-train "$path_2_data" \ + --encoder-normalize-before --decoder-normalize-before \ + --arch transformer --layernorm-embedding \ + --task translation_multi_simple_epoch \ + --sampling-method "temperature" \ + --sampling-temperature 1.5 \ + --encoder-langtok "src" \ + --decoder-langtok \ + --lang-dict "$lang_list" \ + --lang-pairs "$lang_pairs" \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ + --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ + --max-tokens 1024 --update-freq 2 \ + --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --seed 222 --log-format simple --log-interval 2 diff --git a/fairseq/examples/noisychannel/README.md b/fairseq/examples/noisychannel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9d101aa874ec36ff3bb5c1166169a4c4f38ffe2b --- /dev/null +++ b/fairseq/examples/noisychannel/README.md @@ -0,0 +1,72 @@ +# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019) +This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts. + +## Citation: +```bibtex +@inproceedings{yee2019simple, + title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation}, + author = {Kyra Yee and Yann Dauphin and Michael Auli}, + booktitle = {Conference on Empirical Methods in Natural Language Processing}, + year = {2019}, +} +``` + +## Pre-trained Models: + +Model | Description | Download +---|---|--- +`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2) +`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2) +`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2) + +Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2) + +## Example usage + +``` +mkdir rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example +curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example + +beam=50 +num_trials=1000 +fw_name=fw_model_ex +bw_name=bw_model_ex +lm_name=lm_ex +data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe +data_dir_name=wmt17 +lm=rerank_example/lm/checkpoint_best.pt +lm_bpe_code=rerank_example/lm/bpe32k.code +lm_dict=rerank_example/lm/dict.txt +batch_size=32 +bw=rerank_example/backward_en2de.pt +fw=rerank_example/forward_de2en.pt + +# reranking with P(T|S) P(S|T) and P(T) +python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \ + --lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \ + --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \ + --backwards1 --weight2 1 \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name + +# reranking with P(T|S) and P(T) +python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \ + --lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \ + --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model1 $fw \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model1-name $fw_name --gen-model-name $fw_name + +# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead. +python examples/noisychannel/rerank.py $data_dir \ + --lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \ + --data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \ + -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \ + -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \ + --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name +``` + diff --git a/fairseq/examples/noisychannel/__init__.py b/fairseq/examples/noisychannel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89f1aef4f6328d25425e0bcabb42dfffd2ed35f0 --- /dev/null +++ b/fairseq/examples/noisychannel/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .rerank_options import * # noqa diff --git a/fairseq/examples/noisychannel/rerank.py b/fairseq/examples/noisychannel/rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..bb80d11a67cd75764a89f6f41915b0348ae96e92 --- /dev/null +++ b/fairseq/examples/noisychannel/rerank.py @@ -0,0 +1,428 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from multiprocessing import Pool + +import numpy as np +from fairseq import options +from fairseq.data import dictionary +from fairseq.scoring import bleu + +from examples.noisychannel import ( + rerank_generate, + rerank_options, + rerank_score_bw, + rerank_score_lm, + rerank_utils, +) + + +def score_target_hypo( + args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize +): + + print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) + gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) + dict = dictionary.Dictionary() + scorer = scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) + + ordered_hypos = {} + ordered_targets = {} + + for shard_id in range(len(bitext1_lst)): + bitext1 = bitext1_lst[shard_id] + bitext2 = bitext2_lst[shard_id] + gen_output = gen_output_lst[shard_id] + lm_res = lm_res_lst[shard_id] + + total = len(bitext1.rescore_source.keys()) + source_lst = [] + hypo_lst = [] + score_lst = [] + reference_lst = [] + j = 1 + best_score = -math.inf + + for i in range(total): + # length is measured in terms of words, not bpe tokens, since models may not share the same bpe + target_len = len(bitext1.rescore_hypo[i].split()) + + if lm_res is not None: + lm_score = lm_res.score[i] + else: + lm_score = 0 + + if bitext2 is not None: + bitext2_score = bitext2.rescore_score[i] + bitext2_backwards = bitext2.backwards + else: + bitext2_score = None + bitext2_backwards = None + + score = rerank_utils.get_score( + a, + b, + c, + target_len, + bitext1.rescore_score[i], + bitext2_score, + lm_score=lm_score, + lenpen=lenpen, + src_len=bitext1.source_lengths[i], + tgt_len=bitext1.target_lengths[i], + bitext1_backwards=bitext1.backwards, + bitext2_backwards=bitext2_backwards, + normalize=normalize, + ) + + if score > best_score: + best_score = score + best_hypo = bitext1.rescore_hypo[i] + + if j == gen_output.num_hypos[i] or j == args.num_rescore: + j = 1 + hypo_lst.append(best_hypo) + score_lst.append(best_score) + source_lst.append(bitext1.rescore_source[i]) + reference_lst.append(bitext1.rescore_target[i]) + + best_score = -math.inf + best_hypo = "" + else: + j += 1 + + gen_keys = list(sorted(gen_output.no_bpe_target.keys())) + + for key in range(len(gen_keys)): + if args.prefix_len is None: + assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( + "pred and rescore hypo mismatch: i: " + + str(key) + + ", " + + str(hypo_lst[key]) + + str(gen_keys[key]) + + str(gen_output.no_bpe_hypo[key]) + ) + sys_tok = dict.encode_line(hypo_lst[key]) + ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) + scorer.add(ref_tok, sys_tok) + + else: + full_hypo = rerank_utils.get_full_from_prefix( + hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] + ) + sys_tok = dict.encode_line(full_hypo) + ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) + scorer.add(ref_tok, sys_tok) + + # if only one set of hyper parameters is provided, write the predictions to a file + if write_hypos: + # recover the orinal ids from n best list generation + for key in range(len(gen_output.no_bpe_target)): + if args.prefix_len is None: + assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], ( + "pred and rescore hypo mismatch:" + + "i:" + + str(key) + + str(hypo_lst[key]) + + str(gen_output.no_bpe_hypo[key]) + ) + ordered_hypos[gen_keys[key]] = hypo_lst[key] + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ + gen_keys[key] + ] + + else: + full_hypo = rerank_utils.get_full_from_prefix( + hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]] + ) + ordered_hypos[gen_keys[key]] = full_hypo + ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[ + gen_keys[key] + ] + + # write the hypos in the original order from nbest list generation + if args.num_shards == (len(bitext1_lst)): + with open(target_outfile, "w") as t: + with open(hypo_outfile, "w") as h: + for key in range(len(ordered_hypos)): + t.write(ordered_targets[key]) + h.write(ordered_hypos[key]) + + res = scorer.result_string(4) + if write_hypos: + print(res) + score = rerank_utils.parse_bleu_scoring(res) + return score + + +def match_target_hypo(args, target_outfile, hypo_outfile): + """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" + if len(args.weight1) == 1: + res = score_target_hypo( + args, + args.weight1[0], + args.weight2[0], + args.weight3[0], + args.lenpen[0], + target_outfile, + hypo_outfile, + True, + args.normalize, + ) + rerank_scores = [res] + else: + print("launching pool") + with Pool(32) as p: + rerank_scores = p.starmap( + score_target_hypo, + [ + ( + args, + args.weight1[i], + args.weight2[i], + args.weight3[i], + args.lenpen[i], + target_outfile, + hypo_outfile, + False, + args.normalize, + ) + for i in range(len(args.weight1)) + ], + ) + + if len(rerank_scores) > 1: + best_index = np.argmax(rerank_scores) + best_score = rerank_scores[best_index] + print("best score", best_score) + print("best lenpen", args.lenpen[best_index]) + print("best weight1", args.weight1[best_index]) + print("best weight2", args.weight2[best_index]) + print("best weight3", args.weight3[best_index]) + return ( + args.lenpen[best_index], + args.weight1[best_index], + args.weight2[best_index], + args.weight3[best_index], + best_score, + ) + + else: + return ( + args.lenpen[0], + args.weight1[0], + args.weight2[0], + args.weight3[0], + rerank_scores[0], + ) + + +def load_score_files(args): + if args.all_shards: + shard_ids = list(range(args.num_shards)) + else: + shard_ids = [args.shard_id] + + gen_output_lst = [] + bitext1_lst = [] + bitext2_lst = [] + lm_res1_lst = [] + + for shard_id in shard_ids: + using_nbest = args.nbest_list is not None + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) + + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) + if args.language_model is not None: + lm_score_file = rerank_utils.rescore_file_name( + pre_gen, args.prefix_len, args.lm_name, lm_file=True + ) + + # get gen output + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, + bpe_symbol=args.post_process, + nbest=using_nbest, + prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac, + ) + + if rerank1_is_gen: + bitext1 = gen_output + else: + bitext1 = rerank_utils.BitextOutput( + score1_file, + args.backwards1, + args.right_to_left1, + args.post_process, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + if args.score_model2 is not None or args.nbest_list is not None: + if rerank2_is_gen: + bitext2 = gen_output + else: + bitext2 = rerank_utils.BitextOutput( + score2_file, + args.backwards2, + args.right_to_left2, + args.post_process, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + assert ( + bitext2.source_lengths == bitext1.source_lengths + ), "source lengths for rescoring models do not match" + assert ( + bitext2.target_lengths == bitext1.target_lengths + ), "target lengths for rescoring models do not match" + else: + if args.diff_bpe: + assert args.score_model2 is None + bitext2 = gen_output + else: + bitext2 = None + + if args.language_model is not None: + lm_res1 = rerank_utils.LMOutput( + lm_score_file, + args.lm_dict, + args.prefix_len, + args.post_process, + args.target_prefix_frac, + ) + else: + lm_res1 = None + + gen_output_lst.append(gen_output) + bitext1_lst.append(bitext1) + bitext2_lst.append(bitext2) + lm_res1_lst.append(lm_res1) + return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst + + +def rerank(args): + if type(args.lenpen) is not list: + args.lenpen = [args.lenpen] + if type(args.weight1) is not list: + args.weight1 = [args.weight1] + if type(args.weight2) is not list: + args.weight2 = [args.weight2] + if type(args.weight3) is not list: + args.weight3 = [args.weight3] + if args.all_shards: + shard_ids = list(range(args.num_shards)) + else: + shard_ids = [args.shard_id] + + for shard_id in shard_ids: + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + rerank_generate.gen_and_reprocess_nbest(args) + rerank_score_bw.score_bw(args) + rerank_score_lm.score_lm(args) + + if args.write_hypos is None: + write_targets = pre_gen + "/matched_targets" + write_hypos = pre_gen + "/matched_hypos" + else: + write_targets = args.write_hypos + "_targets" + args.gen_subset + write_hypos = args.write_hypos + "_hypos" + args.gen_subset + + if args.all_shards: + write_targets += "_all_shards" + write_hypos += "_all_shards" + + ( + best_lenpen, + best_weight1, + best_weight2, + best_weight3, + best_score, + ) = match_target_hypo(args, write_targets, write_hypos) + + return best_lenpen, best_weight1, best_weight2, best_weight3, best_score + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + rerank(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/noisychannel/rerank_generate.py b/fairseq/examples/noisychannel/rerank_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..daeeae059a677a9fcd7c370be087f1f5c189bc52 --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_generate.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Generate n-best translations using a trained model. +""" + +import os +import subprocess +from contextlib import redirect_stdout + +from fairseq import options +from fairseq_cli import generate, preprocess + +from examples.noisychannel import rerank_options, rerank_utils + + +def gen_and_reprocess_nbest(args): + if args.score_dict_dir is None: + args.score_dict_dir = args.data + if args.prefix_len is not None: + assert ( + args.right_to_left1 is False + ), "prefix length not compatible with right to left models" + assert ( + args.right_to_left2 is False + ), "prefix length not compatible with right to left models" + + if args.nbest_list is not None: + assert args.score_model2 is None + + if args.backwards1: + scorer1_src = args.target_lang + scorer1_tgt = args.source_lang + else: + scorer1_src = args.source_lang + scorer1_tgt = args.target_lang + + store_data = ( + os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name + ) + if not os.path.exists(store_data): + os.makedirs(store_data) + + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + assert not ( + args.right_to_left1 and args.backwards1 + ), "backwards right to left not supported" + assert not ( + args.right_to_left2 and args.backwards2 + ), "backwards right to left not supported" + assert not ( + args.prefix_len is not None and args.target_prefix_frac is not None + ), "target prefix frac and target prefix len incompatible" + + # make directory to store generation results + if not os.path.exists(pre_gen): + os.makedirs(pre_gen) + + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) + + if args.nbest_list is not None: + rerank2_is_gen = True + + # make directories to store preprossed nbest list for reranking + if not os.path.exists(left_to_right_preprocessed_dir): + os.makedirs(left_to_right_preprocessed_dir) + if not os.path.exists(right_to_left_preprocessed_dir): + os.makedirs(right_to_left_preprocessed_dir) + if not os.path.exists(lm_preprocessed_dir): + os.makedirs(lm_preprocessed_dir) + if not os.path.exists(backwards_preprocessed_dir): + os.makedirs(backwards_preprocessed_dir) + + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) + + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" + + using_nbest = args.nbest_list is not None + + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + + else: + if not os.path.isfile(predictions_bpe_file): + print("STEP 1: generate predictions using the p(T|S) model with bpe") + print(args.data) + param1 = [ + args.data, + "--path", + args.gen_model, + "--shard-id", + str(args.shard_id), + "--num-shards", + str(args.num_shards), + "--nbest", + str(args.num_rescore), + "--batch-size", + str(args.batch_size), + "--beam", + str(args.num_rescore), + "--batch-size", + str(args.num_rescore), + "--gen-subset", + args.gen_subset, + "--source-lang", + args.source_lang, + "--target-lang", + args.target_lang, + ] + if args.sampling: + param1 += ["--sampling"] + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, param1) + + print(input_args) + with open(predictions_bpe_file, "w") as f: + with redirect_stdout(f): + generate.main(input_args) + + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, + bpe_symbol=args.post_process, + nbest=using_nbest, + prefix_len=args.prefix_len, + target_prefix_frac=args.target_prefix_frac, + ) + + if args.diff_bpe: + rerank_utils.write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + pre_gen + "/source_gen_bpe." + args.source_lang, + pre_gen + "/target_gen_bpe." + args.target_lang, + pre_gen + "/reference_gen_bpe." + args.target_lang, + ) + bitext_bpe = args.rescore_bpe_code + bpe_src_param = [ + "-c", + bitext_bpe, + "--input", + pre_gen + "/source_gen_bpe." + args.source_lang, + "--output", + pre_gen + "/rescore_data." + args.source_lang, + ] + bpe_tgt_param = [ + "-c", + bitext_bpe, + "--input", + pre_gen + "/target_gen_bpe." + args.target_lang, + "--output", + pre_gen + "/rescore_data." + args.target_lang, + ] + + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_src_param, + shell=False, + ) + + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_tgt_param, + shell=False, + ) + + if (not os.path.isfile(score1_file) and not rerank1_is_gen) or ( + args.score_model2 is not None + and not os.path.isfile(score2_file) + and not rerank2_is_gen + ): + print( + "STEP 2: process the output of generate.py so we have clean text files with the translations" + ) + + rescore_file = "/rescore_data" + if args.prefix_len is not None: + prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len) + if args.target_prefix_frac is not None: + target_prefix_frac_rescore_file = ( + rescore_file + "target_prefix_frac" + str(args.target_prefix_frac) + ) + if args.source_prefix_frac is not None: + source_prefix_frac_rescore_file = ( + rescore_file + "source_prefix_frac" + str(args.source_prefix_frac) + ) + + if not args.right_to_left1 or not args.right_to_left2: + if not args.diff_bpe: + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + rescore_file + "." + args.source_lang, + pre_gen + rescore_file + "." + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.post_process, + ) + if args.prefix_len is not None: + bw_rescore_file = prefix_len_rescore_file + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + prefix_len_rescore_file + "." + args.source_lang, + pre_gen + prefix_len_rescore_file + "." + args.target_lang, + pre_gen + "/reference_file", + prefix_len=args.prefix_len, + bpe_symbol=args.post_process, + ) + elif args.target_prefix_frac is not None: + bw_rescore_file = target_prefix_frac_rescore_file + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + + target_prefix_frac_rescore_file + + "." + + args.source_lang, + pre_gen + + target_prefix_frac_rescore_file + + "." + + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.post_process, + target_prefix_frac=args.target_prefix_frac, + ) + else: + bw_rescore_file = rescore_file + + if args.source_prefix_frac is not None: + fw_rescore_file = source_prefix_frac_rescore_file + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + + source_prefix_frac_rescore_file + + "." + + args.source_lang, + pre_gen + + source_prefix_frac_rescore_file + + "." + + args.target_lang, + pre_gen + "/reference_file", + bpe_symbol=args.post_process, + source_prefix_frac=args.source_prefix_frac, + ) + else: + fw_rescore_file = rescore_file + + if args.right_to_left1 or args.right_to_left2: + rerank_utils.write_reprocessed( + gen_output.source, + gen_output.hypo, + gen_output.target, + pre_gen + "/right_to_left_rescore_data." + args.source_lang, + pre_gen + "/right_to_left_rescore_data." + args.target_lang, + pre_gen + "/right_to_left_reference_file", + right_to_left=True, + bpe_symbol=args.post_process, + ) + + print("STEP 3: binarize the translations") + if ( + not args.right_to_left1 + or args.score_model2 is not None + and not args.right_to_left2 + or not rerank1_is_gen + ): + + if args.backwards1 or args.backwards2: + if args.backwards_score_dict_dir is not None: + bw_dict = args.backwards_score_dict_dir + else: + bw_dict = args.score_dict_dir + bw_preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + bw_rescore_file, + "--srcdict", + bw_dict + "/dict." + scorer1_src + ".txt", + "--tgtdict", + bw_dict + "/dict." + scorer1_tgt + ".txt", + "--destdir", + backwards_preprocessed_dir, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(bw_preprocess_param) + preprocess.main(input_args) + + preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + fw_rescore_file, + "--srcdict", + args.score_dict_dir + "/dict." + scorer1_src + ".txt", + "--tgtdict", + args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", + "--destdir", + left_to_right_preprocessed_dir, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_param) + preprocess.main(input_args) + + if args.right_to_left1 or args.right_to_left2: + preprocess_param = [ + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + "--trainpref", + pre_gen + "/right_to_left_rescore_data", + "--srcdict", + args.score_dict_dir + "/dict." + scorer1_src + ".txt", + "--tgtdict", + args.score_dict_dir + "/dict." + scorer1_tgt + ".txt", + "--destdir", + right_to_left_preprocessed_dir, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_param) + preprocess.main(input_args) + + return gen_output + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + gen_and_reprocess_nbest(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/noisychannel/rerank_options.py b/fairseq/examples/noisychannel/rerank_options.py new file mode 100644 index 0000000000000000000000000000000000000000..de91939e6635bdf33c9dc330116be07d9e8be6a2 --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_options.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import options + + +def get_reranking_parser(default_task="translation"): + parser = options.get_parser("Generation and reranking", default_task) + add_reranking_args(parser) + return parser + + +def get_tuning_parser(default_task="translation"): + parser = options.get_parser("Reranking tuning", default_task) + add_reranking_args(parser) + add_tuning_args(parser) + return parser + + +def add_reranking_args(parser): + group = parser.add_argument_group("Reranking") + # fmt: off + group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True, + help='path to first model or ensemble of models for rescoring') + group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False, + help='path to second model or ensemble of models for rescoring') + group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10, + help='the number of candidate hypothesis to rescore') + group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128, + help='batch size for generating the nbest list') + group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'], + help='data subset to generate (train, valid, test)') + group.add_argument('--gen-model', default=None, metavar='FILE', + help='the model to generate translations') + group.add_argument('-b1', '--backwards1', action='store_true', + help='whether or not the first model group is backwards') + group.add_argument('-b2', '--backwards2', action='store_true', + help='whether or not the second model group is backwards') + group.add_argument('-a', '--weight1', default=1, nargs='+', type=float, + help='the weight(s) of the first model') + group.add_argument('-b', '--weight2', default=1, nargs='+', type=float, + help='the weight(s) of the second model, or the gen model if using nbest from interactive.py') + group.add_argument('-c', '--weight3', default=1, nargs='+', type=float, + help='the weight(s) of the third model') + + # lm arguments + group.add_argument('-lm', '--language-model', default=None, metavar='FILE', + help='language model for target language to rescore translations') + group.add_argument('--lm-dict', default=None, metavar='FILE', + help='the dict of the language model for the target language') + group.add_argument('--lm-name', default=None, + help='the name of the language model for the target language') + group.add_argument('--lm-bpe-code', default=None, metavar='FILE', + help='the bpe code for the language model for the target language') + group.add_argument('--data-dir-name', default=None, + help='name of data directory') + group.add_argument('--lenpen', default=1, nargs='+', type=float, + help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') + group.add_argument('--score-dict-dir', default=None, + help='the directory with dictionaries for the scoring models') + group.add_argument('--right-to-left1', action='store_true', + help='whether the first model group is a right to left model') + group.add_argument('--right-to-left2', action='store_true', + help='whether the second model group is a right to left model') + group.add_argument('--post-process', '--remove-bpe', default='@@ ', + help='the bpe symbol, used for the bitext and LM') + group.add_argument('--prefix-len', default=None, type=int, + help='the length of the target prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--sampling', action='store_true', + help='use sampling instead of beam search for generating n best list') + group.add_argument('--diff-bpe', action='store_true', + help='bpe for rescoring and nbest list not the same') + group.add_argument('--rescore-bpe-code', default=None, + help='bpe code for rescoring models') + group.add_argument('--nbest-list', default=None, + help='use predefined nbest list in interactive.py format') + group.add_argument('--write-hypos', default=None, + help='filename prefix to write hypos to') + group.add_argument('--ref-translation', default=None, + help='reference translation to use with nbest list from interactive.py') + group.add_argument('--backwards-score-dict-dir', default=None, + help='the directory with dictionaries for the backwards model,' + 'if None then it is assumed the fw and backwards models share dictionaries') + + # extra scaling args + group.add_argument('--gen-model-name', default=None, + help='the name of the models that generated the nbest list') + group.add_argument('--model1-name', default=None, + help='the name of the set for model1 group ') + group.add_argument('--model2-name', default=None, + help='the name of the set for model2 group') + group.add_argument('--shard-id', default=0, type=int, + help='the id of the shard to generate') + group.add_argument('--num-shards', default=1, type=int, + help='the number of shards to generate across') + group.add_argument('--all-shards', action='store_true', + help='use all shards') + group.add_argument('--target-prefix-frac', default=None, type=float, + help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--source-prefix-frac', default=None, type=float, + help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)') + group.add_argument('--normalize', action='store_true', + help='whether to normalize by src and target len') + # fmt: on + return group + + +def add_tuning_args(parser): + group = parser.add_argument_group("Tuning") + + group.add_argument( + "--lower-bound", + default=[-0.7], + nargs="+", + type=float, + help="lower bound of search space", + ) + group.add_argument( + "--upper-bound", + default=[3], + nargs="+", + type=float, + help="upper bound of search space", + ) + group.add_argument( + "--tune-param", + default=["lenpen"], + nargs="+", + choices=["lenpen", "weight1", "weight2", "weight3"], + help="the parameter(s) to tune", + ) + group.add_argument( + "--tune-subset", + default="valid", + choices=["valid", "test", "train"], + help="the subset to tune on ", + ) + group.add_argument( + "--num-trials", + default=1000, + type=int, + help="number of trials to do for random search", + ) + group.add_argument( + "--share-weights", action="store_true", help="share weight2 and weight 3" + ) + return group diff --git a/fairseq/examples/noisychannel/rerank_score_bw.py b/fairseq/examples/noisychannel/rerank_score_bw.py new file mode 100644 index 0000000000000000000000000000000000000000..b0bc913651bd76667e25c214acb70f2bca19e185 --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_score_bw.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from contextlib import redirect_stdout + +from fairseq import options +from fairseq_cli import generate + +from examples.noisychannel import rerank_options, rerank_utils + + +def score_bw(args): + if args.backwards1: + scorer1_src = args.target_lang + scorer1_tgt = args.source_lang + else: + scorer1_src = args.source_lang + scorer1_tgt = args.target_lang + + if args.score_model2 is not None: + if args.backwards2: + scorer2_src = args.target_lang + scorer2_tgt = args.source_lang + else: + scorer2_src = args.source_lang + scorer2_tgt = args.target_lang + + rerank1_is_gen = ( + args.gen_model == args.score_model1 and args.source_prefix_frac is None + ) + rerank2_is_gen = ( + args.gen_model == args.score_model2 and args.source_prefix_frac is None + ) + + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + score1_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model1_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards1, + ) + + if args.score_model2 is not None: + score2_file = rerank_utils.rescore_file_name( + pre_gen, + args.prefix_len, + args.model2_name, + target_prefix_frac=args.target_prefix_frac, + source_prefix_frac=args.source_prefix_frac, + backwards=args.backwards2, + ) + + if args.right_to_left1: + rerank_data1 = right_to_left_preprocessed_dir + elif args.backwards1: + rerank_data1 = backwards_preprocessed_dir + else: + rerank_data1 = left_to_right_preprocessed_dir + + gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] + if not rerank1_is_gen and not os.path.isfile(score1_file): + print("STEP 4: score the translations for model 1") + + model_param1 = [ + "--path", + args.score_model1, + "--source-lang", + scorer1_src, + "--target-lang", + scorer1_tgt, + ] + gen_model1_param = [rerank_data1] + gen_param + model_param1 + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) + + with open(score1_file, "w") as f: + with redirect_stdout(f): + generate.main(input_args) + + if ( + args.score_model2 is not None + and not os.path.isfile(score2_file) + and not rerank2_is_gen + ): + print("STEP 4: score the translations for model 2") + + if args.right_to_left2: + rerank_data2 = right_to_left_preprocessed_dir + elif args.backwards2: + rerank_data2 = backwards_preprocessed_dir + else: + rerank_data2 = left_to_right_preprocessed_dir + + model_param2 = [ + "--path", + args.score_model2, + "--source-lang", + scorer2_src, + "--target-lang", + scorer2_tgt, + ] + gen_model2_param = [rerank_data2] + gen_param + model_param2 + + gen_parser = options.get_generation_parser() + input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) + + with open(score2_file, "w") as f: + with redirect_stdout(f): + generate.main(input_args) + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + score_bw(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/noisychannel/rerank_score_lm.py b/fairseq/examples/noisychannel/rerank_score_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..e80948d78b02561cbd09d72c319222105f41f6bb --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_score_lm.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from fairseq import options + +from examples.noisychannel import rerank_options, rerank_utils + + +def score_lm(args): + using_nbest = args.nbest_list is not None + ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) = rerank_utils.get_directories( + args.data_dir_name, + args.num_rescore, + args.gen_subset, + args.gen_model_name, + args.shard_id, + args.num_shards, + args.sampling, + args.prefix_len, + args.target_prefix_frac, + args.source_prefix_frac, + ) + + predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" + if using_nbest: + print("Using predefined n-best list from interactive.py") + predictions_bpe_file = args.nbest_list + + gen_output = rerank_utils.BitextOutputFromGen( + predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest + ) + + if args.language_model is not None: + lm_score_file = rerank_utils.rescore_file_name( + pre_gen, args.prefix_len, args.lm_name, lm_file=True + ) + + if args.language_model is not None and not os.path.isfile(lm_score_file): + print("STEP 4.5: language modeling for P(T)") + if args.lm_bpe_code is None: + bpe_status = "no bpe" + elif args.lm_bpe_code == "shared": + bpe_status = "shared" + else: + bpe_status = "different" + + rerank_utils.lm_scoring( + lm_preprocessed_dir, + bpe_status, + gen_output, + pre_gen, + args.lm_dict, + args.lm_name, + args.language_model, + args.lm_bpe_code, + 128, + lm_score_file, + args.target_lang, + args.source_lang, + prefix_len=args.prefix_len, + ) + + +def cli_main(): + parser = rerank_options.get_reranking_parser() + args = options.parse_args_and_arch(parser) + score_lm(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/noisychannel/rerank_tune.py b/fairseq/examples/noisychannel/rerank_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e8b7594a370b2462f77252d54d7ef80e290f7c --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_tune.py @@ -0,0 +1,102 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import random + +import numpy as np +from fairseq import options + +from examples.noisychannel import rerank, rerank_options + + +def random_search(args): + param_values = [] + tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"] + initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3] + for i, elem in enumerate(initial_params): + if type(elem) is not list: + initial_params[i] = [elem] + else: + initial_params[i] = elem + + tune_parameters = args.tune_param.copy() + for i in range(len(args.tune_param)): + assert args.upper_bound[i] >= args.lower_bound[i] + index = tuneable_parameters.index(args.tune_param[i]) + del tuneable_parameters[index] + del initial_params[index] + + tune_parameters += tuneable_parameters + param_values += initial_params + random.seed(args.seed) + + random_params = np.array( + [ + [ + random.uniform(args.lower_bound[i], args.upper_bound[i]) + for i in range(len(args.tune_param)) + ] + for k in range(args.num_trials) + ] + ) + set_params = np.array( + [ + [initial_params[i][0] for i in range(len(tuneable_parameters))] + for k in range(args.num_trials) + ] + ) + random_params = np.concatenate((random_params, set_params), 1) + + rerank_args = vars(args).copy() + if args.nbest_list: + rerank_args["gen_subset"] = "test" + else: + rerank_args["gen_subset"] = args.tune_subset + + for k in range(len(tune_parameters)): + rerank_args[tune_parameters[k]] = list(random_params[:, k]) + + if args.share_weights: + k = tune_parameters.index("weight2") + rerank_args["weight3"] = list(random_params[:, k]) + + rerank_args = argparse.Namespace(**rerank_args) + best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank( + rerank_args + ) + rerank_args = vars(args).copy() + rerank_args["lenpen"] = [best_lenpen] + rerank_args["weight1"] = [best_weight1] + rerank_args["weight2"] = [best_weight2] + rerank_args["weight3"] = [best_weight3] + + # write the hypothesis from the valid set from the best trial + + if args.gen_subset != "valid": + rerank_args["gen_subset"] = "valid" + rerank_args = argparse.Namespace(**rerank_args) + rerank.rerank(rerank_args) + + # test with the best hyperparameters on gen subset + rerank_args = vars(args).copy() + rerank_args["gen_subset"] = args.gen_subset + rerank_args["lenpen"] = [best_lenpen] + rerank_args["weight1"] = [best_weight1] + rerank_args["weight2"] = [best_weight2] + rerank_args["weight3"] = [best_weight3] + rerank_args = argparse.Namespace(**rerank_args) + rerank.rerank(rerank_args) + + +def cli_main(): + parser = rerank_options.get_tuning_parser() + args = options.parse_args_and_arch(parser) + + random_search(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/noisychannel/rerank_utils.py b/fairseq/examples/noisychannel/rerank_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6bf1b1afbb089cf5e84f720eb7a067479fbcbc --- /dev/null +++ b/fairseq/examples/noisychannel/rerank_utils.py @@ -0,0 +1,850 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import re +import subprocess +from contextlib import redirect_stdout + +from fairseq import options +from fairseq_cli import eval_lm, preprocess + + +def reprocess(fle): + # takes in a file of generate.py translation generate_output + # returns a source dict and hypothesis dict, where keys are the ID num (as a string) + # and values and the corresponding source and translation. There may be several translations + # per source, so the values for hypothesis_dict are lists. + # parses output of generate.py + + with open(fle, "r") as f: + txt = f.read() + + """reprocess generate.py output""" + p = re.compile(r"[STHP][-]\d+\s*") + hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)") + source_dict = {} + hypothesis_dict = {} + score_dict = {} + target_dict = {} + pos_score_dict = {} + lines = txt.split("\n") + + for line in lines: + line += "\n" + prefix = re.search(p, line) + if prefix is not None: + assert len(prefix.group()) > 2, "prefix id not found" + _, j = prefix.span() + id_num = prefix.group()[2:] + id_num = int(id_num) + line_type = prefix.group()[0] + if line_type == "H": + h_txt = line[j:] + hypo = re.search(hp, h_txt) + assert ( + hypo is not None + ), "regular expression failed to find the hypothesis scoring" + _, i = hypo.span() + score = hypo.group() + if id_num in hypothesis_dict: + hypothesis_dict[id_num].append(h_txt[i:]) + score_dict[id_num].append(float(score)) + else: + hypothesis_dict[id_num] = [h_txt[i:]] + score_dict[id_num] = [float(score)] + + elif line_type == "S": + source_dict[id_num] = line[j:] + elif line_type == "T": + target_dict[id_num] = line[j:] + elif line_type == "P": + pos_scores = (line[j:]).split() + pos_scores = [float(x) for x in pos_scores] + if id_num in pos_score_dict: + pos_score_dict[id_num].append(pos_scores) + else: + pos_score_dict[id_num] = [pos_scores] + + return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict + + +def reprocess_nbest(fle): + """reprocess interactive.py output""" + with open(fle, "r") as f: + txt = f.read() + + source_dict = {} + hypothesis_dict = {} + score_dict = {} + target_dict = {} + pos_score_dict = {} + lines = txt.split("\n") + + hp = re.compile(r"[-]?\d+[.]?\d+") + j = -1 + + for _i, line in enumerate(lines): + line += "\n" + line_type = line[0] + + if line_type == "H": + hypo = re.search(hp, line) + _, start_index = hypo.span() + score = hypo.group() + if j in score_dict: + score_dict[j].append(float(score)) + hypothesis_dict[j].append(line[start_index:].strip("\t")) + else: + score_dict[j] = [float(score)] + hypothesis_dict[j] = [line[start_index:].strip("\t")] + elif line_type == "O": + j += 1 + source_dict[j] = line[2:] + # we don't have the targets for interactive.py + target_dict[j] = "filler" + + elif line_type == "P": + pos_scores = [float(pos_score) for pos_score in line.split()[1:]] + if j in pos_score_dict: + pos_score_dict[j].append(pos_scores) + else: + pos_score_dict[j] = [pos_scores] + + assert source_dict.keys() == hypothesis_dict.keys() + assert source_dict.keys() == pos_score_dict.keys() + assert source_dict.keys() == score_dict.keys() + + return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict + + +def write_reprocessed( + sources, + hypos, + targets, + source_outfile, + hypo_outfile, + target_outfile, + right_to_left=False, + prefix_len=None, + bpe_symbol=None, + target_prefix_frac=None, + source_prefix_frac=None, +): + + """writes nbest hypothesis for rescoring""" + assert not ( + prefix_len is not None and target_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + assert not ( + prefix_len is not None and source_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + assert not ( + target_prefix_frac is not None and source_prefix_frac is not None + ), "in writing reprocessed, only one type of prefix may be used" + + with open(source_outfile, "w") as source_file, open( + hypo_outfile, "w" + ) as hypo_file, open(target_outfile, "w") as target_file: + + assert len(sources) == len(hypos), "sources and hypos list length mismatch" + if right_to_left: + for i in range(len(sources)): + for j in range(len(hypos[i])): + if prefix_len is None: + hypo_file.write(make_right_to_left(hypos[i][j]) + "\n") + else: + raise NotImplementedError() + source_file.write(make_right_to_left(sources[i]) + "\n") + target_file.write(make_right_to_left(targets[i]) + "\n") + else: + for i in sorted(sources.keys()): + for j in range(len(hypos[i])): + if prefix_len is not None: + shortened = ( + get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len) + + "\n" + ) + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif target_prefix_frac is not None: + num_words, shortened, num_bpe_tokens = calc_length_from_frac( + hypos[i][j], target_prefix_frac, bpe_symbol + ) + shortened += "\n" + hypo_file.write(shortened) + source_file.write(sources[i]) + target_file.write(targets[i]) + elif source_prefix_frac is not None: + num_words, shortened, num_bpe_tokensn = calc_length_from_frac( + sources[i], source_prefix_frac, bpe_symbol + ) + shortened += "\n" + hypo_file.write(hypos[i][j]) + source_file.write(shortened) + target_file.write(targets[i]) + else: + hypo_file.write(hypos[i][j]) + source_file.write(sources[i]) + target_file.write(targets[i]) + + +def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol): + # return number of words, (not bpe tokens) that we want + no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol) + len_sen = len(no_bpe_sen.split()) + + num_words = math.ceil(len_sen * prefix_frac) + prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words) + num_bpe_tokens = len(prefix.split()) + return num_words, prefix, num_bpe_tokens + + +def get_prefix(sentence, prefix_len): + """assuming no bpe, gets the prefix of the sentence with prefix_len words""" + tokens = sentence.strip("\n").split() + if prefix_len >= len(tokens): + return sentence.strip("\n") + else: + return " ".join(tokens[:prefix_len]) + + +def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len): + if bpe_symbol is None: + return get_prefix(sentence, prefix_len) + else: + return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len)) + + +def get_prefix_from_len(sentence, bpe_symbol, prefix_len): + """get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens""" + bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]]) + if bpe_count == 0: + return sentence[:prefix_len] + else: + return sentence[:prefix_len] + get_prefix_from_len( + sentence[prefix_len:], bpe_symbol, bpe_count + ) + + +def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len): + """given a prefix length in terms of words, return the number of bpe tokens""" + prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len) + assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len + return len(prefix.split(" ")) + + +def make_right_to_left(line): + tokens = line.split() + tokens.reverse() + new_line = " ".join(tokens) + return new_line + + +def remove_bpe(line, bpe_symbol): + line = line.replace("\n", "") + line = (line + " ").replace(bpe_symbol, "").rstrip() + return line + ("\n") + + +def remove_bpe_dict(pred_dict, bpe_symbol): + new_dict = {} + for i in pred_dict: + if type(pred_dict[i]) == list: + new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]] + new_dict[i] = new_list + else: + new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol) + return new_dict + + +def parse_bleu_scoring(line): + p = re.compile(r"(BLEU4 = )\d+[.]\d+") + res = re.search(p, line) + assert res is not None, line + return float(res.group()[8:]) + + +def get_full_from_prefix(hypo_prefix, hypos): + """given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix""" + for hypo in hypos: + hypo_prefix = hypo_prefix.strip("\n") + len_prefix = len(hypo_prefix) + if hypo[:len_prefix] == hypo_prefix: + return hypo + # no match found + raise Exception() + + +def get_score( + a, + b, + c, + target_len, + bitext_score1, + bitext_score2=None, + lm_score=None, + lenpen=None, + src_len=None, + tgt_len=None, + bitext1_backwards=False, + bitext2_backwards=False, + normalize=False, +): + if bitext1_backwards: + bitext1_norm = src_len + else: + bitext1_norm = tgt_len + if bitext_score2 is not None: + if bitext2_backwards: + bitext2_norm = src_len + else: + bitext2_norm = tgt_len + else: + bitext2_norm = 1 + bitext_score2 = 0 + if normalize: + score = ( + a * bitext_score1 / bitext1_norm + + b * bitext_score2 / bitext2_norm + + c * lm_score / src_len + ) + else: + score = a * bitext_score1 + b * bitext_score2 + c * lm_score + + if lenpen is not None: + score /= (target_len) ** float(lenpen) + + return score + + +class BitextOutput(object): + def __init__( + self, + output_file, + backwards, + right_to_left, + bpe_symbol, + prefix_len=None, + target_prefix_frac=None, + source_prefix_frac=None, + ): + """process output from rescoring""" + source, hypo, score, target, pos_score = reprocess(output_file) + if backwards: + self.hypo_fracs = source_prefix_frac + else: + self.hypo_fracs = target_prefix_frac + + # remove length penalty so we can use raw scores + score, num_bpe_tokens = get_score_from_pos( + pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards + ) + source_lengths = {} + target_lengths = {} + + assert hypo.keys() == source.keys(), "key mismatch" + if backwards: + tmp = hypo + hypo = source + source = tmp + for i in source: + # since we are reranking, there should only be one hypo per source sentence + if backwards: + len_src = len(source[i][0].split()) + # record length without <eos> + if len_src == num_bpe_tokens[i][0] - 1: + source_lengths[i] = num_bpe_tokens[i][0] - 1 + else: + source_lengths[i] = num_bpe_tokens[i][0] + + target_lengths[i] = len(hypo[i].split()) + + source[i] = remove_bpe(source[i][0], bpe_symbol) + target[i] = remove_bpe(target[i], bpe_symbol) + hypo[i] = remove_bpe(hypo[i], bpe_symbol) + + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + + else: + len_tgt = len(hypo[i][0].split()) + # record length without <eos> + if len_tgt == num_bpe_tokens[i][0] - 1: + target_lengths[i] = num_bpe_tokens[i][0] - 1 + else: + target_lengths[i] = num_bpe_tokens[i][0] + + source_lengths[i] = len(source[i].split()) + + if right_to_left: + source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol) + target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol) + hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol) + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + else: + assert ( + len(hypo[i]) == 1 + ), "expected only one hypothesis per source sentence" + source[i] = remove_bpe(source[i], bpe_symbol) + target[i] = remove_bpe(target[i], bpe_symbol) + hypo[i] = remove_bpe(hypo[i][0], bpe_symbol) + score[i] = float(score[i][0]) + pos_score[i] = pos_score[i][0] + + self.rescore_source = source + self.rescore_hypo = hypo + self.rescore_score = score + self.rescore_target = target + self.rescore_pos_score = pos_score + self.backwards = backwards + self.right_to_left = right_to_left + self.target_lengths = target_lengths + self.source_lengths = source_lengths + + +class BitextOutputFromGen(object): + def __init__( + self, + predictions_bpe_file, + bpe_symbol=None, + nbest=False, + prefix_len=None, + target_prefix_frac=None, + ): + if nbest: + ( + pred_source, + pred_hypo, + pred_score, + pred_target, + pred_pos_score, + ) = reprocess_nbest(predictions_bpe_file) + else: + pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess( + predictions_bpe_file + ) + + assert len(pred_source) == len(pred_hypo) + assert len(pred_source) == len(pred_score) + assert len(pred_source) == len(pred_target) + assert len(pred_source) == len(pred_pos_score) + + # remove length penalty so we can use raw scores + pred_score, num_bpe_tokens = get_score_from_pos( + pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False + ) + + self.source = pred_source + self.target = pred_target + self.score = pred_score + self.pos_score = pred_pos_score + self.hypo = pred_hypo + self.target_lengths = {} + self.source_lengths = {} + + self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol) + self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol) + self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol) + + # indexes to match those from the rescoring models + self.rescore_source = {} + self.rescore_target = {} + self.rescore_pos_score = {} + self.rescore_hypo = {} + self.rescore_score = {} + self.num_hypos = {} + self.backwards = False + self.right_to_left = False + + index = 0 + + for i in sorted(pred_source.keys()): + for j in range(len(pred_hypo[i])): + + self.target_lengths[index] = len(self.hypo[i][j].split()) + self.source_lengths[index] = len(self.source[i].split()) + + self.rescore_source[index] = self.no_bpe_source[i] + self.rescore_target[index] = self.no_bpe_target[i] + self.rescore_hypo[index] = self.no_bpe_hypo[i][j] + self.rescore_score[index] = float(pred_score[i][j]) + self.rescore_pos_score[index] = pred_pos_score[i][j] + self.num_hypos[index] = len(pred_hypo[i]) + index += 1 + + +def get_score_from_pos( + pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards +): + score_dict = {} + num_bpe_tokens_dict = {} + assert prefix_len is None or hypo_frac is None + for key in pos_score_dict: + score_dict[key] = [] + num_bpe_tokens_dict[key] = [] + for i in range(len(pos_score_dict[key])): + if prefix_len is not None and not backwards: + num_bpe_tokens = get_num_bpe_tokens_from_len( + hypo_dict[key][i], bpe_symbol, prefix_len + ) + score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens])) + num_bpe_tokens_dict[key].append(num_bpe_tokens) + elif hypo_frac is not None: + num_words, shortened, hypo_prefix_len = calc_length_from_frac( + hypo_dict[key][i], hypo_frac, bpe_symbol + ) + score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len])) + num_bpe_tokens_dict[key].append(hypo_prefix_len) + else: + score_dict[key].append(sum(pos_score_dict[key][i])) + num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i])) + return score_dict, num_bpe_tokens_dict + + +class LMOutput(object): + def __init__( + self, + lm_score_file, + lm_dict=None, + prefix_len=None, + bpe_symbol=None, + target_prefix_frac=None, + ): + ( + lm_sentences, + lm_sen_scores, + lm_sen_pos_scores, + lm_no_bpe_sentences, + lm_bpe_tokens, + ) = parse_lm( + lm_score_file, + prefix_len=prefix_len, + bpe_symbol=bpe_symbol, + target_prefix_frac=target_prefix_frac, + ) + + self.sentences = lm_sentences + self.score = lm_sen_scores + self.pos_score = lm_sen_pos_scores + self.lm_dict = lm_dict + self.no_bpe_sentences = lm_no_bpe_sentences + self.bpe_tokens = lm_bpe_tokens + + +def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None): + """parse output of eval_lm""" + with open(input_file, "r") as f: + text = f.readlines() + text = text[7:] + cleaned_text = text[:-2] + + sentences = {} + sen_scores = {} + sen_pos_scores = {} + no_bpe_sentences = {} + num_bpe_tokens_dict = {} + for _i, line in enumerate(cleaned_text): + tokens = line.split() + if tokens[0].isdigit(): + line_id = int(tokens[0]) + scores = [float(x[1:-1]) for x in tokens[2::2]] + sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n" + if bpe_symbol is not None: + # exclude <eos> symbol to match output from generate.py + bpe_sen = " ".join(tokens[1::2][:-1]) + "\n" + no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol) + no_bpe_sentences[line_id] = no_bpe_sen + + if prefix_len is not None: + num_bpe_tokens = get_num_bpe_tokens_from_len( + bpe_sen, bpe_symbol, prefix_len + ) + sen_scores[line_id] = sum(scores[:num_bpe_tokens]) + num_bpe_tokens_dict[line_id] = num_bpe_tokens + elif target_prefix_frac is not None: + num_words, shortened, target_prefix_len = calc_length_from_frac( + bpe_sen, target_prefix_frac, bpe_symbol + ) + sen_scores[line_id] = sum(scores[:target_prefix_len]) + num_bpe_tokens_dict[line_id] = target_prefix_len + else: + sen_scores[line_id] = sum(scores) + num_bpe_tokens_dict[line_id] = len(scores) + + sen_pos_scores[line_id] = scores + + return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict + + +def get_directories( + data_dir_name, + num_rescore, + gen_subset, + fw_name, + shard_id, + num_shards, + sampling=False, + prefix_len=None, + target_prefix_frac=None, + source_prefix_frac=None, +): + nbest_file_id = ( + "nbest_" + + str(num_rescore) + + "_subset_" + + gen_subset + + "_fw_name_" + + fw_name + + "_shard_" + + str(shard_id) + + "_of_" + + str(num_shards) + ) + + if sampling: + nbest_file_id += "_sampling" + + # the directory containing all information for this nbest list + pre_gen = ( + os.path.join(os.path.dirname(__file__)) + + "/rerank_data/" + + data_dir_name + + "/" + + nbest_file_id + ) + # the directory to store the preprocessed nbest list, for left to right rescoring + left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed" + if source_prefix_frac is not None: + left_to_right_preprocessed_dir = ( + left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac) + ) + # the directory to store the preprocessed nbest list, for right to left rescoring + right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed" + # the directory to store the preprocessed nbest list, for backwards rescoring + backwards_preprocessed_dir = pre_gen + "/backwards" + if target_prefix_frac is not None: + backwards_preprocessed_dir = ( + backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac) + ) + elif prefix_len is not None: + backwards_preprocessed_dir = ( + backwards_preprocessed_dir + "/prefix_" + str(prefix_len) + ) + + # the directory to store the preprocessed nbest list, for rescoring with P(T) + lm_preprocessed_dir = pre_gen + "/lm_preprocessed" + + return ( + pre_gen, + left_to_right_preprocessed_dir, + right_to_left_preprocessed_dir, + backwards_preprocessed_dir, + lm_preprocessed_dir, + ) + + +def lm_scoring( + preprocess_directory, + bpe_status, + gen_output, + pre_gen, + cur_lm_dict, + cur_lm_name, + cur_language_model, + cur_lm_bpe_code, + batch_size, + lm_score_file, + target_lang, + source_lang, + prefix_len=None, +): + if prefix_len is not None: + assert ( + bpe_status == "different" + ), "bpe status must be different to use prefix len" + if bpe_status == "no bpe": + # run lm on output without bpe + write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + pre_gen + "/rescore_data_no_bpe.de", + pre_gen + "/rescore_data_no_bpe.en", + pre_gen + "/reference_file_no_bpe", + ) + + preprocess_lm_param = [ + "--only-source", + "--trainpref", + pre_gen + "/rescore_data_no_bpe." + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_directory, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [ + preprocess_directory, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--max-tokens", + "1024", + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, "w") as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + elif bpe_status == "shared": + preprocess_lm_param = [ + "--only-source", + "--trainpref", + pre_gen + "/rescore_data." + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_directory, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [ + preprocess_directory, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, "w") as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + elif bpe_status == "different": + rescore_file = pre_gen + "/rescore_data_no_bpe" + rescore_bpe = pre_gen + "/rescore_data_new_bpe" + + rescore_file += "." + rescore_bpe += "." + + write_reprocessed( + gen_output.no_bpe_source, + gen_output.no_bpe_hypo, + gen_output.no_bpe_target, + rescore_file + source_lang, + rescore_file + target_lang, + pre_gen + "/reference_file_no_bpe", + bpe_symbol=None, + ) + + # apply LM bpe to nbest list + bpe_src_param = [ + "-c", + cur_lm_bpe_code, + "--input", + rescore_file + target_lang, + "--output", + rescore_bpe + target_lang, + ] + subprocess.call( + [ + "python", + os.path.join( + os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py" + ), + ] + + bpe_src_param, + shell=False, + ) + # uncomment to use fastbpe instead of subword-nmt bpe + # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code] + # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False) + + preprocess_dir = preprocess_directory + + preprocess_lm_param = [ + "--only-source", + "--trainpref", + rescore_bpe + target_lang, + "--srcdict", + cur_lm_dict, + "--destdir", + preprocess_dir, + ] + preprocess_parser = options.get_preprocessing_parser() + input_args = preprocess_parser.parse_args(preprocess_lm_param) + preprocess.main(input_args) + + eval_lm_param = [ + preprocess_dir, + "--path", + cur_language_model, + "--output-word-probs", + "--batch-size", + str(batch_size), + "--max-tokens", + "1024", + "--sample-break-mode", + "eos", + "--gen-subset", + "train", + ] + + eval_lm_parser = options.get_eval_lm_parser() + input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param) + + with open(lm_score_file, "w") as f: + with redirect_stdout(f): + eval_lm.main(input_args) + + +def rescore_file_name( + nbest_dir, + prefix_len, + scorer_name, + lm_file=False, + target_prefix_frac=None, + source_prefix_frac=None, + backwards=None, +): + if lm_file: + score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt" + else: + score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt" + if backwards: + if prefix_len is not None: + score_file += "prefix_len" + str(prefix_len) + elif target_prefix_frac is not None: + score_file += "target_prefix_frac" + str(target_prefix_frac) + else: + if source_prefix_frac is not None: + score_file += "source_prefix_frac" + str(source_prefix_frac) + return score_file diff --git a/fairseq/examples/nonautoregressive_translation/README.md b/fairseq/examples/nonautoregressive_translation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8793e225c99732c42c9c19e22075cde37c73341d --- /dev/null +++ b/fairseq/examples/nonautoregressive_translation/README.md @@ -0,0 +1,146 @@ +# Non-autoregressive Neural Machine Translation (NAT) + +This page mainly includes instructions for reproducing results from the following papers +* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006). +* [Understanding Knowledge Distillation in Non-autoregressive Machine Translation (Zhou et al., 2019)](https://arxiv.org/abs/1911.02727). + +We also provided our own implementations for several popular non-autoregressive-based models as reference:<br> +* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)<br> +* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al., 2018)](https://arxiv.org/abs/1802.06901)<br> +* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al., 2019)](https://arxiv.org/abs/1902.03249)<br> +* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2)<br> +* [Fast Structured Decoding for Sequence Models (Sun et al., 2019)](https://arxiv.org/abs/1910.11555) + +## Dataset + +First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#wmt14-english-to-german-convolutional). +Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`. + +### Knowledge Distillation +Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations. +The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT. + +### Download +We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own. + + +## Train a model + +Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`. +Use the `--noise` flag to specify the input noise used on the target sentences. +In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md). + +The following command will train a *Levenshtein Transformer* on the binarized dataset. + +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch levenshtein_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +## Translate + +Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence. + +For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations. + +For example, to generate with `--iter-decode-max-iter=9`: +```bash +fairseq-generate \ + data-bin/wmt14_en_de_distill \ + --gen-subset test \ + --task translation_lev \ + --path checkpoints/checkpoint_best.pt \ + --iter-decode-max-iter 9 \ + --iter-decode-eos-penalty 0 \ + --beam 1 --remove-bpe \ + --print-step \ + --batch-size 400 +``` +In the end of the generation, we can see the tokenized BLEU score for the translation. + +## Advanced Decoding Methods +### Ensemble +The NAT models use special implementations of [ensembling](https://github.com/fairinternal/fairseq-py/blob/b98d88da52f2f21f1b169bab8c70c1c4ca19a768/fairseq/sequence_generator.py#L522) to support iterative refinement and a variety of parallel operations in different models, while it shares the same API as standard autoregressive models as follows: +```bash +fairseq-generate \ + data-bin/wmt14_en_de_distill \ + --gen-subset test \ + --task translation_lev \ + --path checkpoint_1.pt:checkpoint_2.pt:checkpoint_3.pt \ + --iter-decode-max-iter 9 \ + --iter-decode-eos-penalty 0 \ + --beam 1 --remove-bpe \ + --print-step \ + --batch-size 400 +``` +We use ``:`` to split multiple models. Note that, not all NAT models support ensembling for now. + + +### Length-beam +For models that predict lengths before decoding (e.g. the vanilla NAT, Mask-Predict, etc), it is possible to improve the translation quality by varying the target lengths around the predicted value, and translating the same example multiple times in parallel. We can select the best translation with the highest scores defined by your model's output. + +Note that, not all models support length beams. For models which dynamically change the lengths (e.g. *Insertion Transformer*, *Levenshtein Transformer*), the same trick does not apply. + +### Re-ranking +If the model generates multiple translations with length beam, we can also introduce an autoregressive model to rerank the translations considering scoring from an autoregressive model is much faster than decoding from that. + +For example, to generate translations with length beam and reranking, +```bash +fairseq-generate \ + data-bin/wmt14_en_de_distill \ + --gen-subset test \ + --task translation_lev \ + --path checkpoints/checkpoint_best.pt:at_checkpoints/checkpoint_best.pt \ + --iter-decode-max-iter 9 \ + --iter-decode-eos-penalty 0 \ + --iter-decode-with-beam 9 \ + --iter-decode-with-external-reranker \ + --beam 1 --remove-bpe \ + --print-step \ + --batch-size 100 +``` +Note that we need to make sure the autoregressive model shares the same vocabulary as our target non-autoregressive model. + + +## Citation + +```bibtex +@incollection{NIPS2019_9297, + title = {Levenshtein Transformer}, + author = {Gu, Jiatao and Wang, Changhan and Zhao, Junbo}, + booktitle = {Advances in Neural Information Processing Systems 32}, + editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, + pages = {11179--11189}, + year = {2019}, + publisher = {Curran Associates, Inc.}, + url = {http://papers.nips.cc/paper/9297-levenshtein-transformer.pdf} +} +``` +```bibtex +@article{zhou2019understanding, + title={Understanding Knowledge Distillation in Non-autoregressive Machine Translation}, + author={Zhou, Chunting and Neubig, Graham and Gu, Jiatao}, + journal={arXiv preprint arXiv:1911.02727}, + year={2019} +} +``` diff --git a/fairseq/examples/nonautoregressive_translation/scripts.md b/fairseq/examples/nonautoregressive_translation/scripts.md new file mode 100644 index 0000000000000000000000000000000000000000..9d3d7b67dc08440b5f4d1c5a7ffcd4bd6e76c14f --- /dev/null +++ b/fairseq/examples/nonautoregressive_translation/scripts.md @@ -0,0 +1,179 @@ +# Examples of Training scripts for Non-autoregressive Machine Translation models + +### Non-autoregressive Transformer (NAT, Gu et al., 2017) +Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence. +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch nonautoregressive_transformer \ + --noise full_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +### Fast Structured Decoding for Sequence Models (NAT-CRF, Sun et al., 2019) +Note that we implemented a low-rank appromixated CRF model by setting `--crf-lowrank-approx=32` and `--crf-beam-approx=64` as discribed in the original paper. All other settings are the same as the vanilla NAT model. +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch nacrf_transformer \ + --noise full_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --word-ins-loss-factor 0.5 \ + --crf-lowrank-approx 32 \ + --crf-beam-approx 64 \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + + +### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018) +Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper. +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch iterative_nonautoregressive_transformer \ + --noise full_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --pred-length-offset \ + --length-loss-factor 0.1 \ + --train-step 4 \ + --dae-ratio 0.5 \ + --stochastic-approx \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + +### Insertion Transformer (InsT, Stern et al., 2019) +Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature. + +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch insertion_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + + +### Mask Predict (CMLM, Ghazvininejad et al., 2019) +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch cmlm_transformer \ + --noise random_mask \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` + + + + +### Levenshtein Transformer (LevT, Gu et al., 2019) +```bash +fairseq-train \ + data-bin/wmt14_en_de_distill \ + --save-dir checkpoints \ + --ddp-backend=legacy_ddp \ + --task translation_lev \ + --criterion nat_loss \ + --arch levenshtein_transformer \ + --noise random_delete \ + --share-all-embeddings \ + --optimizer adam --adam-betas '(0.9,0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ + --warmup-init-lr '1e-07' --label-smoothing 0.1 \ + --dropout 0.3 --weight-decay 0.01 \ + --decoder-learned-pos \ + --encoder-learned-pos \ + --apply-bert-init \ + --log-format 'simple' --log-interval 100 \ + --fixed-validation-seed 7 \ + --max-tokens 8000 \ + --save-interval-updates 10000 \ + --max-update 300000 +``` diff --git a/fairseq/examples/normformer/README.md b/fairseq/examples/normformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..037b453ff1a4b62521da11e8d6c387d2045280e1 --- /dev/null +++ b/fairseq/examples/normformer/README.md @@ -0,0 +1,70 @@ +### NormFormer +This is the code for the ["NormFormer: Improved Transformer Pretraining with Extra Normalization"](https://arxiv.org/abs/2110.09456) +- 2021-10-19: Commands for CLM Experiments +- Coming soon: Commands for MLM experiments + +If you have any issues or questions please post a github issue and tag `@sshleifer`. + + +### Data +- To preprocess language modeling data, see [here](https://github.com/pytorch/fairseq/blob/d0fbcb0baef6f6ff3425ded62d8daea0e8b12114/examples/language_model/README.md#1-preprocess-the-data). +- The replication commands below expect `$DATA` to be the path to the binarized data directory. +- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset and compare to a baseline on the same data, rather than Table 2. +- The code uses `FSDP`, which requires `pip install fairscale>=0.4.0`. + + +### Modify existing Command +To modify an existing `fairseq-train` command to use NormFormer, simply add the following flags: +```bash +fairseq-train ... \ + --scale-attn --scale-fc --scale-heads +``` +- you probably also want to increase your learning rate +- if your model is small, you may want to add `--scale-resids` + +### Exact Training Commands + +- Note that NormFormer results in Table 2 use a much larger private dataset, and to get good results you should adapt the pre-processing instructions to your dataset. +The full commands are functions defined here, so to run them you must `source examples/normformer/train_lm.sh`. +- We default `--distributed-world-size 8`. You should adjust `--update-freq` and `--batch-size` and such that the effective batch size is (1024x1024x0.5) tokens for 125M and 355M, + and (1024x1024) for 1.3B parameter and above. For small models, `--update-freq`=256/`global_bs`. For large models, `--update-freq`=512/`global_bs`, where `global_bs` = `--batch-size` * `--distributed-world-size` +- The small models will all train on as few as 8 GPUs. + +```bash +train_125M --lr 6e-4 # GPT-3 Replicated +train_125M --lr 1e-3 # stronger high-lr baseline +train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads # No scale-resids +train_125M --lr 3e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Best command +``` + +```bash +train_355M --lr 6e-4 # GPT-3 Replicated +train_355M --lr 1e-3 # stronger high-lr baseline +train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads # No scale-resids +train_355M --lr 1e-3 --scale-attn --scale-fc --scale-heads --scale-resids # Slightly better +``` + +```bash +train_1.3B --lr 2e-4 # GPT-3 Replicated +train_1.3B --lr 6e-4 # stronger high-lr baseline +train_1.3B --lr 6e-4 --scale-attn --scale-fc --scale-heads # NormFormer +``` + +```bash +train_2.7B --lr 1.6e-4 # GPT-3 Replicated +train_2.7B --lr 1.6e-4 --activation-fn relu_squared # stronger Relu^2 baseline +train_2.7B --lr 6e-4 --activation-fn relu_squared --scale-attn --scale-fc --scale-heads # NormFormer 2.7B +``` + + +### Citation +```bibtex +@misc{shleifer2021normformer, + title={NormFormer: Improved Transformer Pretraining with Extra Normalization}, + author={Sam Shleifer and Jason Weston and Myle Ott}, + year={2021}, + eprint={2110.09456}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/fairseq/examples/normformer/train_lm.sh b/fairseq/examples/normformer/train_lm.sh new file mode 100644 index 0000000000000000000000000000000000000000..b081f2ddd3acdb91d87d7fd080257cb2acf5cb1e --- /dev/null +++ b/fairseq/examples/normformer/train_lm.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash +train_common () { + fairseq-train "$DATA" \ + --combine-val \ + --train-subset train \ + --num-workers 2 \ + --validate-interval-updates 1000 \ + --save-interval-updates 1000 \ + --no-epoch-checkpoints \ + --ddp-backend fully_sharded \ + --memory-efficient-fp16 \ + --fp16-init-scale 4 \ + --checkpoint-activations \ + --arch transformer_lm_gpt \ + --activation-fn gelu \ + --share-decoder-input-output-embed \ + --task language_modeling \ + --sample-break-mode none \ + --tokens-per-sample 2048 \ + --optimizer adam --adam-betas "(0.9, 0.98)" \ + --adam-eps 1e-08 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay \ + --warmup-updates 750 \ + --dropout 0.1 \ + --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --batch-size 16 \ + --update-freq 2 \ + --required-batch-size-multiple 1 \ + --total-num-update 572204 \ + --max-update 572204 \ + --seed 1 \ + --log-format json --log-interval 1 \ + --distributed-world-size 8 --distributed-port 13177 \ + "$@" +} + +train_125M () { + train_common --decoder-layers 12 \ + --decoder-embed-dim 768 \ + --decoder-ffn-embed-dim 3072 \ + --decoder-attention-heads 12 "$@" +} + +train_355M () { + train_common --decoder-layers 24 \ + --decoder-embed-dim 1024\ + --decoder-ffn-embed-dim 4096 \ + --decoder-attention-heads 16 \ + --dropout 0.0 \ + --attention-dropout 0.0 \ + "$@" +} + +train_1.3B () { + train_common --decoder-layers 24 \ + --decoder-embed-dim 2048 \ + --decoder-ffn-embed-dim 8192 \ + --decoder-attention-heads 32 \ + --batch-size 4 \ + --update-freq 16 \ + --total-num-update 286102 \ + --max-update 286102 \ + "$@" +} + +train_2.7B () { + train_common --decoder-layers 32 \ + --decoder-embed-dim 2560 \ + --decoder-ffn-embed-dim 10240 \ + --decoder-attention-heads 32 \ + --batch-size 4 \ + --update-freq 16 \ + --total-num-update 286102 \ + --max-update 286102 \ + "$@" +} diff --git a/fairseq/examples/operators/alignment_train_cpu.cpp b/fairseq/examples/operators/alignment_train_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13c015308e320d9ce14909fcdcc135300538e705 --- /dev/null +++ b/fairseq/examples/operators/alignment_train_cpu.cpp @@ -0,0 +1,166 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <torch/extension.h> // @manual=//caffe2:torch_extension +#include <algorithm> + +namespace { + +template <typename T> +void exclusiveCumprod( + const T* p_choose, + T* cumprod_1mp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len) { + // cumprod_1mp = 1 - p_choose + for (uint32_t b = 0; b < bsz; b++) { + for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { + for (uint32_t src = 0; src < src_len; src++) { + uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; + cumprod_1mp[idx] = 1 - p_choose[idx]; + } + } + } + + // Implementing exclusive cumprod in the innermost dimension + // cumprod_1mp = cumprod(1 - p_choose) + // There is cumprod in pytorch, however there is no exclusive mode. + // cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] + // exclusive means + // cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] + for (uint32_t b = 0; b < bsz; b++) { + for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { + uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len; + T prev = cumprod_1mp[idx_offset]; + // index [b][tgt][0] + cumprod_1mp[idx_offset] = (T)1.0; + T curr; + for (uint32_t src = 1; src < src_len; src++) { + uint32_t idx = idx_offset + src; + curr = cumprod_1mp[idx]; + cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev; + prev = curr; + } + } + } +} + +template <typename T> +void clamp( + const T* cumprod_1mp, + T* cumprod_1mp_clamp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + T min_val, + T max_val) { + for (uint32_t b = 0; b < bsz; b++) { + for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { + for (uint32_t src = 0; src < src_len; src++) { + uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; + if (cumprod_1mp[idx] < min_val) { + cumprod_1mp_clamp[idx] = min_val; + } else if (cumprod_1mp[idx] > max_val) { + cumprod_1mp_clamp[idx] = max_val; + } else { + cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; + } + } + } + } +} + +template <typename T> +void alignmentTrainCPUImpl( + const T* p_choose, + T* alpha, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + float eps) { + // p_choose: bsz , tgt_len, src_len + // cumprod_1mp: bsz , tgt_len, src_len + // cumprod_1mp_clamp : bsz, tgt_len, src_len + // alpha: bsz + 1, tgt_len, src_len + + uint32_t elements = bsz * tgt_len * src_len; + T* cumprod_1mp = new T[elements]; + T* cumprod_1mp_clamp = new T[elements]; + + exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len); + clamp<T>( + cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); + + // ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) + + // Initialize alpha [:, 0, 0] + for (uint32_t b = 0; b < bsz; b++) { + alpha[b * tgt_len * src_len] = 1.0; + } + + for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { + for (uint32_t b = 0; b < bsz; b++) { + uint32_t alpha_idx, inout_idx; + T prev_scan = 0, curr_scan, out; + for (uint32_t src = 0; src < src_len; src++) { + // Apply scan/cumsum + if (tgt == 0) { + // alpha index is [b][tgt][src] + alpha_idx = b * tgt_len * src_len + src; + } else { + // alpha index is [b][tgt-1][src] + alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src; + } + // input index is [b][tgt][src] + inout_idx = b * tgt_len * src_len + tgt * src_len + src; + curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; + + out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx]; + alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0); + prev_scan = curr_scan; + } + } + } + + free(cumprod_1mp); + free(cumprod_1mp_clamp); +} + +void alignmentTrainCPU( + const torch::Tensor& p_choose, + torch::Tensor& alpha, + float eps) { + uint32_t bsz = p_choose.size(0); + uint32_t tgt_len = p_choose.size(1); + uint32_t src_len = p_choose.size(2); + + AT_DISPATCH_FLOATING_TYPES_AND2( + torch::ScalarType::Half, + torch::ScalarType::BFloat16, + p_choose.scalar_type(), + "alignmentCPUImpl", + [&]() { + alignmentTrainCPUImpl<scalar_t>( + p_choose.data_ptr<scalar_t>(), + alpha.data_ptr<scalar_t>(), + bsz, + tgt_len, + src_len, + eps); + }); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "alignment_train_cpu", + &alignmentTrainCPU, + "expected_alignment_from_p_choose (CPU)"); +} + +} // namespace diff --git a/fairseq/examples/operators/alignment_train_cuda.cpp b/fairseq/examples/operators/alignment_train_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..430e04813923074c5458c11e05dbdab879659bb5 --- /dev/null +++ b/fairseq/examples/operators/alignment_train_cuda.cpp @@ -0,0 +1,31 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "alignment_train_cuda.h" +#include "utils.h" + +namespace { + +void alignmentTrainCUDA( + const torch::Tensor& p_choose, + torch::Tensor& alpha, + float eps) { + CHECK_INPUT(p_choose); + CHECK_INPUT(alpha); + + alignmentTrainCUDAWrapper(p_choose, alpha, eps); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "alignment_train_cuda", + &alignmentTrainCUDA, + "expected_alignment_from_p_choose (CUDA)"); +} + +} // namespace diff --git a/fairseq/examples/operators/alignment_train_cuda.h b/fairseq/examples/operators/alignment_train_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..8289d1a69079138de6d73993ff117a59f000df37 --- /dev/null +++ b/fairseq/examples/operators/alignment_train_cuda.h @@ -0,0 +1,16 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include <torch/extension.h> // @manual=//caffe2:torch_extension + +void alignmentTrainCUDAWrapper( + const torch::Tensor& p_choose, + torch::Tensor& alpha, + float eps); diff --git a/fairseq/examples/operators/alignment_train_kernel.cu b/fairseq/examples/operators/alignment_train_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..efae7cc76f1cce8701dbb5bb7b5ff1e402993953 --- /dev/null +++ b/fairseq/examples/operators/alignment_train_kernel.cu @@ -0,0 +1,354 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <ATen/ATen.h> +#include <ATen/cuda/CUDAContext.h> // @manual=//caffe2/aten:ATen-cu +#include <cuda_runtime.h> +#include <algorithm> // std::min/max +#include <cub/cub.cuh> + +#include "alignment_train_cuda.h" +#include "utils.h" + +namespace { + +// The thread block length in threads along the X dimension +constexpr int BLOCK_DIM_X = 128; +// The thread block length in threads along the Y dimension +constexpr int BLOCK_DIM_Y = 8; +// The thread block length in threads for scan operation +constexpr int SCAN_BLOCK = 512; + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } + +inline void +gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { + if (code != cudaSuccess) { + fprintf( + stderr, + "\nGPUassert: %s %s %d\n", + cudaGetErrorString(code), + file, + line); + if (abort) + exit(code); + } +} + +template <typename T> +struct Prod { + /// prod operator, returns <tt>a * b</tt> + __host__ __device__ __forceinline__ T + operator()(const T& a, const T& b) const { + return a * b; + } +}; + +template <typename T> +struct BlockPrefixProdCallbackOp { + // Running prefix + T running_total; + + // Constructor + __device__ BlockPrefixProdCallbackOp(T running_total) + : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ T operator()(const T block_aggregate) { + T old_prefix = running_total; + running_total *= block_aggregate; + return old_prefix; + } +}; + +template <typename T> +struct BlockPrefixSumCallbackOp { + // Running prefix + T running_total; + + // Constructor + __device__ BlockPrefixSumCallbackOp(T running_total) + : running_total(running_total) {} + + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide + // scan. + __device__ T operator()(const T block_aggregate) { + T old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } +}; + +template <typename T> +__global__ void oneMinusPKernel( + const T* __restrict__ p_choose, + T* __restrict__ cumprod_1mp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len) { + for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { + for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { + for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { + uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; + cumprod_1mp[idx] = 1 - p_choose[idx]; + } + } + } +} + +template <typename T, int TPB> +__global__ void innermostScanKernel( + T* __restrict__ cumprod_1mp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len) { + for (uint32_t b = blockIdx.y; b < bsz; b += gridDim.y) { + for (uint32_t tgt = blockIdx.x; tgt < tgt_len; tgt += gridDim.x) { + // Specialize BlockScan for a 1D block of TPB threads on type T + typedef cub::BlockScan<T, TPB> BlockScan; + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + // Initialize running total + BlockPrefixProdCallbackOp<T> prefix_op(1); + + const uint32_t tid = threadIdx.x; + for (uint32_t block_src = 0; block_src < src_len; + block_src += blockDim.x) { + uint32_t src = block_src + tid; + uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; + T thread_data = (src < src_len) ? cumprod_1mp[idx] : (T)0; + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage) + .ExclusiveScan(thread_data, thread_data, Prod<T>(), prefix_op); + __syncthreads(); + + // write the scanned value to output + if (src < src_len) { + cumprod_1mp[idx] = thread_data; + } + } + } + } +} + +template <typename T> +__global__ void clampKernel( + const T* __restrict__ cumprod_1mp, + T* __restrict__ cumprod_1mp_clamp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + T min_val, + T max_val) { + for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { + for (uint32_t tgt = threadIdx.y; tgt < tgt_len; tgt += blockDim.y) { + for (uint32_t src = threadIdx.x; src < src_len; src += blockDim.x) { + uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; + if (cumprod_1mp[idx] < min_val) { + cumprod_1mp_clamp[idx] = min_val; + } else if (cumprod_1mp[idx] > max_val) { + cumprod_1mp_clamp[idx] = max_val; + } else { + cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; + } + } + } + } +} + +template <typename T> +__global__ void initAlphaCUDAKernel( + T* alpha, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len) { + // alpha[:, 0, 0] = 1.0 + for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { + alpha[b * tgt_len * src_len] = (T)1.0; + } +} + +template <typename T, int TPB> +__global__ void alignmentTrainCUDAKernel( + const T* __restrict__ p_choose, + const T* __restrict__ cumprod_1mp, + const T* __restrict__ cumprod_1mp_clamp, + T* __restrict__ alpha, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + uint32_t tgt) { + for (uint32_t b = blockIdx.x; b < bsz; b += gridDim.x) { + // Specialize BlockScan for a 1D block of TPB threads on type T + typedef cub::BlockScan<T, TPB> BlockScan; + + // Allocate shared memory for BlockScan + __shared__ typename BlockScan::TempStorage temp_storage; + // Initialize running total + BlockPrefixSumCallbackOp<T> prefix_op(0); + + uint32_t b_offset = b * tgt_len * src_len; + const uint32_t tid = threadIdx.x; + for (uint32_t block_src = 0; block_src < src_len; block_src += blockDim.x) { + uint32_t src = block_src + tid; + // Obtain a segment of consecutive items that are blocked across threads + uint32_t inout_idx, alpha_idx; + if (tgt == 0) { + // both alpha and other input index is [b][0][src] + alpha_idx = b_offset + src; + } else { + // alpha index is [b][tgt-1][src] + alpha_idx = b_offset + (tgt - 1) * src_len + src; + } + inout_idx = b_offset + tgt * src_len + src; + T thread_data = (T)0; + if (src < src_len) { + thread_data = alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; + } + + // Collectively compute the block-wide inclusive prefix sum + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op); + __syncthreads(); + + if (src < src_len) { + T out = thread_data * p_choose[inout_idx] * cumprod_1mp[inout_idx]; + // Clamps all elements into the range [ 0, 1.0 ] + alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), (T)1.0); + } + } + } +} + +template <typename T> +void exclusiveCumprod( + const T* p_choose, + T* cumprod_1mp, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + uint32_t max_grid_x, + uint32_t max_grid_y, + cudaStream_t& stream) { + // cumprod_1mp = 1 - p_choose + dim3 grid(std::min<T>(max_grid_x, bsz), 1, 1); + dim3 block(BLOCK_DIM_X, BLOCK_DIM_Y, 1); + oneMinusPKernel<T><<<grid, block, 0, stream>>>( + p_choose, cumprod_1mp, bsz, tgt_len, src_len); + gpuErrchk(cudaGetLastError()); + + // scan on the innermost dimension of cumprod_1mp + // cumprod_1mp = cumprod(cumprod_1mp) + dim3 grid_scan( + std::min<T>(max_grid_x, tgt_len), std::min<T>(max_grid_y, bsz), 1); + innermostScanKernel<T, SCAN_BLOCK><<<grid_scan, SCAN_BLOCK, 0, stream>>>( + cumprod_1mp, bsz, tgt_len, src_len); + gpuErrchk(cudaGetLastError()); +} + +template <typename T> +void alignmentTrainCUDAImpl( + const T* p_choose, + T* alpha, + uint32_t bsz, + uint32_t tgt_len, + uint32_t src_len, + float eps) { + // p_choose: bsz , tgt_len, src_len + // cumprod_1mp: bsz , tgt_len, src_len + // cumprod_1mp_clamp : bsz, tgt_len, src_len + // alpha: bsz, tgt_len, src_len + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + uint32_t max_grid_x = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + uint32_t max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + + // Implementing exclusive cumprod. + // cumprod_1mp = cumprod(1 - p_choose) + // There is cumprod in pytorch, however there is no exclusive mode. + // cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] + // exclusive means + // cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] + uint32_t elements = bsz * tgt_len * src_len; + T* cumprod_1mp; + gpuErrchk(cudaMalloc(&cumprod_1mp, elements * sizeof(T))); + exclusiveCumprod<T>( + p_choose, + cumprod_1mp, + bsz, + tgt_len, + src_len, + max_grid_x, + max_grid_y, + stream); + + // clamp cumprod_1mp to the range [eps, 1.0] + T* cumprod_1mp_clamp; + gpuErrchk(cudaMalloc(&cumprod_1mp_clamp, elements * sizeof(T))); + dim3 grid_clamp(std::min<T>(max_grid_x, bsz), 1, 1); + dim3 block_clamp(BLOCK_DIM_X, BLOCK_DIM_Y, 1); + clampKernel<T><<<grid_clamp, block_clamp, 0, stream>>>( + cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); + gpuErrchk(cudaGetLastError()); + + // ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) + dim3 grid_init(std::min<int>(max_grid_x, bsz), 1, 1); + initAlphaCUDAKernel<T> + <<<grid_init, 1, 0, stream>>>(alpha, bsz, tgt_len, src_len); + gpuErrchk(cudaGetLastError()); + + const int grid = std::min(bsz, max_grid_x); + + for (uint32_t i = 0; i < tgt_len; i++) { + alignmentTrainCUDAKernel<T, SCAN_BLOCK><<<grid, SCAN_BLOCK, 0, stream>>>( + p_choose, + cumprod_1mp, + cumprod_1mp_clamp, + alpha, + bsz, + tgt_len, + src_len, + i); + gpuErrchk(cudaGetLastError()); + } + + gpuErrchk(cudaFree(cumprod_1mp)); + gpuErrchk(cudaFree(cumprod_1mp_clamp)); +} + +} // namespace + +void alignmentTrainCUDAWrapper( + const torch::Tensor& p_choose, + torch::Tensor& alpha, + float eps) { + // p_choose dimension: bsz, tgt_len, src_len + uint32_t bsz = p_choose.size(0); + uint32_t tgt_len = p_choose.size(1); + uint32_t src_len = p_choose.size(2); + + cudaSetDevice(p_choose.get_device()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + torch::ScalarType::Half, + torch::ScalarType::BFloat16, + p_choose.scalar_type(), + "alignmentTrainCUDAImpl", + [&]() { + alignmentTrainCUDAImpl<scalar_t>( + p_choose.data_ptr<scalar_t>(), + alpha.data_ptr<scalar_t>(), + bsz, + tgt_len, + src_len, + eps); + }); +} diff --git a/fairseq/examples/operators/utils.h b/fairseq/examples/operators/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0ef5b4383f4a8ff05ea7b44d932a3b47e4c7f927 --- /dev/null +++ b/fairseq/examples/operators/utils.h @@ -0,0 +1,19 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include <torch/extension.h> // @manual=//caffe2:torch_extension + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) diff --git a/fairseq/examples/paraphraser/README.md b/fairseq/examples/paraphraser/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3810311f30f99f0a07fd8e5d3723bffeba9948c3 --- /dev/null +++ b/fairseq/examples/paraphraser/README.md @@ -0,0 +1,46 @@ +# Paraphrasing with round-trip translation and mixture of experts + +Machine translation models can be used to paraphrase text by translating it to +an intermediate language and back (round-trip translation). + +This example shows how to paraphrase text by first passing it to an +English-French translation model, followed by a French-English [mixture of +experts translation model](/examples/translation_moe). + +##### 0. Setup + +Clone fairseq from source and install necessary dependencies: +```bash +git clone https://github.com/pytorch/fairseq.git +cd fairseq +pip install --editable . +pip install sacremoses sentencepiece +``` + +##### 1. Download models +```bash +wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz +wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz +tar -xzvf paraphraser.en-fr.tar.gz +tar -xzvf paraphraser.fr-en.hMoEup.tar.gz +``` + +##### 2. Paraphrase +```bash +python examples/paraphraser/paraphrase.py \ + --en2fr paraphraser.en-fr \ + --fr2en paraphraser.fr-en.hMoEup +# Example input: +# The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules. +# Example outputs: +# Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule. +# The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule. +# The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule. +# The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule. +# The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training. +# The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule. +``` diff --git a/fairseq/examples/paraphraser/paraphrase.py b/fairseq/examples/paraphraser/paraphrase.py new file mode 100644 index 0000000000000000000000000000000000000000..d3422fb3db9a381b73a854d2379df214ebe544a2 --- /dev/null +++ b/fairseq/examples/paraphraser/paraphrase.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 -u + +import argparse +import fileinput +import logging +import os +import sys + +from fairseq.models.transformer import TransformerModel + + +logging.getLogger().setLevel(logging.INFO) + + +def main(): + parser = argparse.ArgumentParser(description="") + parser.add_argument("--en2fr", required=True, help="path to en2fr model") + parser.add_argument( + "--fr2en", required=True, help="path to fr2en mixture of experts model" + ) + parser.add_argument( + "--user-dir", help="path to fairseq examples/translation_moe/src directory" + ) + parser.add_argument( + "--num-experts", + type=int, + default=10, + help="(keep at 10 unless using a different model)", + ) + parser.add_argument( + "files", + nargs="*", + default=["-"], + help='input files to paraphrase; "-" for stdin', + ) + args = parser.parse_args() + + if args.user_dir is None: + args.user_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ + "translation_moe", + "src", + ) + if os.path.exists(args.user_dir): + logging.info("found user_dir:" + args.user_dir) + else: + raise RuntimeError( + "cannot find fairseq examples/translation_moe/src " + "(tried looking here: {})".format(args.user_dir) + ) + + logging.info("loading en2fr model from:" + args.en2fr) + en2fr = TransformerModel.from_pretrained( + model_name_or_path=args.en2fr, + tokenizer="moses", + bpe="sentencepiece", + ).eval() + + logging.info("loading fr2en model from:" + args.fr2en) + fr2en = TransformerModel.from_pretrained( + model_name_or_path=args.fr2en, + tokenizer="moses", + bpe="sentencepiece", + user_dir=args.user_dir, + task="translation_moe", + ).eval() + + def gen_paraphrases(en): + fr = en2fr.translate(en) + return [ + fr2en.translate(fr, inference_step_args={"expert": i}) + for i in range(args.num_experts) + ] + + logging.info("Type the input sentence and press return:") + for line in fileinput.input(args.files): + line = line.strip() + if len(line) == 0: + continue + for paraphrase in gen_paraphrases(line): + print(paraphrase) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/pay_less_attention_paper/README.md b/fairseq/examples/pay_less_attention_paper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5adab11f4dc3461f9e7126ac391b04e703616e6b --- /dev/null +++ b/fairseq/examples/pay_less_attention_paper/README.md @@ -0,0 +1,176 @@ +# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019) + +This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430). + +## Citation: +```bibtex +@inproceedings{wu2018pay, + title = {Pay Less Attention with Lightweight and Dynamic Convolutions}, + author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli}, + booktitle = {International Conference on Learning Representations}, + year = {2019}, + url = {https://arxiv.org/abs/1901.10430}, +} +``` + +## Translation + +### Pre-trained models +For some datasets we release models without GLUs which are faster at inference. + +Model | Description | Dataset | Download +---|---|---|--- +`lightconv.no_glu.iwslt14.de-en` | LightConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2) +`dynamicconv.no_glu.iwslt14.de-en` | DynamicConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2) +`lightconv.no_glu.wmt16.en-de` | LightConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +`dynamicconv.no_glu.wmt16.en-de` | DynamicConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +`lightconv.glu.wmt16.en-de` | LightConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +`dynamicconv.glu.wmt16.en-de` | DynamicConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +`lightconv.glu.wmt14.en-fr` | LightConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +`dynamicconv.glu.wmt14.en-fr` | DynamicConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +`lightconv.glu.wmt17.zh-en` | LightConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2) +`dynamicconv.glu.wmt17.zh-en` | DynamicConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2) + +### Memory-Efficient CUDA Kernels + +Since the PyTorch implementations of Light/Dynamic conv are quite memory intensive, we have developed CUDA kernels that implement the light and dynamic convolution operator in a memory-efficient and performant manner. For large sequence lengths, these kernels save about 50% memory compared to the PyTorch equivalent. + +To install the kernels, use the commands below. Once installed, they will automatically be used in place of the PyTorch implementations whenever a light or dynamic convolution is used. + +```sh +# to install lightconv +cd fairseq/modules/lightconv_layer +python cuda_function_gen.py +python setup.py install + +# to install dynamicconv +cd fairseq/modules/dynamicconv_layer +python cuda_function_gen.py +python setup.py install +``` + +### Example usage (torch.hub) + +We require a few additional Python dependencies for preprocessing: +```bash +pip install sacremoses subword_nmt +``` + +Interactive translation via PyTorch Hub: +```python +import torch + +# List available models +torch.hub.list('pytorch/fairseq') # [..., 'lightconv.glu.wmt17.zh-en', ... ] + +# Load a transformer trained on WMT'16 En-De +zh2en = torch.hub.load('pytorch/fairseq', 'lightconv.glu.wmt17.zh-en', tokenizer='moses', bpe='subword_nmt') + +# The underlying model is available under the *models* attribute +assert isinstance(zh2en.models[0], fairseq.models.lightconv.LightConvModel) + +# Translate a sentence +zh2en.translate('你好 世界') +# 'Hello World' +``` + +Loading custom models: +```python +from fairseq.models.lightconv import LightConvModel +en2fr = LightConvModel.from_pretrained( + '/path/to/checkpoints', + checkpoint_file='checkpoint_best.pt', + data_name_or_path='data-bin/wmt14_en_fr', + bpe='subword_nmt', + bpe_codes='data-bin/wmt14_en_fr/en.code' +) +en2fr.translate('Hello world!') +# 'Bonjour le monde' +``` + +### Preprocessing the training datasets + +Please follow the instructions in [`examples/translation/README.md`](../translation/README.md) to preprocess the data. + +### Training and evaluation options: +To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`. +For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv. +For best BLEU results, lenpen may need to be manually tuned. + +To use the CUDA kernels, first install the PyTorch modules using the commands +above. Once the CUDA modules are installed, they will automatically be used +instead of the PyTorch modules. + +### IWSLT14 De-En +Training and evaluating DynamicConv (without GLU) on a GPU: +```sh +# Training +SAVE="save/dynamic_conv_iwslt" +mkdir -p $SAVE +CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \ + --clip-norm 0 --optimizer adam --lr 0.0005 \ + --source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \ + --log-interval 100 --stop-min-lr '1e-09' --weight-decay 0.0001 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --lr-scheduler inverse_sqrt \ + --ddp-backend=legacy_ddp \ + --max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \ + --adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \ + -a lightconv_iwslt_de_en --save-dir $SAVE \ + --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 0 --decoder-glu 0 +python scripts/average_checkpoints.py --inputs $SAVE \ + --num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt" + +# Evaluation +CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet +``` + +### WMT16 En-De +Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine scheduler on one machine with 8 V100 GPUs: +```sh +# Training +SAVE="save/dynamic_conv_wmt16en2de" +mkdir -p $SAVE +python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ + data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \ + --max-update 30000 --share-all-embeddings --optimizer adam \ + --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --ddp-backend=legacy_ddp --max-tokens 3584 \ + --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ + --t-mult 1 --lr-period-updates 20000 \ + --arch lightconv_wmt_en_de_big --save-dir $SAVE \ + --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 1 --decoder-glu 1 + +# Evaluation +CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt +bash scripts/compound_split_bleu.sh wmt16_gen.txt +``` + +### WMT14 En-Fr +Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one machine with 8 V100 GPUs: +```sh +# Training +SAVE="save/dynamic_conv_wmt14en2fr" +mkdir -p $SAVE +python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ + data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \ + --max-update 30000 --share-all-embeddings --optimizer adam \ + --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --ddp-backend=legacy_ddp --max-tokens 3584 \ + --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ + --t-mult 1 --lr-period-updates 70000 \ + --arch lightconv_wmt_en_fr_big --save-dir $SAVE \ + --dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \ + --encoder-glu 1 --decoder-glu 1 + +# Evaluation +CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test +``` diff --git a/fairseq/examples/pointer_generator/README.md b/fairseq/examples/pointer_generator/README.md new file mode 100644 index 0000000000000000000000000000000000000000..60965708254aae2174812ea6686a9807825b7fb6 --- /dev/null +++ b/fairseq/examples/pointer_generator/README.md @@ -0,0 +1,82 @@ +# Transformer with Pointer-Generator Network + +This page describes the `transformer_pointer_generator` model that incorporates +a pointing mechanism in the Transformer model that facilitates copying of input +words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/). + +## Background + +The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368) +for RNN encoder-decoder attention models. A similar mechanism can be +incorporated in a Transformer model by reusing one of the many attention +distributions for pointing. The attention distribution over the input words is +interpolated with the normal output distribution over the vocabulary words. This +allows the model to generate words that appear in the input, even if they don't +appear in the vocabulary, helping especially with small vocabularies. + +## Implementation + +The mechanism for copying out-of-vocabulary words from the input has been +implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator) +they convey the word identities through the model in order to be able to produce +words that appear in the input sequence but not in the vocabulary. A different +approach was taken in the Fairseq implementation to keep it self-contained in +the model file, avoiding any changes to the rest of the code base. Copying +out-of-vocabulary words is possible by pre-processing the input and +post-processing the output. This is described in detail in the next section. + +## Usage + +The training and evaluation procedure is outlined below. You can also find a +more detailed example for the XSum dataset on [this page](README.xsum.md). + +##### 1. Create a vocabulary and extend it with source position markers + +The pointing mechanism is especially helpful with small vocabularies, if we are +able to recover the identities of any out-of-vocabulary words that are copied +from the input. For this purpose, the model allows extending the vocabulary with +special tokens that can be used in place of `<unk>` tokens to identify different +input positions. For example, the user may add `<unk-0>`, `<unk-1>`, `<unk-2>`, +etc. to the end of the vocabulary, after the normal words. Below is an example +of how to create a vocabulary of 10000 most common words and add 1000 input +position markers. + +```bash +vocab_size=10000 +position_markers=1000 +export LC_ALL=C +cat train.src train.tgt | + tr -s '[:space:]' '\n' | + sort | + uniq -c | + sort -k1,1bnr -k2 | + head -n "$((vocab_size - 4))" | + awk '{ print $2 " " $1 }' >dict.pg.txt +python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt +``` + +##### 2. Preprocess the text data + +The idea is that any `<unk>` tokens in the text are replaced with `<unk-0>` if +it appears in the first input position, `<unk-1>` if it appears in the second +input position, and so on. This can be achieved using the `preprocess.py` script +that is provided in this directory. + +##### 3. Train a model + +The number of these special tokens is given to the model with the +`--source-position-markers` argument—the model simply maps all of these to the +same word embedding as `<unk>`. + +The attention distribution that is used for pointing is selected using the +`--alignment-heads` and `--alignment-layer` command-line arguments in the same +way as with the `transformer_align` model. + +##### 4. Generate text and postprocess it + +When using the model to generate text, you want to preprocess the input text in +the same way that training data was processed, replacing out-of-vocabulary words +with `<unk-N>` tokens. If any of these tokens are copied to the output, the +actual words can be retrieved from the unprocessed input text. Any `<unk-N>` +token should be replaced with the word at position N in the original input +sequence. This can be achieved using the `postprocess.py` script. diff --git a/fairseq/examples/pointer_generator/preprocess.py b/fairseq/examples/pointer_generator/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f72ca7d3d97e12ab7b405dcff314bdb6c0a78755 --- /dev/null +++ b/fairseq/examples/pointer_generator/preprocess.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from itertools import zip_longest + + +def replace_oovs(source_in, target_in, vocabulary, source_out, target_out): + """Replaces out-of-vocabulary words in source and target text with <unk-N>, + where N in is the position of the word in the source sequence. + """ + + def format_unk(pos): + return "<unk-{}>".format(pos) + + if target_in is None: + target_in = [] + + for seq_num, (source_seq, target_seq) in enumerate( + zip_longest(source_in, target_in) + ): + source_seq_out = [] + target_seq_out = [] + + word_to_pos = dict() + for position, token in enumerate(source_seq.strip().split()): + if token in vocabulary: + token_out = token + else: + if token in word_to_pos: + oov_pos = word_to_pos[token] + else: + word_to_pos[token] = position + oov_pos = position + token_out = format_unk(oov_pos) + source_seq_out.append(token_out) + source_out.write(" ".join(source_seq_out) + "\n") + + if target_seq is not None: + for token in target_seq.strip().split(): + if token in word_to_pos: + token_out = format_unk(word_to_pos[token]) + else: + token_out = token + target_seq_out.append(token_out) + if target_out is not None: + target_out.write(" ".join(target_seq_out) + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Replaces out-of-vocabulary words in both source and target " + "sequences with tokens that indicate the position of the word " + "in the source sequence." + ) + parser.add_argument( + "--source", type=str, help="text file with source sequences", required=True + ) + parser.add_argument( + "--target", type=str, help="text file with target sequences", default=None + ) + parser.add_argument("--vocab", type=str, help="vocabulary file", required=True) + parser.add_argument( + "--source-out", + type=str, + help="where to write source sequences with <unk-N> entries", + required=True, + ) + parser.add_argument( + "--target-out", + type=str, + help="where to write target sequences with <unk-N> entries", + default=None, + ) + args = parser.parse_args() + + with open(args.vocab, encoding="utf-8") as vocab: + vocabulary = vocab.read().splitlines() + + target_in = ( + open(args.target, "r", encoding="utf-8") if args.target is not None else None + ) + target_out = ( + open(args.target_out, "w", encoding="utf-8") + if args.target_out is not None + else None + ) + with open(args.source, "r", encoding="utf-8") as source_in, open( + args.source_out, "w", encoding="utf-8" + ) as source_out: + replace_oovs(source_in, target_in, vocabulary, source_out, target_out) + if target_in is not None: + target_in.close() + if target_out is not None: + target_out.close() + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/quant_noise/README.md b/fairseq/examples/quant_noise/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a04d7e4e8a077f11c9f63cfa3d1f20e2b899be8c --- /dev/null +++ b/fairseq/examples/quant_noise/README.md @@ -0,0 +1,298 @@ +# Training with Quantization Noise for Extreme Model Compression ({Fan\*, Stock\*} *et al.*, 2020) +This page contains information for how to train and quantize models with Quantization Noise, for both scalar quantization like `int8` and Iterative Product Quantization. +Check out our paper [here](https://arxiv.org/abs/2004.07320). + +Looking for pretrained models? They will be added shortly. +Looking for code to train vision models? We are working on open sourcing our code as part of ClassyVision. Please check back, but note that both the Scalar and Iterative Product Quantization counterparts of the `nn.Conv2d` module are already included in this release. + +**Contents**: +- [Walk through of code](#walk-through-the-code) +- [Reproduce NLP Results](#looking-to-reproduce-the-nlp-results-in-the-paper) +- [Reproduce Vision Results](#looking-to-reproduce-the-vision-results-in-the-paper) + + +## Citation +```bibtex +@article{fan2020training, + title={Training with Quantization Noise for Extreme Model Compression}, + author={Angela Fan* and Pierre Stock* and and Benjamin Graham and Edouard Grave and Remi Gribonval and Herve Jegou and Armand Joulin}, + year={2020}, + eprint={2004.07320}, + archivePrefix={arXiv}, + primaryClass={cs.ML} +} +``` + +## Walk through the code + +Training a model with Quant-Noise improves the performance in subsequent inference-time quantization by training models to be robust to quantization. This technique is useful for both scalar and product quantization methods, as well as multiple domains. We detail below our approach to train, quantize models and integrate our code to quantize your favorite models. + +### Scalar Quantization + +Unlike the section [Iterative Product Quantization](#iterative-product-quantization) which gives state-of-the-art compression, this section showcases the usefulness of our approach for simple scalar quantization baselines such as int8 using on-GPU Fake Quantization. + +#### Training + +Scalar quantization with Quant-Noise consists in randomly quantizing a proportion `p` of the weights during training. Scalar quantization is implemented [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/scalar) under the form of Fake Quantization, meaning that we emulate int8 on GPU by quantizing and de-quantizing both the weights and the activations. We rely on PyTorch's [quantization primitives](https://github.com/pytorch/pytorch/tree/master/torch/quantization). + +To train a model with Quant-Noise, add the following flag: +``` +--quant-noise-scalar 0.5 +``` +Large values of noise make the network easier to quantize but may result in higher non-quantized test and validation perplexities. + +#### Quantization + +When evaluating a network, all quantized modules and activation hooks automatically switch to `p=1` so the validation accuracy reported by Fairseq is actually the quantized one, nothing more to do. + + +#### Integration with your own code + +Looking to quantize your own models with Quant-Noise + Scalar Quantization? +- Use the function `quantize_model_` implemented [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/scalar/utils.py) to (1) replace all your modules by their quantized counterparts and (2) add hooks to those modules to quantize the activations. +- Then, perform your training as usual. Note that in `eval()` mode, the network is always fully quantized (weights and activations) by default (`p=1`). + + + +### Iterative Product Quantization + + +Iterative Product Quantization with Quant-Noise proceeds in two steps. First, a model must be trained uncompressed with Quant-Noise. Second, the model must be quantized with iPQ. Note that we implement here the simplest form of noise, which consists in randomly dropping a proportion `p` of blocks, and that worked as well as assigning those blocks to their current centroid. + +#### Training + +To train a model with Quant-Noise, add the following flags: +``` +--quant-noise-pq 0.1 --quant-noise-pq-block-size 8 +``` +`quant-noise-pq` controls how much dropout is applied to the blocks of the weight matrix. `quant-noise-pq-block-size` controls the size of the weight matrix blocks. +We recommend training with 0.05 to 0.2 Quant-Noise, a value that worked well in our experiments. For the block-size, we recommend training with block-size of 8. Note that the block size must be a multiple of `input_features`, see the size checks [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py). Large block sizes result in higher compression ratio but may induce a loss in accuracy. + +We currently support training Transformer based models, such as sequence-to-sequence, language models, and BERT architectures. The `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py) wraps a module. It splits a weight matrix into blocks and applies random dropout to these blocks. +In the Transformer architectures, quant-noise is applied to the input and output embeddings, the attention, and the FFN. + +Quant-Noise can also be combined with **LayerDrop** (see [here](https://github.com/pytorch/fairseq/tree/main/examples/layerdrop)) to add its pruning effect to the quantized model and make the model even smaller. We recommend training with LayerDrop 0.1 or 0.2. + +#### Quantization + +We implement an improved version of product quantization from Stock et al, **iPQ**, described [here](https://arxiv.org/abs/1907.05686), see code with old API [here](https://github.com/facebookresearch/kill-the-bits). Note that we improved the iPQ API in terms of both compute speed and usability as described below. + +For the particular case of PQ, quantization is made sequentially. We recommend first quantizing the FFNs, then the EMBs, and finally the ATTNs. Quantization is done in two sub-steps: +- First, perform `n` steps of Product Quantization (generally `n=20` is enough). +- Then, finetune the obtained centroids. + +#### Integration with your own code + +Looking to quantize your own models with Quant-Noise + iPQ? +- First wrap your modules with the `quant_noise` function [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quant_noise.py), which is module-agnostic and train your favorite model. +- Then, quantize your trained model using the code [here](https://github.com/pytorch/fairseq/tree/main/fairseq/modules/quantization/pq). This can be done *without any changes to your training loop*. Below is an example code for integration. +Note that we tried our approach only on Transformers and various Convolutional Models such as EfficientNets. + +```python +from fairseq.modules.quantization.pq import quantize_model_, SizeTracker + +# get configuration parameters +n_centroids_config = config["n_centroids"] +block_sizes_config = config["block_sizes"] +layers_to_quantize = config["layers_to_quantize"] + +# size tracker for keeping track of assignments, centroids and non-compressed sizes +size_tracker = SizeTracker(model) + +# Quantize model by stages +for step in range(len(layers_to_quantize)): + + # quantize model in-place + quantized_layers = quantize_model_( + model, + size_tracker, + layers_to_quantize, + block_sizes_config, + n_centroids_config, + step=step, + ) + logger.info(f"Finetuning stage {step}, quantized layers: {quantized_layers}") + logger.info(f"{size_tracker}") + + # Don't forget to re-create/update trainer/optimizer since model parameters have changed + optimizer = ... + + # Finetune the centroids with your usual training loop for a few epochs + trainer.train_epoch() +``` + + +## Looking to reproduce the NLP results in the paper? + +We detail below how to reproduce the state-of-the-art results in reported in the paper for Quant-Noise + Iterative Product Quantization. + +### Training with Quant-Noise + +To **train** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/roberta). +The following command can be used to train a RoBERTa Base + QuantNoise model: + +```bash +TOTAL_UPDATES=125000 +WARMUP_UPDATES=10000 +PEAK_LR=0.0005 +TOKENS_PER_SAMPLE=512 +MAX_POSITIONS=512 +MAX_SENTENCES=16 +UPDATE_FREQ=2 +DATA_DIR=/path/to/data/here + +fairseq-train $DATA_DIR \ + --task masked_lm --criterion masked_lm --arch roberta_base \ + --sample-break-mode complete \ + --tokens-per-sample $TOKENS_PER_SAMPLE --max-positions $MAX_POSITIONS \ + --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $PEAK_LR \ + --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.01 \ + --batch-size $MAX_SENTENCES \ + --update-freq $UPDATE_FREQ --max-update $TOTAL_UPDATES \ + --save-dir checkpoint/roberta \ + --ddp-backend legacy_ddp --encoder-layerdrop 0.2 \ + --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 --untie-weights-roberta +``` + +To **finetune** RoBERTa + QuantNoise, we followed this setting [here](https://github.com/pytorch/fairseq/blob/main/examples/roberta/README.glue.md). +The following command can be used to finetune a RoBERTa Base + QuantNoise model on the RTE dataset: + +```bash +TOTAL_NUM_UPDATES=2036 +WARMUP_UPDATES=122 +LR=2e-05 +NUM_CLASSES=2 +MAX_SENTENCES=16 +ROBERTA_PATH=/path/to/roberta_quantnoise/model.pt + +fairseq-train /path/to/rte/data/ \ + --restore-file $ROBERTA_PATH \ + --max-positions 512 \ + --batch-size $MAX_SENTENCES \ + --max-tokens 4400 \ + --task sentence_prediction \ + --reset-optimizer --reset-dataloader --reset-meters \ + --required-batch-size-multiple 1 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --max-epoch 10 \ + --find-unused-parameters \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \ + --ddp-backend legacy_ddp \ + --quant-noise-pq 0.2 --quant-noise-pq-block-size 8 +``` + +To **train** Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/main/examples/language_model). +The following command can be used to train a Transformer + QuantNoise model on Wikitext-103: + +```bash +fairseq-train --task language_modeling /path/to/wikitext-103/data \ + --save-dir checkpoints/transformer_wikitext-103 \ + --adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \ + --adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \ + --tie-adaptive-proj --tie-adaptive-weights \ + --arch transformer_lm_gbw \ + --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ + --clip-norm 0.1 --criterion adaptive_loss \ + --ddp-backend legacy_ddp \ + --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \ + --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 1.0 --t-mult 2.0 \ + --max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \ + --sample-break-mode none --update-freq 3 \ + --warmup-init-lr 1e-07 --warmup-updates 16000 \ + --weight-decay 0 --seed 1 --stop-min-lr 1e-09 \ + --quant-noise-pq 0.05 --quant-noise-pq-block-size 8 +``` + +To **evaluate** this model, note you need to use the `eval.py` script. The following command can be used to evaluate: + +```bash +fairseq-eval-lm /path/to/wikitext-103/data --path /path/to/model/checkpoint \ + --sample-break-mode complete \ + --max-tokens 3072 \ + --context-window 2560 \ + --softmax-batch 1024 \ + --gen-subset valid +``` +and change the `--gen-subset` to `test` if you would like to evaluate on the test set instead. + + +### Iterative Product Quantization + +To quantize the finetuned RoBERTa model, we use this command on 1 GPU. This should run in a day. +```bash +TOTAL_NUM_UPDATES=6108 # 2036 updates for each iteration +WARMUP_UPDATES=122 +LR=2e-05 +NUM_CLASSES=2 +MAX_SENTENCES=16 +fairseq-train --task sentence_prediction /path/to/data/ \ + --restore-file $ROBERTA_PATH \ + --save-dir checkpoints/roberta_finetuned \ + --max-positions 512 \ + --batch-size $MAX_SENTENCES \ + --max-tokens 4400 \ + --init-token 0 --separator-token 2 \ + --arch roberta_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ + --clip-norm 0.0 --lr-scheduler polynomial_decay \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --no-progress-bar --skip-invalid-size-inputs-valid-test --ddp-backend legacy_ddp \ + --quantization-config-path /path/to/config/yaml +``` + +To quantize the trained Language Model, we use this command on 8 V100 23GB GPUs. This should run in a couple of hours. +```bash +fairseq-train --task language_modeling /path/to/wikitext-103/data \ + --save-dir checkpoints/transformer_wikitext-103 \ + --adaptive-input --adaptive-input-cutoff 20000,60000 --adaptive-input-factor 4 \ + --adaptive-softmax-cutoff 20000,60000 --adaptive-softmax-dropout 0.2 --adaptive-softmax-factor 4.0 \ + --arch transformer_lm_gbw \ + --attention-dropout 0.1 --dropout 0.2 --relu-dropout 0.1 \ + --bucket-cap-mb 25 --char-embedder-highway-layers 2 --character-embedding-dim 4 \ + --clip-norm 0.1 --criterion adaptive_loss \ + --ddp-backend legacy_ddp \ + --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ + --fp16 --keep-last-epochs -1 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 0.05 --stop-min-lr 1e-09 \ + --max-tokens 2944 --tokens-per-sample 2944\ + --momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \ + --sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \ + --tie-adaptive-proj --tie-adaptive-weights --update-freq 3 --weight-decay 0 --seed 1 \ + --log-interval 100 --no-progress-bar --skip-invalid-size-inputs-valid-test \ + --restore-file path/to/trained/lm/with/quant/noise \ + --max-update 13500 --quantization-config-path /path/to/config/yaml +``` +If you have less capacity or if your distributed training freezes, try reducing `--max-tokens` and `--tokens-per-sample` (this may reduce the quantized accuracy a bit). + +### Remarks + +We try to keep the open-sourced code as readable and as easy-to-plug as possible. Therefore, we did not test it for the following cases: +- Scalar quantization with RoBERTa. +- Quantization with iPQ and `int8` combined. + +If you have trouble adapting it, we will be more than happy to help! + +## Looking to reproduce the Vision results in the paper? + +We are working on open sourcing our code as part of ClassyVision. Please check back. + + +## Having an issue or have a question? + +Please open an issue in this repository with the details of your question. Thanks!