kalekarnn's picture
Upload 3 files
24a6b10 verified
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, prepare_model_for_kbit_training
import torch
# Configure quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load model and tokenizer
model_name = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model.config.use_cache = False
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# Configure LoRA
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "dense"]
)
# Load and preprocess dataset
ds = load_dataset("OpenAssistant/oasst1")
train_dataset = ds['train']
def format_conversation(example):
"""Format the conversation for instruction fine-tuning"""
# Only process root messages (start of conversations)
if example["role"] == "prompter" and example["parent_id"] is None:
conversation = []
current_msg = example
conversation.append(("Human", current_msg["text"]))
# Follow the conversation thread
current_id = current_msg["message_id"]
while current_id in message_children:
# Get the next message in conversation
next_msg = message_children[current_id]
if next_msg["role"] == "assistant":
conversation.append(("Assistant", next_msg["text"]))
elif next_msg["role"] == "prompter":
conversation.append(("Human", next_msg["text"]))
current_id = next_msg["message_id"]
if len(conversation) >= 2: # At least one exchange (human->assistant)
formatted_text = ""
for speaker, text in conversation:
formatted_text += f"{speaker}: {text}\n\n"
return {"text": formatted_text.strip()}
return {"text": None}
# Build message relationships
print("Building conversation threads...")
message_children = {}
for example in train_dataset:
if example["parent_id"] is not None:
message_children[example["parent_id"]] = example
# Format complete conversations
print("\nFormatting conversations...")
processed_dataset = []
for example in train_dataset:
result = format_conversation(example)
if result["text"] is not None:
processed_dataset.append(result)
if len(processed_dataset) % 100 == 0 and len(processed_dataset) > 0:
print(f"Found {len(processed_dataset)} valid conversations")
print(f"Final dataset size: {len(processed_dataset)} conversations")
# Convert to Dataset format
train_dataset = Dataset.from_list(processed_dataset)
# Remove the redundant conversion
# train_dataset = list(train_dataset)
# train_dataset = Dataset.from_list(train_dataset)
# Convert to standard dataset for training
train_dataset = list(train_dataset)
train_dataset = Dataset.from_list(train_dataset)
# Configure SFT parameters
sft_config = SFTConfig(
output_dir="phi2-finetuned",
num_train_epochs=1,
max_steps=500,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
learning_rate=2e-4,
weight_decay=0.001,
logging_steps=1,
logging_strategy="steps",
save_strategy="steps",
save_steps=100,
save_total_limit=3,
push_to_hub=False,
max_seq_length=512,
report_to="none",
)
# Initialize trainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset, # Changed from dataset to train_dataset
peft_config=peft_config,
args=sft_config,
)
# Train the model
trainer.train()
# Save the trained model in Hugging Face format
trainer.save_model("phi2-finetuned-final")
# Save the model in PyTorch format
model_save_path = "phi2-finetuned-final/model.pt"
torch.save({
'model_state_dict': trainer.model.state_dict(),
'config': trainer.model.config,
'peft_config': peft_config,
}, model_save_path)
print(f"Model saved in PyTorch format at: {model_save_path}")