File size: 3,892 Bytes
ffe5a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import os
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import DPOTrainer, DPOConfig
import torch


# Model loading function
def load_model():
    print("Initializing model loading...")
    model_name = "outputs_sample_code/checkpoint-200"
    max_seq_length = 512
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name,
        dtype=None,
        load_in_4bit=True
    )
    print("Model and tokenizer loaded successfully.")
    print(f"Model type: {type(model)}, Tokenizer type: {type(tokenizer)}")

    if hasattr(model, 'config'):
        print("Setting max_seq_length in model.config")
        model.config.max_seq_length = max_seq_length
    else:
        print("Error: model.config does not exist!")

    model = FastLanguageModel.get_peft_model(
        model,
        r=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
        max_seq_length=max_seq_length
    )
    print("PEFT model configured.")
    return model, tokenizer

# Dataset loading function
def load_dataset():
    print("Loading dataset...")
    dataset_name = "cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental"

    from datasets import load_dataset
    dataset = load_dataset(dataset_name)

    formatted_data = []
    for item in dataset["train"]:
        formatted_data.append({
            "prompt": item.get("prompt", ""),
            "chosen": item.get("response_winner", ""),
            "rejected": item.get("response_loser", "")
        })

    print(f"Formatted data: {len(formatted_data)} items")
    return Dataset.from_dict({
        "prompt": [item["prompt"] for item in formatted_data],
        "chosen": [item["chosen"] for item in formatted_data],
        "rejected": [item["rejected"] for item in formatted_data]
    })

# DPO training function
def train_dpo(model, tokenizer, dataset):
    print("Configuring training arguments...")

    training_args = DPOConfig(
        output_dir="./dpo_trained_model_1216",
        overwrite_output_dir=True,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=128,
        per_device_eval_batch_size=8,
        learning_rate=1e-5,
        weight_decay=0.01,
        num_train_epochs=1,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=10,
        fp16=True,
        eval_strategy="steps",
        save_strategy="steps",
        save_steps=32,
        logging_steps=8,
        eval_steps=8,
        load_best_model_at_end=True,
        save_safetensors=False,
        save_only_model=True,
        remove_unused_columns=False,
    )
    print("Training arguments configured.")

    print("Initializing DPOTrainer...")
    dpo_trainer = DPOTrainer(
        model=model,
        args=training_args,
        beta=0.3,
        train_dataset=dataset,
        eval_dataset=dataset,
        tokenizer=tokenizer,
        max_prompt_length=162,
        max_length=512,
        loss_type="sigmoid",
        label_smoothing=0.0,
    )
    print("DPOTrainer initialized.")

    print("Starting training...")

    original_forward = model.forward

    def new_forward(*args, **kwargs):
        if "input_ids" in kwargs:
            kwargs["input_ids"] = kwargs["input_ids"].long()
        return original_forward(*args, **kwargs)

    model.forward = new_forward

    dpo_trainer.train()
    print("Training completed.")

if __name__ == "__main__":
    print("Loading model...")
    model, tokenizer = load_model()

    print("Loading dataset...")
    dataset = load_dataset()

    print("Starting DPO training...")
    train_dpo(model, tokenizer, dataset)

    print("Training complete. Model saved.")