File size: 5,314 Bytes
17daafb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Ensure Apple Metal (MPS) is enabled
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from peft import LoraConfig, TaskType
from trl import SFTConfig, SFTTrainer
from enum import Enum
# β
Set device to Metal Performance Shaders (MPS) for Mac M3
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
# β
Set seed for reproducibility
set_seed(42)
# β
Model and dataset
model_name = "google/gemma-2-2b-it"
dataset_name = "Jofthomas/hermes-function-calling-thinking-V1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)
# β
Adjust tokenizer with special tokens
class ChatmlSpecialTokens(str, Enum):
tools = "<tools>"
eotools = "</tools>"
think = "<think>"
eothink = "</think>"
tool_call="<tool_call>"
eotool_call="</tool_call>"
tool_response="<tool_response>"
eotool_response="</tool_response>"
pad_token = "<pad>"
eos_token = "<eos>"
@classmethod
def list(cls):
return [c.value for c in cls]
tokenizer = AutoTokenizer.from_pretrained(
model_name,
pad_token=ChatmlSpecialTokens.pad_token.value,
additional_special_tokens=ChatmlSpecialTokens.list()
)
# β
Load model and move it to MPS
model = AutoModelForCausalLM.from_pretrained(model_name, token=True, attn_implementation="eager")
model.resize_token_embeddings(len(tokenizer))
model.to(device)
# β
Data preprocessing function
def preprocess(sample):
messages = sample["messages"]
if not messages or not isinstance(messages, list):
return {"text": ""} # Return empty text if messages are missing
first_message = messages[0]
# Ensure system messages are merged with the first user message
if first_message["role"] == "system":
system_message_content = first_message.get("content", "")
if len(messages) > 1 and messages[1]["role"] == "user":
messages[1]["content"] = (
system_message_content
+ "\n\nAlso, before making a call to a function, take the time to plan the function to take. "
+ "Make that thinking process between <think>{your thoughts}</think>\n\n"
+ messages[1].get("content", "")
)
messages.pop(0) # Remove system message
# Ensure the conversation alternates between "user" and "assistant"
valid_roles = ["user", "assistant"]
cleaned_messages = [
msg for msg in messages if msg.get("role") in valid_roles and msg.get("content")
]
# Check if messages are empty after cleanup
if not cleaned_messages or cleaned_messages[0]["role"] != "user":
return {"text": ""} # Ensure the first message is always from the user
# Apply chat template
try:
formatted_text = tokenizer.apply_chat_template(cleaned_messages, tokenize=False)
return {"text": formatted_text}
except Exception as e:
print(f"Error processing message: {e}")
return {"text": ""}
# β
Load dataset
dataset = load_dataset(dataset_name, cache_dir="/tmp")
dataset = dataset.rename_column("conversations", "messages")
dataset = dataset.map(preprocess, remove_columns=["messages"])
dataset = dataset["train"].train_test_split(0.1)
# β
Print dataset size before training
print(f"Training dataset size: {len(dataset['train'])} samples")
print(f"Evaluation dataset size: {len(dataset['test'])} samples")
# β
LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["gate_proj", "q_proj", "lm_head", "o_proj", "k_proj", "embed_tokens", "down_proj", "up_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
bias="none",
)
# β
Training configuration (adjusted for performance on Mac M3 Max)
num_train_epochs = 5 # β
Increase to 5 epochs for better training
max_steps = 1000 # β
Ensure at least 1000 training steps
learning_rate = 5e-5 # β
Reduce learning rate to prevent overfitting
training_arguments = SFTConfig(
output_dir="gemma-2-2B-it-macM3",
per_device_train_batch_size=2, # β
Keep small if training on MPS
per_device_eval_batch_size=2,
gradient_accumulation_steps=4, # β
Helps fit larger batch sizes
save_strategy="epoch",
save_total_limit=2,
save_safetensors=False,
evaluation_strategy="epoch",
logging_steps=5,
learning_rate=learning_rate,
max_grad_norm=1.0,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
report_to="tensorboard",
bf16=True, # β
Efficient mixed precision training for Mac MPS
push_to_hub=False,
num_train_epochs=num_train_epochs,
max_steps=max_steps, # β
Ensure training runs for at least 1000 steps
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
packing=True,
max_seq_length=1500,
)
# β
Trainer setup
trainer = SFTTrainer(
model=model,
args=training_arguments,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=peft_config,
)
# β
Start training (should work efficiently on Mac M3 Max)
trainer.train()
trainer.save_model()
print("Training complete! π Model saved successfully.") |