Phi2_FineTuned / model_utils.py
bala1802's picture
Upload 6 files
a36cb22 verified
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
import config
def load_model(quantization_config):
model = AutoModelForCausalLM.from_pretrained(
config.MODEL_NAME,
quantization_config = quantization_config,
trust_remote_code = config.TRUST_REMOTE_CODE
)
model.config.use_cache = config.ENABLE_MODEL_CONFIG_CACHE
return model
def load_tokenizers():
tokenizer = AutoTokenizer.from_pretrained(
config.MODEL_NAME,
trust_remote_code=config.TRUST_REMOTE_CODE)
return tokenizer
def load_training_arguments():
training_arguments = TrainingArguments(
output_dir=config.MODEL_OUTPUT_DIR,
per_device_train_batch_size=config.PER_DEVICE_TRAIN_BATCH_SIZE,
gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
optim=config.OPTIM,
save_steps=config.SAVE_STEPS,
logging_steps=config.LOGGING_STEPS,
learning_rate=config.LEARNING_RATE,
fp16=config.ENABLE_FP_16,
max_grad_norm=config.MAX_GRAD_NORM,
max_steps=config.MAX_STEPS,
warmup_ratio=config.WARMUP_RATIO,
gradient_checkpointing=config.ENABLE_GRADIENT_CHECKPOINTING
)
return training_arguments
def load_trainer(model, training_dataset, peft_config, tokenizer, training_arguments):
trainer = SFTTrainer(
model = model,
train_dataset = training_dataset,
peft_config = peft_config,
dataset_text_field = config.DATASET_TEXT_FIELD,
max_seq_length = config.MAX_SEQ_LENGTH,
tokenizer = tokenizer,
args = training_arguments
)
return trainer