GradioMistral / app.py
kasim90's picture
Update app.py
7d71aa7 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import gradio as gr
import time
import spaces
# === 1️⃣ MODEL VE TOKENIZER YÜKLEME ===
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype, device_map="auto")
# === 2️⃣ LoRA AYARLARI ===
lora_config = LoraConfig(
r=8,
lora_alpha=32,
lora_dropout=0.1,
bias="none",
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(model, lora_config).to("cuda" if torch.cuda.is_available() else "cpu")
# === 3️⃣ VERİ SETİ ===
@spaces.GPU
def load_and_prepare_dataset():
dataset = load_dataset("oscar", "unshuffled_deduplicated_tr", trust_remote_code=True)
subset = dataset["train"].shuffle(seed=42).select(range(10000))
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
tokenized_datasets = subset.map(tokenize_function, batched=True)
return tokenized_datasets
tokenized_dataset = load_and_prepare_dataset()
# === 4️⃣ EĞİTİM AYARLARI ===
batch_size = 1
num_epochs = 1
max_steps = (len(tokenized_dataset) // batch_size) * num_epochs
training_args = TrainingArguments(
output_dir="./mistral_lora",
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=5e-4,
num_train_epochs=1,
max_steps=max_steps,
save_steps=500,
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
optim="adamw_torch",
fp16=torch.cuda.is_available(),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
)
@spaces.GPU
def train_model():
trainer.train()
# === 5️⃣ CHAT ARAYÜZÜ ===
@spaces.GPU
def slow_echo(message, history):
response = "Model henüz eğitilmedi. Lütfen eğitimi başlatın."
if model:
response = f"You typed: {message}"
return response
demo = gr.ChatInterface(
slow_echo,
type="text",
flagging_mode="manual",
flagging_options=["Like", "Spam", "Inappropriate", "Other"],
save_history=True,
)
if __name__ == "__main__":
train_model()
demo.launch(share=True)