Luke MacLean
init
17daafb
# Ensure Apple Metal (MPS) is enabled
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from peft import LoraConfig, TaskType
from trl import SFTConfig, SFTTrainer
from enum import Enum
# βœ… Set device to Metal Performance Shaders (MPS) for Mac M3
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# βœ… Set seed for reproducibility
set_seed(42)
# βœ… Model and dataset
model_name = "google/gemma-2-2b-it"
dataset_name = "Jofthomas/hermes-function-calling-thinking-V1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)
# βœ… Adjust tokenizer with special tokens
class ChatmlSpecialTokens(str, Enum):
tools = "<tools>"
eotools = "</tools>"
think = "<think>"
eothink = "</think>"
tool_call="<tool_call>"
eotool_call="</tool_call>"
tool_response="<tool_response>"
eotool_response="</tool_response>"
pad_token = "<pad>"
eos_token = "<eos>"
@classmethod
def list(cls):
return [c.value for c in cls]
tokenizer = AutoTokenizer.from_pretrained(
model_name,
pad_token=ChatmlSpecialTokens.pad_token.value,
additional_special_tokens=ChatmlSpecialTokens.list()
)
# βœ… Load model and move it to MPS
model = AutoModelForCausalLM.from_pretrained(model_name, token=True, attn_implementation="eager")
model.resize_token_embeddings(len(tokenizer))
model.to(device)
# βœ… Data preprocessing function
def preprocess(sample):
messages = sample["messages"]
if not messages or not isinstance(messages, list):
return {"text": ""} # Return empty text if messages are missing
first_message = messages[0]
# Ensure system messages are merged with the first user message
if first_message["role"] == "system":
system_message_content = first_message.get("content", "")
if len(messages) > 1 and messages[1]["role"] == "user":
messages[1]["content"] = (
system_message_content
+ "\n\nAlso, before making a call to a function, take the time to plan the function to take. "
+ "Make that thinking process between <think>{your thoughts}</think>\n\n"
+ messages[1].get("content", "")
)
messages.pop(0) # Remove system message
# Ensure the conversation alternates between "user" and "assistant"
valid_roles = ["user", "assistant"]
cleaned_messages = [
msg for msg in messages if msg.get("role") in valid_roles and msg.get("content")
]
# Check if messages are empty after cleanup
if not cleaned_messages or cleaned_messages[0]["role"] != "user":
return {"text": ""} # Ensure the first message is always from the user
# Apply chat template
try:
formatted_text = tokenizer.apply_chat_template(cleaned_messages, tokenize=False)
return {"text": formatted_text}
except Exception as e:
print(f"Error processing message: {e}")
return {"text": ""}
# βœ… Load dataset
dataset = load_dataset(dataset_name, cache_dir="/tmp")
dataset = dataset.rename_column("conversations", "messages")
dataset = dataset.map(preprocess, remove_columns=["messages"])
dataset = dataset["train"].train_test_split(0.1)
# βœ… Print dataset size before training
print(f"Training dataset size: {len(dataset['train'])} samples")
print(f"Evaluation dataset size: {len(dataset['test'])} samples")
# βœ… LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["gate_proj", "q_proj", "lm_head", "o_proj", "k_proj", "embed_tokens", "down_proj", "up_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
bias="none",
)
# βœ… Training configuration (adjusted for performance on Mac M3 Max)
num_train_epochs = 5 # βœ… Increase to 5 epochs for better training
max_steps = 1000 # βœ… Ensure at least 1000 training steps
learning_rate = 5e-5 # βœ… Reduce learning rate to prevent overfitting
training_arguments = SFTConfig(
output_dir="gemma-2-2B-it-macM3",
per_device_train_batch_size=2, # βœ… Keep small if training on MPS
per_device_eval_batch_size=2,
gradient_accumulation_steps=4, # βœ… Helps fit larger batch sizes
save_strategy="epoch",
save_total_limit=2,
save_safetensors=False,
evaluation_strategy="epoch",
logging_steps=5,
learning_rate=learning_rate,
max_grad_norm=1.0,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
report_to="tensorboard",
bf16=True, # βœ… Efficient mixed precision training for Mac MPS
push_to_hub=False,
num_train_epochs=num_train_epochs,
max_steps=max_steps, # βœ… Ensure training runs for at least 1000 steps
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
packing=True,
max_seq_length=1500,
)
# βœ… Trainer setup
trainer = SFTTrainer(
model=model,
args=training_arguments,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=peft_config,
)
# βœ… Start training (should work efficiently on Mac M3 Max)
trainer.train()
trainer.save_model()
print("Training complete! πŸš€ Model saved successfully.")