File size: 715 Bytes
eca6215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from trl import SFTTrainer
from transformers import TrainingArguments

def train_model(model, tokenizer, train_dataset, dataset_text_field, max_seq_length, dataset_num_proc, packing, training_args):
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        dataset_text_field=dataset_text_field,
        max_seq_length=max_seq_length,
        dataset_num_proc=dataset_num_proc,
        packing=packing,
        args=TrainingArguments(**training_args),
    )
    
    # Train the model
    train_results = trainer.train()
    
    # Optionally, you can return more specific training information if necessary
    return train_results