Spaces:
Runtime error
Runtime error
import json | |
import os | |
from argparse import ArgumentParser | |
from typing import Dict, List | |
from datasets import load_dataset | |
from transformers import pipeline | |
import trlx | |
from trlx.data.default_configs import TRLConfig, default_sft_config | |
def get_positive_score(scores): | |
"Extract value associated with a positive sentiment from pipeline's output" | |
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] | |
def preprocess(instruction: str, input: str, output: str): | |
"""Build Alpaca prompt and output from instruction and input/output examples""" | |
if input: | |
prefix = ( | |
"Below is an instruction that describes a task, paired with an input that provides further context. " | |
"Write a response that appropriately completes the request." | |
) | |
prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" | |
return [prompt, output] | |
else: | |
prefix = ( | |
"Below is an instruction that describes a task. Write a response that appropriately completes the request." | |
) | |
prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Response:\n" | |
return [prompt, output] | |
def main(hparams={}, model_name="EleutherAI/gpt-j-6B", dataset="tatsu-lab/alpaca"): | |
config = default_sft_config() | |
config = config.evolve( | |
train=dict( | |
total_steps=2400, | |
batch_size=4, | |
seq_length=1024, | |
), | |
model=dict( | |
model_path=model_name, | |
), | |
tokenizer=dict( | |
tokenizer_path=model_name, | |
), | |
optimizer=dict(kwargs=dict(lr=2e-5)), | |
scheduler=dict(kwargs=dict(eta_min=2e-5)), | |
method=dict( | |
gen_kwargs=dict( | |
max_new_tokens=256, | |
) | |
), | |
) | |
# Merge sweep config with default config if given | |
config = TRLConfig.update(config.to_dict(), hparams) | |
# alpaca = load_dataset("tatsu-lab/alpaca", split="train") | |
alpaca = load_dataset(dataset, split="train") | |
alpaca = [preprocess(x["instruction"], x["input"], x["output"]) for x in alpaca] | |
sentiment_fn = pipeline( | |
"sentiment-analysis", | |
"lvwerra/distilbert-imdb", | |
top_k=2, | |
truncation=True, | |
batch_size=256, | |
device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, | |
) | |
def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> Dict[str, List[float]]: | |
sentiments = list(map(get_positive_score, sentiment_fn(outputs))) | |
return {"sentiments": sentiments} | |
imdb = load_dataset("imdb", split="test") | |
bad_reviews = imdb.filter(lambda sample: sample["label"] == 0).select(range(256)) | |
zs_rewrite = [preprocess("Rewrite the input into a positive review.", x["text"][:1024], "")[0] for x in bad_reviews] | |
trainer = trlx.train( | |
samples=alpaca, | |
eval_prompts=zs_rewrite, | |
metric_fn=metric_fn, | |
config=config, | |
) | |
slug = f"{model_name.split('/')[-1]}-{dataset.split('/')[-1]}" | |
trainer.save_pretrained(f"{slug}-sft") | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("override_hparams", type=str, default="{}", nargs="?") | |
parser.add_argument("--model_name", type=str, default="EleutherAI/gpt-j-6B") | |
parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca") | |
args = parser.parse_args() | |
hparams = json.loads(args.override_hparams) | |
main(hparams, args.model_name, args.dataset) | |