Spaces:
Configuration error
Configuration error
""" | |
fine_tune.py | |
------------ | |
This script fine-tunes the selected models using the prepared dataset. | |
""" | |
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
import pandas as pd | |
import torch | |
# Load dataset | |
df = pd.read_csv("../dataset/commands_dataset.csv") | |
commands = df["command"].tolist() | |
intents = df["intent"].tolist() | |
methods = df["effective_method"].tolist() | |
checks = df["safety_check"].tolist() | |
permissions = df["permissions"].tolist() | |
class CommandsDataset(torch.utils.data.Dataset): | |
def __init__(self, tokenizer, commands, intents, methods, checks, permissions, max_length=512): | |
self.tokenizer = tokenizer | |
self.commands = commands | |
self.intents = intents | |
self.methods = methods | |
self.checks = checks | |
self.permissions = permissions | |
self.max_length = max_length | |
def __len__(self): | |
return len(self.commands) | |
def __getitem__(self, idx): | |
input_text = f"Command: {self.commands[idx]} Intent: {self.intents[idx]} Method: {self.methods[idx]} Safety Check: {self.checks[idx]} Permissions: {self.permissions[idx]}" | |
encoding = self.tokenizer(input_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt") | |
return {key: val.squeeze() for key, val in encoding.items()} | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("redpajama/incite-chat-3b-v1") | |
model = AutoModelForCausalLM.from_pretrained("redpajama/incite-chat-3b-v1") | |
# Prepare dataset | |
dataset = CommandsDataset(tokenizer, commands, intents, methods, checks, permissions) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=3, | |
per_device_train_batch_size=2, | |
per_device_eval_batch_size=2, | |
warmup_steps=10, | |
weight_decay=0.01, | |
logging_dir="./logs", | |
) | |
# Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
) | |
# Fine-tune the model | |
trainer.train() | |