hackergeek commited on
Commit
006af89
·
verified ·
1 Parent(s): e88f543

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -20
app.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import gradio as gr
3
  import multiprocessing
4
  import os
5
- import time
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
7
  from peft import get_peft_model, LoraConfig, TaskType
8
  from datasets import load_dataset
@@ -11,34 +10,56 @@ device = "cpu"
11
  training_process = None
12
  log_file = "training_status.log"
13
 
 
14
  def log_status(message):
15
  with open(log_file, "w") as f:
16
  f.write(message)
17
 
 
18
  def read_status():
19
  if os.path.exists(log_file):
20
  with open(log_file, "r") as f:
21
  return f.read()
22
  return "⏳ در انتظار شروع ترینینگ..."
23
 
 
 
 
 
 
 
 
 
 
24
  def train_model(dataset_url, model_url, epochs):
25
  try:
26
  log_status("🚀 در حال بارگیری مدل...")
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
30
  )
31
 
32
  lora_config = LoraConfig(
33
- task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"]
 
 
 
 
34
  )
35
-
36
  model = get_peft_model(model, lora_config)
37
  model.to(device)
38
 
39
  dataset = load_dataset(dataset_url)
 
 
 
 
 
 
 
40
  def tokenize_function(examples):
41
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256)
42
 
43
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
44
  train_dataset = tokenized_datasets["train"]
@@ -61,8 +82,8 @@ def train_model(dataset_url, model_url, epochs):
61
  )
62
 
63
  trainer = Trainer(
64
- model=model,
65
- args=training_args,
66
  train_dataset=train_dataset
67
  )
68
 
@@ -78,6 +99,7 @@ def train_model(dataset_url, model_url, epochs):
78
  except Exception as e:
79
  log_status(f"❌ خطا: {str(e)}")
80
 
 
81
  def start_training(dataset_url, model_url, epochs):
82
  global training_process
83
  if training_process is None or not training_process.is_alive():
@@ -87,26 +109,24 @@ def start_training(dataset_url, model_url, epochs):
87
  else:
88
  return "⚠ ترینینگ در حال اجرا است!"
89
 
 
90
  def update_status():
91
  return read_status()
92
 
 
93
  with gr.Blocks() as app:
94
  gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")
95
 
96
- dataset_url = gr.Textbox(label="Dataset URL (Hugging Face)", placeholder="مثال: samsum")
97
- model_url = gr.Textbox(label="Model URL (Hugging Face)", placeholder="مثال: deepseek-ai/deepseek-r1")
98
- epochs = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="تعداد Epochs")
99
-
100
- train_button = gr.Button("شروع ترینینگ")
101
- output_text = gr.Textbox(label="وضعیت ترینینگ")
102
-
103
- train_button.click(start_training, inputs=[dataset_url, model_url, epochs], outputs=output_text)
104
 
105
- # نمایش وضعیت لحظه‌ای ترینینگ
106
- status_box = gr.Textbox(label="مرحله فعلی ترینینگ", interactive=False)
107
- refresh_button = gr.Button("🔄 به‌روزرسانی وضعیت")
108
 
109
- refresh_button.click(update_status, inputs=[], outputs=status_box)
 
 
110
 
111
- app.queue()
112
- app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
2
  import gradio as gr
3
  import multiprocessing
4
  import os
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
6
  from peft import get_peft_model, LoraConfig, TaskType
7
  from datasets import load_dataset
 
10
  training_process = None
11
  log_file = "training_status.log"
12
 
13
+ # Logging function
14
  def log_status(message):
15
  with open(log_file, "w") as f:
16
  f.write(message)
17
 
18
+ # Read training status
19
  def read_status():
20
  if os.path.exists(log_file):
21
  with open(log_file, "r") as f:
22
  return f.read()
23
  return "⏳ در انتظار شروع ترینینگ..."
24
 
25
+ # Function to find the text column dynamically
26
+ def find_text_column(dataset):
27
+ sample = dataset["train"][0] # Get the first row of the training dataset
28
+ for column in sample.keys():
29
+ if isinstance(sample[column], str): # Find the first text-like column
30
+ return column
31
+ return None # No valid text column found
32
+
33
+ # Model training function
34
  def train_model(dataset_url, model_url, epochs):
35
  try:
36
  log_status("🚀 در حال بارگیری مدل...")
37
+
38
  tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  model_url, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
41
  )
42
 
43
  lora_config = LoraConfig(
44
+ task_type=TaskType.CAUSAL_LM,
45
+ r=8,
46
+ lora_alpha=32,
47
+ lora_dropout=0.1,
48
+ target_modules=["q_proj", "v_proj"]
49
  )
 
50
  model = get_peft_model(model, lora_config)
51
  model.to(device)
52
 
53
  dataset = load_dataset(dataset_url)
54
+
55
+ # Automatically detect the correct text column
56
+ text_column = find_text_column(dataset)
57
+ if not text_column:
58
+ log_status("❌ خطا: ستون متنی در دیتاست یافت نشد!")
59
+ return
60
+
61
  def tokenize_function(examples):
62
+ return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=256)
63
 
64
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
65
  train_dataset = tokenized_datasets["train"]
 
82
  )
83
 
84
  trainer = Trainer(
85
+ model=model,
86
+ args=training_args,
87
  train_dataset=train_dataset
88
  )
89
 
 
99
  except Exception as e:
100
  log_status(f"❌ خطا: {str(e)}")
101
 
102
+ # Start training in a separate process
103
  def start_training(dataset_url, model_url, epochs):
104
  global training_process
105
  if training_process is None or not training_process.is_alive():
 
109
  else:
110
  return "⚠ ترینینگ در حال اجرا است!"
111
 
112
+ # Function to update the status
113
  def update_status():
114
  return read_status()
115
 
116
+ # Gradio UI
117
  with gr.Blocks() as app:
118
  gr.Markdown("# 🚀 AutoTrain DeepSeek R1 (CPU) - نمایش وضعیت لحظه‌ای")
119
 
120
+ with gr.Row():
121
+ dataset_input = gr.Textbox(label="📂 لینک دیتاست (Hugging Face)")
122
+ model_input = gr.Textbox(label="🤖 مدل پایه (Hugging Face)")
123
+ epochs_input = gr.Number(label="🔄 تعداد Epochs", value=3)
 
 
 
 
124
 
125
+ start_button = gr.Button("🚀 شروع ترینینگ")
126
+ status_output = gr.Textbox(label="📢 وضعیت ترینینگ", interactive=False)
 
127
 
128
+ start_button.click(start_training, inputs=[dataset_input, model_input, epochs_input], outputs=status_output)
129
+ status_button = gr.Button("🔄 بروزرسانی وضعیت")
130
+ status_button.click(update_status, outputs=status_output)
131
 
132
+ app.launch()