File size: 4,476 Bytes
24a6b10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, prepare_model_for_kbit_training
import torch

# Configure quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer
model_name = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.config.use_cache = False

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Configure LoRA
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "dense"]
)

# Load and preprocess dataset
ds = load_dataset("OpenAssistant/oasst1")
train_dataset = ds['train']

def format_conversation(example):
    """Format the conversation for instruction fine-tuning"""
    # Only process root messages (start of conversations)
    if example["role"] == "prompter" and example["parent_id"] is None:
        conversation = []
        current_msg = example
        conversation.append(("Human", current_msg["text"]))
        
        # Follow the conversation thread
        current_id = current_msg["message_id"]
        while current_id in message_children:
            # Get the next message in conversation
            next_msg = message_children[current_id]
            if next_msg["role"] == "assistant":
                conversation.append(("Assistant", next_msg["text"]))
            elif next_msg["role"] == "prompter":
                conversation.append(("Human", next_msg["text"]))
            current_id = next_msg["message_id"]
            
        if len(conversation) >= 2:  # At least one exchange (human->assistant)
            formatted_text = ""
            for speaker, text in conversation:
                formatted_text += f"{speaker}: {text}\n\n"
            return {"text": formatted_text.strip()}
    return {"text": None}

# Build message relationships
print("Building conversation threads...")
message_children = {}
for example in train_dataset:
    if example["parent_id"] is not None:
        message_children[example["parent_id"]] = example

# Format complete conversations
print("\nFormatting conversations...")
processed_dataset = []
for example in train_dataset:
    result = format_conversation(example)
    if result["text"] is not None:
        processed_dataset.append(result)
    if len(processed_dataset) % 100 == 0 and len(processed_dataset) > 0:
        print(f"Found {len(processed_dataset)} valid conversations")

print(f"Final dataset size: {len(processed_dataset)} conversations")

# Convert to Dataset format
train_dataset = Dataset.from_list(processed_dataset)

# Remove the redundant conversion
# train_dataset = list(train_dataset)
# train_dataset = Dataset.from_list(train_dataset)

# Convert to standard dataset for training
train_dataset = list(train_dataset)
train_dataset = Dataset.from_list(train_dataset)

# Configure SFT parameters
sft_config = SFTConfig(
    output_dir="phi2-finetuned",
    num_train_epochs=1,
    max_steps=500,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    weight_decay=0.001,
    logging_steps=1,
    logging_strategy="steps",
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    push_to_hub=False,
    max_seq_length=512,
    report_to="none",
)

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,  # Changed from dataset to train_dataset
    peft_config=peft_config,
    args=sft_config,
)

# Train the model
trainer.train()

# Save the trained model in Hugging Face format
trainer.save_model("phi2-finetuned-final")

# Save the model in PyTorch format
model_save_path = "phi2-finetuned-final/model.pt"
torch.save({
    'model_state_dict': trainer.model.state_dict(),
    'config': trainer.model.config,
    'peft_config': peft_config,
}, model_save_path)
print(f"Model saved in PyTorch format at: {model_save_path}")