|
|
|
import torch |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
|
from datasets import load_dataset |
|
from peft import LoraConfig, TaskType |
|
from trl import SFTConfig, SFTTrainer |
|
from enum import Enum |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
set_seed(42) |
|
|
|
|
|
model_name = "google/gemma-2-2b-it" |
|
dataset_name = "Jofthomas/hermes-function-calling-thinking-V1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, token=True) |
|
|
|
|
|
class ChatmlSpecialTokens(str, Enum): |
|
tools = "<tools>" |
|
eotools = "</tools>" |
|
think = "<think>" |
|
eothink = "</think>" |
|
tool_call="<tool_call>" |
|
eotool_call="</tool_call>" |
|
tool_response="<tool_response>" |
|
eotool_response="</tool_response>" |
|
pad_token = "<pad>" |
|
eos_token = "<eos>" |
|
|
|
@classmethod |
|
def list(cls): |
|
return [c.value for c in cls] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
pad_token=ChatmlSpecialTokens.pad_token.value, |
|
additional_special_tokens=ChatmlSpecialTokens.list() |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, token=True, attn_implementation="eager") |
|
model.resize_token_embeddings(len(tokenizer)) |
|
model.to(device) |
|
|
|
|
|
def preprocess(sample): |
|
messages = sample["messages"] |
|
|
|
if not messages or not isinstance(messages, list): |
|
return {"text": ""} |
|
|
|
first_message = messages[0] |
|
|
|
|
|
if first_message["role"] == "system": |
|
system_message_content = first_message.get("content", "") |
|
if len(messages) > 1 and messages[1]["role"] == "user": |
|
messages[1]["content"] = ( |
|
system_message_content |
|
+ "\n\nAlso, before making a call to a function, take the time to plan the function to take. " |
|
+ "Make that thinking process between <think>{your thoughts}</think>\n\n" |
|
+ messages[1].get("content", "") |
|
) |
|
messages.pop(0) |
|
|
|
|
|
valid_roles = ["user", "assistant"] |
|
cleaned_messages = [ |
|
msg for msg in messages if msg.get("role") in valid_roles and msg.get("content") |
|
] |
|
|
|
|
|
if not cleaned_messages or cleaned_messages[0]["role"] != "user": |
|
return {"text": ""} |
|
|
|
|
|
try: |
|
formatted_text = tokenizer.apply_chat_template(cleaned_messages, tokenize=False) |
|
return {"text": formatted_text} |
|
except Exception as e: |
|
print(f"Error processing message: {e}") |
|
return {"text": ""} |
|
|
|
|
|
dataset = load_dataset(dataset_name, cache_dir="/tmp") |
|
dataset = dataset.rename_column("conversations", "messages") |
|
dataset = dataset.map(preprocess, remove_columns=["messages"]) |
|
dataset = dataset["train"].train_test_split(0.1) |
|
|
|
|
|
print(f"Training dataset size: {len(dataset['train'])} samples") |
|
print(f"Evaluation dataset size: {len(dataset['test'])} samples") |
|
|
|
|
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=64, |
|
lora_dropout=0.05, |
|
target_modules=["gate_proj", "q_proj", "lm_head", "o_proj", "k_proj", "embed_tokens", "down_proj", "up_proj", "v_proj"], |
|
task_type=TaskType.CAUSAL_LM, |
|
bias="none", |
|
) |
|
|
|
|
|
num_train_epochs = 5 |
|
max_steps = 1000 |
|
learning_rate = 5e-5 |
|
|
|
training_arguments = SFTConfig( |
|
output_dir="gemma-2-2B-it-macM3", |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
gradient_accumulation_steps=4, |
|
save_strategy="epoch", |
|
save_total_limit=2, |
|
save_safetensors=False, |
|
evaluation_strategy="epoch", |
|
logging_steps=5, |
|
learning_rate=learning_rate, |
|
max_grad_norm=1.0, |
|
weight_decay=0.1, |
|
warmup_ratio=0.1, |
|
lr_scheduler_type="cosine", |
|
report_to="tensorboard", |
|
bf16=True, |
|
push_to_hub=False, |
|
num_train_epochs=num_train_epochs, |
|
max_steps=max_steps, |
|
gradient_checkpointing=True, |
|
gradient_checkpointing_kwargs={"use_reentrant": False}, |
|
packing=True, |
|
max_seq_length=1500, |
|
) |
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
args=training_arguments, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
processing_class=tokenizer, |
|
peft_config=peft_config, |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.save_model() |
|
|
|
print("Training complete! π Model saved successfully.") |