HarmonyAI / scripts /fine_tune.py
stagbrook-tech's picture
Initial commit
ea7fd90
"""
fine_tune.py
------------
This script fine-tunes the selected models using the prepared dataset.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import pandas as pd
import torch
# Load dataset
df = pd.read_csv("../dataset/commands_dataset.csv")
commands = df["command"].tolist()
intents = df["intent"].tolist()
methods = df["effective_method"].tolist()
checks = df["safety_check"].tolist()
permissions = df["permissions"].tolist()
class CommandsDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, commands, intents, methods, checks, permissions, max_length=512):
self.tokenizer = tokenizer
self.commands = commands
self.intents = intents
self.methods = methods
self.checks = checks
self.permissions = permissions
self.max_length = max_length
def __len__(self):
return len(self.commands)
def __getitem__(self, idx):
input_text = f"Command: {self.commands[idx]} Intent: {self.intents[idx]} Method: {self.methods[idx]} Safety Check: {self.checks[idx]} Permissions: {self.permissions[idx]}"
encoding = self.tokenizer(input_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
return {key: val.squeeze() for key, val in encoding.items()}
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("redpajama/incite-chat-3b-v1")
model = AutoModelForCausalLM.from_pretrained("redpajama/incite-chat-3b-v1")
# Prepare dataset
dataset = CommandsDataset(tokenizer, commands, intents, methods, checks, permissions)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
warmup_steps=10,
weight_decay=0.01,
logging_dir="./logs",
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
# Fine-tune the model
trainer.train()