cstr commited on
Commit
1fd9eb4
·
verified ·
1 Parent(s): b066bb6

Delete main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +0 -203
main.py DELETED
@@ -1,203 +0,0 @@
1
-
2
- import os
3
- import time
4
- import wandb
5
- import torch
6
- import argparse
7
- from datasets import load_dataset
8
- from typing import List, Dict, Union
9
- from transformers import (
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- TrainingArguments,
13
- DataCollatorForLanguageModeling
14
- )
15
-
16
- from src.args import default_args
17
- from src.orpo_trainer import ORPOTrainer
18
- from src.utils import preprocess_logits_for_metrics, dataset_split_selector
19
-
20
- class ORPO(object):
21
- def __init__(self, args) -> None:
22
- self.start = time.gmtime()
23
- self.args = args
24
-
25
- # Load Tokenizer
26
- print(">>> 1. Loading Tokenizer")
27
- self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, cache_dir=self.args.cache_dir)
28
- if self.tokenizer.chat_template is None:
29
- self.tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
30
- print(" 1-1. Chat Template Applied (<|user|> <|assistant|>)")
31
- else:
32
- pass
33
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
34
-
35
- # Load Model
36
- print(">>> 2. Loading Model")
37
- if self.args.flash_attention_2:
38
- self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
39
- cache_dir=self.args.cache_dir,
40
- torch_dtype=torch.bfloat16,
41
- attn_implementation="flash_attention_2")
42
- else:
43
- self.model = AutoModelForCausalLM.from_pretrained(self.args.model_name,
44
- cache_dir=self.args.cache_dir,
45
- torch_dtype=torch.bfloat16)
46
-
47
- # Load Dataset
48
- print(">>> 3. Loading Dataset")
49
- self.data = load_dataset(self.args.data_name, cache_dir=self.args.cache_dir)
50
-
51
- # Preprocess Dataset
52
- print(">>> 4. Filtering and Preprocessing Dataset")
53
- data_split = dataset_split_selector(self.data)
54
-
55
- if len(data_split) == 1:
56
- self.is_test = False
57
- train_split = data_split[0]
58
- print(f" >>> Test Set = {self.is_test}")
59
- else:
60
- self.is_test = True
61
- train_split = data_split[0]
62
- test_split = data_split[1]
63
-
64
- test = self.data[test_split].filter(self.filter_dataset)
65
- self.test = test.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[test_split].column_names)
66
-
67
- train = self.data[train_split].filter(self.filter_dataset)
68
- print(f"\n\n>>> {len(train)} / {len(self.data[train_split])} rows left after filtering by prompt length.")
69
- self.train = train.map(self.preprocess_dataset, batched=True, num_proc=self.args.num_proc, remove_columns=self.data[train_split].column_names)
70
-
71
- # Set WANDB & Logging Configurations
72
- self.run_name = f"{self.args.model_name.split('/')[-1]}-{self.args.data_name.split('/')[-1]}-lambda{self.args.alpha}-ORPO-{self.start.tm_mday}-{self.start.tm_hour}-{self.start.tm_min}"
73
- self.save_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}")
74
- self.log_dir = os.path.join('./checkpoints/', f"{self.args.data_name.split('/')[-1]}/{self.run_name}/logs")
75
-
76
- os.makedirs(self.save_dir, exist_ok=True)
77
- os.makedirs(self.log_dir, exist_ok=True)
78
-
79
- def preprocess_dataset(self, examples: Union[List, Dict]):
80
- if ('instruction' in examples.keys()) or ('question' in examples.keys()):
81
- prompt_key = 'instruction' if 'instruction' in examples.keys() else 'question'
82
- prompt = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item}], tokenize=False, add_generation_prompt=True) for item in examples[prompt_key]]
83
- chosen = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_chosen}], tokenize=False) for item_prompt, item_chosen in zip(examples[prompt_key], examples['chosen'])]
84
- rejected = [self.tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}, {'role': 'assistant', 'content': item_rejected}], tokenize=False) for item_prompt, item_rejected in zip(examples[prompt_key], examples['rejected'])]
85
- else:
86
- prompt = [self.tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in examples['chosen']]
87
- chosen = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['chosen']]
88
- rejected = [self.tokenizer.apply_chat_template(item, tokenize=False) for item in examples['rejected']]
89
-
90
- model_inputs = self.tokenizer(prompt,
91
- max_length=self.args.response_max_length,
92
- padding='max_length',
93
- truncation=True,
94
- return_tensors='pt')
95
- pos_labels = self.tokenizer(chosen,
96
- max_length=self.args.response_max_length,
97
- padding='max_length',
98
- truncation=True,
99
- return_tensors='pt')
100
- neg_labels = self.tokenizer(rejected,
101
- max_length=self.args.response_max_length,
102
- padding='max_length',
103
- truncation=True,
104
- return_tensors='pt')
105
-
106
- model_inputs['positive_input_ids'] = pos_labels['input_ids']
107
- model_inputs['positive_attention_mask'] = pos_labels['attention_mask']
108
-
109
- model_inputs['negative_input_ids'] = neg_labels['input_ids']
110
- model_inputs['negative_attention_mask'] = neg_labels['attention_mask']
111
-
112
- return model_inputs
113
-
114
- def filter_dataset(self, examples: Union[List, Dict]):
115
- if 'instruction' in examples.keys():
116
- query = examples['instruction']
117
- prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
118
- elif 'question' in examples.keys():
119
- query = examples['question']
120
- prompt_length = self.tokenizer.apply_chat_template([{'content': query, 'role': 'user'}], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
121
- else:
122
- prompt_length = self.tokenizer.apply_chat_template([examples['chosen'][0]], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
123
-
124
- if prompt_length < self.args.prompt_max_length:
125
- return True
126
- else:
127
- return False
128
-
129
- def prepare_trainer(self):
130
- wandb.init(name=self.run_name)
131
- arguments = TrainingArguments(
132
- output_dir=self.save_dir, # The output directory
133
- logging_dir=self.log_dir,
134
- logging_steps=50,
135
- learning_rate=self.args.lr,
136
- overwrite_output_dir=True, # overwrite the content of the output directory
137
- num_train_epochs=self.args.num_train_epochs, # number of training epochs
138
- per_device_train_batch_size=self.args.per_device_train_batch_size, # batch size for training
139
- per_device_eval_batch_size=self.args.per_device_eval_batch_size, # batch size for evaluation
140
- evaluation_strategy=self.args.evaluation_strategy if self.is_test else 'no', # batch size for evaluation
141
- save_strategy=self.args.evaluation_strategy,
142
- optim=self.args.optim,
143
- warmup_steps=self.args.warmup_steps,
144
- gradient_accumulation_steps=self.args.gradient_accumulation_steps,
145
- gradient_checkpointing=True, #if ('llama' in self.args.model_name.lower()) or ('mistral' in self.args.model_name.lower()) else False,
146
- gradient_checkpointing_kwargs={'use_reentrant':True},
147
- load_best_model_at_end=self.is_test,
148
- do_train=True,
149
- do_eval=self.is_test,
150
- lr_scheduler_type=self.args.lr_scheduler_type,
151
- remove_unused_columns=False,
152
- report_to='wandb',
153
- run_name=self.run_name,
154
- bf16=True
155
- )
156
-
157
- data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
158
-
159
- self.trainer = ORPOTrainer(
160
- model=self.model,
161
- alpha=self.args.alpha,
162
- pad=self.tokenizer.pad_token_id,
163
- args=arguments,
164
- train_dataset=self.train,
165
- eval_dataset=self.test if self.is_test else None,
166
- data_collator=data_collator,
167
- preprocess_logits_for_metrics=preprocess_logits_for_metrics
168
- )
169
-
170
- def run(self):
171
- print(">>> 5. Preparing ORPOTrainer")
172
- self.prepare_trainer()
173
- self.trainer.train()
174
-
175
- # Saving code for FSDP
176
- if self.trainer.is_fsdp_enabled:
177
- self.trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
178
- self.trainer.save_model()
179
-
180
-
181
- if __name__ == '__main__':
182
- parser = argparse.ArgumentParser("ORPO")
183
- args = default_args(parser)
184
-
185
- # Set WANDB configurations
186
- if args.wandb_entity is not None and args.wandb_project_name is not None:
187
- os.environ["WANDB_ENTITY"] = args.wandb_entity
188
- os.environ["WANDB_PROJECT"] = args.wandb_project_name
189
- else:
190
- pass
191
- os.environ["TOKENIZERS_PARALLELISM"] = 'false'
192
-
193
- print("================================================================================================\n")
194
- print(f">>> Fine-tuning {args.model_name} with ORPO on {args.data_name}\n")
195
- print("================================================================================================")
196
- print("\n\n>>> Summary:")
197
- print(f" - Lambda : {args.alpha}")
198
- print(f" - Training Epochs : {args.num_train_epochs}")
199
- print(f" - Prompt Max Length : {args.prompt_max_length}")
200
- print(f" - Response Max Length : {args.response_max_length}")
201
-
202
- item = ORPO(args=args)
203
- item.run()