Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer | |
from peft import LoraConfig, get_peft_model | |
from datasets import load_dataset | |
import gradio as gr | |
import time | |
import spaces | |
# === 1️⃣ MODEL VE TOKENIZER YÜKLEME === | |
MODEL_NAME = "mistralai/Mistral-7B-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype, device_map="auto") | |
# === 2️⃣ LoRA AYARLARI === | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=32, | |
lora_dropout=0.1, | |
bias="none", | |
target_modules=["q_proj", "v_proj"], | |
) | |
model = get_peft_model(model, lora_config).to("cuda" if torch.cuda.is_available() else "cpu") | |
# === 3️⃣ VERİ SETİ === | |
def load_and_prepare_dataset(): | |
dataset = load_dataset("oscar", "unshuffled_deduplicated_tr", trust_remote_code=True) | |
subset = dataset["train"].shuffle(seed=42).select(range(10000)) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], truncation=True, max_length=512) | |
tokenized_datasets = subset.map(tokenize_function, batched=True) | |
return tokenized_datasets | |
tokenized_dataset = load_and_prepare_dataset() | |
# === 4️⃣ EĞİTİM AYARLARI === | |
batch_size = 1 | |
num_epochs = 1 | |
max_steps = (len(tokenized_dataset) // batch_size) * num_epochs | |
training_args = TrainingArguments( | |
output_dir="./mistral_lora", | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=16, | |
learning_rate=5e-4, | |
num_train_epochs=1, | |
max_steps=max_steps, | |
save_steps=500, | |
save_total_limit=2, | |
logging_dir="./logs", | |
logging_steps=10, | |
optim="adamw_torch", | |
fp16=torch.cuda.is_available(), | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset, | |
) | |
def train_model(): | |
trainer.train() | |
# === 5️⃣ CHAT ARAYÜZÜ === | |
def slow_echo(message, history): | |
response = "Model henüz eğitilmedi. Lütfen eğitimi başlatın." | |
if model: | |
response = f"You typed: {message}" | |
return response | |
demo = gr.ChatInterface( | |
slow_echo, | |
type="text", | |
flagging_mode="manual", | |
flagging_options=["Like", "Spam", "Inappropriate", "Other"], | |
save_history=True, | |
) | |
if __name__ == "__main__": | |
train_model() | |
demo.launch(share=True) | |