|
from config import HAS_CUDA, MODEL, DEVICE_MAP, TRAINING_PARAMS, LORA_TRAINING_PARAMS, GENERATION_PARAMS |
|
|
|
import os |
|
import gc |
|
import torch |
|
import transformers |
|
import peft |
|
import datasets |
|
from contextlib import nullcontext |
|
|
|
class Trainer(): |
|
def __init__(self): |
|
self.model = None |
|
self.model_name = None |
|
self.lora_name = None |
|
self.loras = {} |
|
|
|
self.tokenizer = None |
|
self.trainer = None |
|
|
|
def unload_model(self): |
|
del self.model |
|
del self.tokenizer |
|
|
|
self.model = None |
|
self.model_name = None |
|
self.tokenizer = None |
|
|
|
if (HAS_CUDA): |
|
with torch.no_grad(): |
|
torch.cuda.empty_cache() |
|
|
|
gc.collect() |
|
|
|
def load_model(self, model_name, force=False, **kwargs): |
|
assert model_name is not None |
|
|
|
if (model_name == self.model_name and not force): |
|
return |
|
|
|
if (self.model is not None): |
|
self.unload_model() |
|
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map=DEVICE_MAP, |
|
load_in_8bit=True, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
if model_name.startswith('decapoda-research/llama'): |
|
self.tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name) |
|
else: |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) |
|
|
|
self.tokenizer.pad_token_id = 0 |
|
self.model_name = model_name |
|
|
|
def load_lora(self, lora_name, replace_model=True): |
|
assert self.model is not None |
|
assert lora_name is not None |
|
|
|
if (lora_name == self.lora_name): |
|
return |
|
|
|
if lora_name in self.loras: |
|
self.lora_name = lora_name |
|
self.model.set_adapter(lora_name) |
|
return |
|
|
|
peft_config = peft.PeftConfig.from_pretrained(lora_name) |
|
if not replace_model: |
|
assert peft_config.base_model_name_or_path == self.model_name |
|
|
|
if peft_config.base_model_name_or_path != self.model_name: |
|
self.load_model(peft_config.base_model_name_or_path) |
|
self.loras = {} |
|
|
|
assert self.model_name is not None |
|
assert self.model is not None |
|
|
|
if hasattr(self.model, 'load_adapter'): |
|
self.model.load_adapter(lora_name, adapter_name=lora_name) |
|
else: |
|
self.model = peft.PeftModel.from_pretrained(self.model, lora_name, adapter_name=lora_name) |
|
|
|
self.model.set_adapter(lora_name) |
|
if (self.model_name.startswith('cerebras')): |
|
self.model.half() |
|
|
|
self.lora_name = lora_name |
|
self.loras[lora_name] = True |
|
|
|
def unload_lora(self): |
|
self.lora_name = None |
|
|
|
def generate(self, prompt, **kwargs): |
|
assert self.model is not None |
|
assert self.model_name is not None |
|
assert self.tokenizer is not None |
|
|
|
kwargs = { **GENERATION_PARAMS, **kwargs } |
|
|
|
inputs = self.tokenizer(str(prompt), return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(self.model.device) |
|
|
|
if self.model.config.pad_token_id is None: |
|
kwargs['pad_token_id'] = self.model.config.eos_token_id |
|
|
|
if (kwargs['do_sample']): |
|
del kwargs['num_beams'] |
|
|
|
generation_config = transformers.GenerationConfig( |
|
use_cache=False, |
|
**kwargs |
|
) |
|
|
|
disable_lora = nullcontext() |
|
if self.lora_name is None and hasattr(self.model, 'disable_adapter'): |
|
disable_lora = self.model.disable_adapter() |
|
|
|
with torch.no_grad(), disable_lora: |
|
output = self.model.generate( |
|
input_ids=input_ids, |
|
attention_mask=torch.ones_like(input_ids), |
|
generation_config=generation_config |
|
)[0].to(self.model.device) |
|
|
|
return self.tokenizer.decode(output, skip_special_tokens=True).strip() |
|
|
|
def tokenize_sample(self, item, max_seq_length, add_eos_token=True): |
|
assert self.tokenizer is not None |
|
result = self.tokenizer( |
|
item["text"], |
|
truncation=True, |
|
max_length=max_seq_length, |
|
padding="max_length", |
|
) |
|
|
|
result = { |
|
"input_ids": result["input_ids"][:-1], |
|
"attention_mask": result["attention_mask"][:-1], |
|
} |
|
|
|
if ( |
|
result["input_ids"][-1] != self.tokenizer.eos_token_id |
|
and len(result["input_ids"]) < max_seq_length |
|
and add_eos_token |
|
): |
|
result["input_ids"].append(self.tokenizer.eos_token_id) |
|
result["attention_mask"].append(1) |
|
|
|
return result |
|
|
|
def tokenize_training_text(self, training_text, max_seq_length, separator="\n\n\n", **kwargs): |
|
samples = training_text.split(separator) |
|
samples = [x.strip() for x in samples] |
|
def to_dict(text): |
|
return { 'text': text } |
|
|
|
samples = [to_dict(x) for x in samples] |
|
|
|
training_dataset = datasets.Dataset.from_list(samples) |
|
training_dataset = training_dataset.shuffle().map( |
|
lambda x: self.tokenize_sample(x, max_seq_length), |
|
batched=False |
|
) |
|
|
|
return training_dataset |
|
|
|
def train(self, training_text=None, new_peft_model_name=None, **kwargs): |
|
assert self.model is not None |
|
assert self.model_name is not None |
|
assert self.tokenizer is not None |
|
|
|
kwargs = { **TRAINING_PARAMS, **LORA_TRAINING_PARAMS, **kwargs } |
|
|
|
self.lora_name = None |
|
self.loras = {} |
|
|
|
train_dataset = self.tokenize_training_text(training_text, **kwargs) |
|
|
|
if hasattr(self.model, 'disable_adapter'): |
|
self.load_model(self.model_name, force=True) |
|
|
|
self.model = peft.prepare_model_for_int8_training(self.model) |
|
self.model = peft.get_peft_model(self.model, peft.LoraConfig( |
|
r=kwargs['lora_r'], |
|
lora_alpha=kwargs['lora_alpha'], |
|
lora_dropout=kwargs['lora_dropout'], |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
)) |
|
|
|
if not os.path.exists('lora'): |
|
os.makedirs('lora') |
|
|
|
sanitized_model_name = self.model_name.replace('/', '_').replace('.', '_') |
|
output_dir = f"lora/{sanitized_model_name}_{new_peft_model_name}" |
|
|
|
training_args = transformers.TrainingArguments( |
|
per_device_train_batch_size=kwargs['micro_batch_size'], |
|
gradient_accumulation_steps=kwargs['gradient_accumulation_steps'], |
|
num_train_epochs=kwargs['epochs'], |
|
learning_rate=kwargs['learning_rate'], |
|
fp16=True, |
|
optim='adamw_torch', |
|
logging_steps=20, |
|
save_total_limit=3, |
|
output_dir=output_dir, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.trainer = transformers.Trainer( |
|
model=self.model, |
|
train_dataset=train_dataset, |
|
args=training_args, |
|
data_collator=transformers.DataCollatorForLanguageModeling( |
|
self.tokenizer, |
|
mlm=False, |
|
), |
|
|
|
) |
|
|
|
self.model.config.use_cache = False |
|
result = self.trainer.train(resume_from_checkpoint=False) |
|
self.model.save_pretrained(output_dir) |
|
|
|
return result |
|
|
|
if __name__ == '__main__': |
|
t = Trainer() |
|
t.load_model(MODEL) |
|
|
|
prompt = "Human: How is cheese made?\n\nAssistant:" |
|
print(t.generate(prompt)) |
|
|
|
t.load_lora('lora/melon-mango-orange') |
|
print(t.generate(prompt)) |
|
|
|
t.unload_lora() |
|
print(t.generate(prompt)) |