|
|
|
""" |
|
Script for preparing the SFT data for fine-tuning AMD-OLMo model. |
|
Modifed from https://github.com/allenai/OLMo/blob/main/scripts/prepare_tulu_data.py |
|
""" |
|
|
|
import logging |
|
from argparse import ArgumentParser |
|
from functools import partial |
|
from pathlib import Path |
|
|
|
import datasets as ds |
|
import numpy as np |
|
from rich.progress import track |
|
|
|
from olmo.tokenizer import Tokenizer |
|
from olmo.util import prepare_cli_environment |
|
import random |
|
from tqdm import tqdm |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def convert_code_feedback_to_tulu_format(dataset, mix=False): |
|
log.info("Converting code_feedback ...") |
|
y_all = [] |
|
for i, sample in enumerate(dataset): |
|
y = { |
|
"dataset": "code_feedback", |
|
"id": "code_feedback_{}".format(i), |
|
"messages": sample['messages'] |
|
} |
|
y_all.append(y) |
|
|
|
log.info(f"In total {len(y_all)} samples") |
|
if mix: |
|
return y_all |
|
else: |
|
new_dataset = ds.Dataset.from_list(y_all) |
|
return new_dataset |
|
|
|
|
|
def convert_OpenHermes_to_tulu_format(dataset, mix=False): |
|
log.info("Converting OpenHermes ...") |
|
role_map = {"human": "user", "gpt": "assistant", "system": "system"} |
|
y_all = [] |
|
for i, sample in enumerate(dataset): |
|
y = { |
|
"dataset": "OpenHermes", |
|
"id": "OpenHermes_{}".format(i), |
|
"messages": [{"role": role_map[mssg["from"]], "content": mssg["value"]} for mssg in sample['conversations']] |
|
} |
|
y_all.append(y) |
|
|
|
log.info(f"In total {len(y_all)} samples") |
|
if mix: |
|
return y_all |
|
else: |
|
new_dataset = ds.Dataset.from_list(y_all) |
|
return new_dataset |
|
|
|
|
|
def convert_WebInstructSub_to_tulu_format(dataset, mix=False): |
|
log.info("Converting WebInstructSub ...") |
|
y_all = [] |
|
for i, sample in tqdm(enumerate(dataset)): |
|
y = { |
|
"dataset": "WebInstructSub", |
|
"id": "WebInstructSub_{}".format(i), |
|
"messages": [{"role": "user", "content": sample["question"]}, {"role": "assistant", "content": sample["answer"]}] |
|
} |
|
y_all.append(y) |
|
|
|
log.info(f"In total {len(y_all)} samples") |
|
if mix: |
|
return y_all |
|
else: |
|
new_dataset = ds.Dataset.from_list(y_all) |
|
return new_dataset |
|
|
|
|
|
def main(opts) -> None: |
|
tokenizer: Tokenizer |
|
if Path(opts.tokenizer).is_file(): |
|
tokenizer = Tokenizer.from_file(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) |
|
else: |
|
tokenizer = Tokenizer.from_pretrained(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) |
|
|
|
if opts.dataset == "tulu": |
|
dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") |
|
elif opts.dataset == "2nd-phase": |
|
datasets = ["code-feedback", "OpenHermes", "WebInstructSub"] |
|
combined_datasets = [] |
|
for dataset_name in datasets: |
|
if dataset_name == "code-feedback": |
|
dataset = ds.load_dataset("m-a-p/Code-Feedback", split="train") |
|
dataset = convert_code_feedback_to_tulu_format(dataset, mix=True) |
|
elif dataset_name == "OpenHermes": |
|
dataset = ds.load_dataset("teknium/OpenHermes-2.5", split="train") |
|
dataset = convert_OpenHermes_to_tulu_format(dataset, mix=True) |
|
elif dataset_name == "WebInstructSub": |
|
dataset = ds.load_dataset("TIGER-Lab/WebInstructSub", split="train") |
|
dataset = convert_WebInstructSub_to_tulu_format(dataset, mix=True) |
|
|
|
combined_datasets += dataset |
|
|
|
random.seed(42) |
|
random.shuffle(combined_datasets) |
|
log.info(f"In total {len(combined_datasets)} samples") |
|
dataset = ds.Dataset.from_list(combined_datasets) |
|
|
|
log.info("Tokenizing dataset...") |
|
dataset = dataset.map( |
|
partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), |
|
batched=False, |
|
remove_columns=["dataset", "id", "messages"], |
|
num_proc=opts.num_proc, |
|
) |
|
|
|
log.info("Filtering dataset...") |
|
n = len(dataset) |
|
dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) |
|
log.info(f"Filtered out {n - len(dataset):,d} examples") |
|
|
|
log.info("Counting tokens...") |
|
total_tokens = 0 |
|
for ex in track(dataset): |
|
assert len(ex["input_ids"]) == opts.seq_len |
|
total_tokens += len(ex["input_ids"]) |
|
log.info(f"Total tokens: {total_tokens:,d}") |
|
|
|
log.info(f"Saving results to '{opts.output_dir}'...") |
|
output_dir = Path(opts.output_dir) |
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
input_ids_file = np.memmap( |
|
str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,) |
|
) |
|
label_mask_file = np.memmap( |
|
str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,) |
|
) |
|
offset = 0 |
|
for ex in track(dataset): |
|
ex_len = len(ex["input_ids"]) |
|
input_ids_file[offset : offset + ex_len] = ex["input_ids"] |
|
label_mask_file[offset : offset + ex_len] = ex["label_mask"] |
|
offset += ex_len |
|
input_ids_file.flush() |
|
label_mask_file.flush() |
|
|
|
log.info("Done!") |
|
|
|
|
|
def filter(example): |
|
return example["n_labels"] > 0 |
|
|
|
|
|
def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): |
|
input_ids = [tokenizer.eos_token_id] |
|
label_mask = [False] |
|
|
|
for msg in example["messages"]: |
|
role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False) |
|
label_mask += [False] * len(role_tokens) |
|
input_ids += role_tokens |
|
|
|
if msg["role"] == "assistant": |
|
content_tokens = tokenizer.encode( |
|
msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False |
|
) |
|
label_mask += [True] * len(content_tokens) |
|
|
|
assert content_tokens[-2] == tokenizer.eos_token_id |
|
label_mask[-1] = False |
|
else: |
|
content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False) |
|
label_mask += [False] * len(content_tokens) |
|
input_ids += content_tokens |
|
|
|
input_ids = input_ids[:max_seq_len] |
|
label_mask = label_mask[:max_seq_len] |
|
|
|
if len(input_ids) < max_seq_len: |
|
pad_len = max_seq_len - len(input_ids) |
|
input_ids += [tokenizer.pad_token_id] * pad_len |
|
label_mask += [False] * pad_len |
|
|
|
assert len(input_ids) == len(label_mask) |
|
n_labels = sum(label_mask) |
|
|
|
return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels} |
|
|
|
|
|
def get_parser() -> ArgumentParser: |
|
parser = ArgumentParser(description="Prepare Math dataset") |
|
parser.add_argument("--output_dir", type=str, help="""Directory to save the results to.""") |
|
parser.add_argument( |
|
"-t", |
|
"--tokenizer", |
|
type=str, |
|
help="""Tokenizer path or identifier.""", |
|
default=Path(__file__).parent / "tokenizers" / "allenai_eleuther-ai-gpt-neox-20b-pii-special.json", |
|
) |
|
parser.add_argument("-ds", "--dataset", type=str, help="""Dataset that we are processing. tulu or 2nd-phase""", default="tulu") |
|
parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048) |
|
parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=50279) |
|
parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1) |
|
parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
prepare_cli_environment() |
|
opts = get_parser().parse_args() |
|
main(opts) |