|
from transformers.integrations import TensorBoardCallback |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM |
|
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq |
|
from transformers import TrainerCallback, TrainerState, TrainerControl |
|
from transformers.trainer import TRAINING_ARGS_NAME |
|
from torch.utils.tensorboard import SummaryWriter |
|
import datasets |
|
import torch |
|
import os |
|
import re |
|
import sys |
|
import wandb |
|
import argparse |
|
from datetime import datetime |
|
from functools import partial |
|
from tqdm import tqdm |
|
from utils import * |
|
|
|
|
|
from peft import ( |
|
TaskType, |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
prepare_model_for_int8_training, |
|
set_peft_model_state_dict, |
|
) |
|
|
|
|
|
os.environ['WANDB_API_KEY'] = 'ecf1e5e4f47441d46822d38a3249d62e8fc94db4' |
|
os.environ['WANDB_PROJECT'] = 'fingpt-forecaster' |
|
|
|
|
|
class GenerationEvalCallback(TrainerCallback): |
|
|
|
def __init__(self, eval_dataset, ignore_until_epoch=0): |
|
self.eval_dataset = eval_dataset |
|
self.ignore_until_epoch = ignore_until_epoch |
|
|
|
def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs): |
|
|
|
if state.epoch is None or state.epoch + 1 < self.ignore_until_epoch: |
|
return |
|
|
|
if state.is_local_process_zero: |
|
model = kwargs['model'] |
|
tokenizer = kwargs['tokenizer'] |
|
generated_texts, reference_texts = [], [] |
|
|
|
for feature in tqdm(self.eval_dataset): |
|
prompt = feature['prompt'] |
|
gt = feature['answer'] |
|
inputs = tokenizer( |
|
prompt, return_tensors='pt', |
|
padding=False, max_length=4096 |
|
) |
|
inputs = {key: value.to(model.device) for key, value in inputs.items()} |
|
|
|
res = model.generate( |
|
**inputs, |
|
use_cache=True |
|
) |
|
output = tokenizer.decode(res[0], skip_special_tokens=True) |
|
answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL) |
|
|
|
generated_texts.append(answer) |
|
reference_texts.append(gt) |
|
|
|
|
|
|
|
|
|
metrics = calc_metrics(reference_texts, generated_texts) |
|
|
|
|
|
if wandb.run is None: |
|
wandb.init() |
|
|
|
wandb.log(metrics, step=state.global_step) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def main(args): |
|
|
|
model_name = parse_model_name(args.base_model, args.from_remote) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
|
|
trust_remote_code=True |
|
) |
|
if args.local_rank == 0: |
|
print(model) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
dataset_list = load_dataset(args.dataset, args.from_remote) |
|
|
|
dataset_train = datasets.concatenate_datasets([d['train'] for d in dataset_list]).shuffle(seed=42) |
|
|
|
if args.test_dataset: |
|
dataset_list = load_dataset(args.test_dataset, args.from_remote) |
|
|
|
dataset_test = datasets.concatenate_datasets([d['test'] for d in dataset_list]) |
|
|
|
original_dataset = datasets.DatasetDict({'train': dataset_train, 'test': dataset_test}) |
|
|
|
eval_dataset = original_dataset['test'].shuffle(seed=42).select(range(50)) |
|
|
|
dataset = original_dataset.map(partial(tokenize, args, tokenizer)) |
|
print('original dataset length: ', len(dataset['train'])) |
|
dataset = dataset.filter(lambda x: not x['exceed_max_length']) |
|
print('filtered dataset length: ', len(dataset['train'])) |
|
dataset = dataset.remove_columns( |
|
['prompt', 'answer', 'label', 'symbol', 'period', 'exceed_max_length'] |
|
) |
|
|
|
current_time = datetime.now() |
|
formatted_time = current_time.strftime('%Y%m%d%H%M') |
|
|
|
training_args = TrainingArguments( |
|
output_dir=f'finetuned_models/{args.run_name}_{formatted_time}', |
|
logging_steps=args.log_interval, |
|
num_train_epochs=args.num_epochs, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
dataloader_num_workers=args.num_workers, |
|
learning_rate=args.learning_rate, |
|
weight_decay=args.weight_decay, |
|
warmup_ratio=args.warmup_ratio, |
|
lr_scheduler_type=args.scheduler, |
|
save_steps=args.eval_steps, |
|
eval_steps=args.eval_steps, |
|
fp16=True, |
|
deepspeed=args.ds_config, |
|
evaluation_strategy=args.evaluation_strategy, |
|
remove_unused_columns=False, |
|
report_to='wandb', |
|
run_name=args.run_name |
|
) |
|
|
|
model.gradient_checkpointing_enable() |
|
model.enable_input_require_grads() |
|
model.is_parallelizable = True |
|
model.model_parallel = True |
|
model.model.config.use_cache = False |
|
|
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=8, |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
target_modules=lora_module_dict[args.base_model], |
|
bias='none', |
|
) |
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset['train'], |
|
eval_dataset=dataset['test'], |
|
tokenizer=tokenizer, |
|
data_collator=DataCollatorForSeq2Seq( |
|
tokenizer, padding=True, |
|
return_tensors="pt" |
|
), |
|
callbacks=[ |
|
GenerationEvalCallback( |
|
eval_dataset=eval_dataset, |
|
ignore_until_epoch=round(0.3 * args.num_epochs) |
|
) |
|
] |
|
) |
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
torch.cuda.empty_cache() |
|
trainer.train() |
|
|
|
|
|
model.save_pretrained(training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--local_rank", default=0, type=int) |
|
parser.add_argument("--run_name", default='local-test', type=str) |
|
parser.add_argument("--dataset", required=True, type=str) |
|
parser.add_argument("--test_dataset", type=str) |
|
parser.add_argument("--base_model", required=True, type=str, choices=['chatglm2', 'llama2']) |
|
parser.add_argument("--max_length", default=512, type=int) |
|
parser.add_argument("--batch_size", default=4, type=int, help="The train batch size per device") |
|
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The learning rate") |
|
parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") |
|
parser.add_argument("--num_epochs", default=8, type=float, help="The training epochs") |
|
parser.add_argument("--num_workers", default=8, type=int, help="dataloader workers") |
|
parser.add_argument("--log_interval", default=20, type=int) |
|
parser.add_argument("--gradient_accumulation_steps", default=8, type=int) |
|
parser.add_argument("--warmup_ratio", default=0.05, type=float) |
|
parser.add_argument("--ds_config", default='./config_new.json', type=str) |
|
parser.add_argument("--scheduler", default='linear', type=str) |
|
parser.add_argument("--instruct_template", default='default') |
|
parser.add_argument("--evaluation_strategy", default='steps', type=str) |
|
parser.add_argument("--eval_steps", default=0.1, type=float) |
|
parser.add_argument("--from_remote", default=False, type=bool) |
|
args = parser.parse_args() |
|
|
|
wandb.login() |
|
main(args) |