File size: 5,314 Bytes
17daafb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Ensure Apple Metal (MPS) is enabled
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from peft import LoraConfig, TaskType
from trl import SFTConfig, SFTTrainer
from enum import Enum

# βœ… Set device to Metal Performance Shaders (MPS) for Mac M3
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# βœ… Set seed for reproducibility
set_seed(42)

# βœ… Model and dataset
model_name = "google/gemma-2-2b-it"
dataset_name = "Jofthomas/hermes-function-calling-thinking-V1"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True)

# βœ… Adjust tokenizer with special tokens
class ChatmlSpecialTokens(str, Enum):
    tools = "<tools>"
    eotools = "</tools>"
    think = "<think>"
    eothink = "</think>"
    tool_call="<tool_call>"
    eotool_call="</tool_call>"
    tool_response="<tool_response>"
    eotool_response="</tool_response>"
    pad_token = "<pad>"
    eos_token = "<eos>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    pad_token=ChatmlSpecialTokens.pad_token.value,
    additional_special_tokens=ChatmlSpecialTokens.list()
)

# βœ… Load model and move it to MPS
model = AutoModelForCausalLM.from_pretrained(model_name, token=True, attn_implementation="eager")
model.resize_token_embeddings(len(tokenizer))
model.to(device)

# βœ… Data preprocessing function
def preprocess(sample):
    messages = sample["messages"]

    if not messages or not isinstance(messages, list):
        return {"text": ""}  # Return empty text if messages are missing

    first_message = messages[0]

    # Ensure system messages are merged with the first user message
    if first_message["role"] == "system":
        system_message_content = first_message.get("content", "")
        if len(messages) > 1 and messages[1]["role"] == "user":
            messages[1]["content"] = (
                    system_message_content
                    + "\n\nAlso, before making a call to a function, take the time to plan the function to take. "
                    + "Make that thinking process between <think>{your thoughts}</think>\n\n"
                    + messages[1].get("content", "")
            )
            messages.pop(0)  # Remove system message

    # Ensure the conversation alternates between "user" and "assistant"
    valid_roles = ["user", "assistant"]
    cleaned_messages = [
        msg for msg in messages if msg.get("role") in valid_roles and msg.get("content")
    ]

    # Check if messages are empty after cleanup
    if not cleaned_messages or cleaned_messages[0]["role"] != "user":
        return {"text": ""}  # Ensure the first message is always from the user

    # Apply chat template
    try:
        formatted_text = tokenizer.apply_chat_template(cleaned_messages, tokenize=False)
        return {"text": formatted_text}
    except Exception as e:
        print(f"Error processing message: {e}")
        return {"text": ""}

# βœ… Load dataset
dataset = load_dataset(dataset_name, cache_dir="/tmp")
dataset = dataset.rename_column("conversations", "messages")
dataset = dataset.map(preprocess, remove_columns=["messages"])
dataset = dataset["train"].train_test_split(0.1)

# βœ… Print dataset size before training
print(f"Training dataset size: {len(dataset['train'])} samples")
print(f"Evaluation dataset size: {len(dataset['test'])} samples")

# βœ… LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["gate_proj", "q_proj", "lm_head", "o_proj", "k_proj", "embed_tokens", "down_proj", "up_proj", "v_proj"],
    task_type=TaskType.CAUSAL_LM,
    bias="none",
)

# βœ… Training configuration (adjusted for performance on Mac M3 Max)
num_train_epochs = 5  # βœ… Increase to 5 epochs for better training
max_steps = 1000  # βœ… Ensure at least 1000 training steps
learning_rate = 5e-5  # βœ… Reduce learning rate to prevent overfitting

training_arguments = SFTConfig(
    output_dir="gemma-2-2B-it-macM3",
    per_device_train_batch_size=2,  # βœ… Keep small if training on MPS
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,  # βœ… Helps fit larger batch sizes
    save_strategy="epoch",
    save_total_limit=2,
    save_safetensors=False,
    evaluation_strategy="epoch",
    logging_steps=5,
    learning_rate=learning_rate,
    max_grad_norm=1.0,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    bf16=True,  # βœ… Efficient mixed precision training for Mac MPS
    push_to_hub=False,
    num_train_epochs=num_train_epochs,
    max_steps=max_steps,  # βœ… Ensure training runs for at least 1000 steps
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    packing=True,
    max_seq_length=1500,
)

# βœ… Trainer setup
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
    peft_config=peft_config,
)

# βœ… Start training (should work efficiently on Mac M3 Max)
trainer.train()
trainer.save_model()

print("Training complete! πŸš€ Model saved successfully.")