Tuning / app.py
hackergeek's picture
Update app.py
777e328 verified
raw
history blame
3.69 kB
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
device = "cpu" # ✅ اجباراً فقط روی CPU اجرا شود
def train_model(dataset_url, model_url, epochs):
try:
# 🚀 بارگیری مدل و توکنایزر با اجازه اجرای کد سفارشی و بدون نیاز به GPU
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url,
trust_remote_code=True,
torch_dtype=torch.float32, # ✅ تغییر به float32 برای CPU
device_map="cpu" # ✅ اجباری کردن اجرای روی CPU
)
# ✅ تنظیم LoRA برای کاهش مصرف حافظه
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.to(device)
# ✅ بارگیری دیتاست
dataset = load_dataset(dataset_url)
# ✅ توکنایز کردن داده‌ها
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets["train"]
# ✅ تنظیمات ترینینگ
training_args = TrainingArguments(
output_dir="./deepseek_lora_cpu",
evaluation_strategy="epoch",
learning_rate=5e-4,
per_device_train_batch_size=1, # کاهش مصرف RAM
per_device_eval_batch_size=1,
num_train_epochs=int(epochs),
save_strategy="epoch",
save_total_limit=2,
logging_dir="./logs",
logging_steps=10,
fp16=False, # ❌ غیرفعال‌سازی FP16 چون روی CPU اجرا می‌شود
gradient_checkpointing=True,
optim="adamw_torch",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
# 🚀 شروع ترینینگ (قفل شده تا پایان)
trainer.train()
trainer.save_model("./deepseek_lora_finetuned")
tokenizer.save_pretrained("./deepseek_lora_finetuned")
return "✅ ترینینگ کامل شد! مدل ذخیره شد."
except Exception as e:
return f"❌ خطا: {str(e)}"
# ✅ رابط کاربری Gradio
with gr.Blocks() as app:
gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - (بدون توقف تا پایان)")
dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum")
model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1")
epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs")
train_button = gr.Button("شروع ترینینگ", interactive=True)
output_text = gr.Textbox(label="وضعیت ترینینگ")
def disable_button(*args):
train_button.interactive = False
return train_model(*args)
train_button.click(disable_button, inputs=[dataset_url, model_url, epochs], outputs=output_text)
app.queue()
app.launch(server_name="0.0.0.0", server_port=7860, share=True)