Sakalti commited on
Commit
1e482da
·
verified ·
1 Parent(s): 704cc94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -2
app.py CHANGED
@@ -2,8 +2,27 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
  from datasets import load_dataset, Dataset, DatasetDict
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def train_and_deploy(write_token, repo_name, license_text):
 
 
 
 
 
7
  # トークンを環境変数に設定
8
  os.environ['HF_WRITE_TOKEN'] = write_token
9
 
@@ -62,16 +81,50 @@ def train_and_deploy(write_token, repo_name, license_text):
62
  args=training_args,
63
  train_dataset=tokenized_datasets["train"],
64
  eval_dataset=tokenized_datasets["test"],
 
65
  )
66
 
67
  # トレーニング実行
 
68
  trainer.train()
 
 
 
 
 
69
 
70
  # モデルをHugging Face Hubにプッシュ
71
  trainer.push_to_hub()
72
-
73
  return f"モデルが'{repo_name}'リポジトリにデプロイされました!"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # Gradio UI
76
  with gr.Blocks() as demo:
77
  gr.Markdown("### pythia トレーニングとデプロイ")
@@ -79,8 +132,17 @@ with gr.Blocks() as demo:
79
  repo_input = gr.Textbox(label="リポジトリ名", placeholder="デプロイするリポジトリ名を入力してください...")
80
  license_input = gr.Textbox(label="ライセンス", placeholder="ライセンス情報を入力してください...")
81
  output = gr.Textbox(label="出力")
 
 
 
82
  train_button = gr.Button("デプロイ")
83
 
84
- train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output)
 
 
 
 
 
 
85
 
86
  demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
  from datasets import load_dataset, Dataset, DatasetDict
4
  import os
5
+ import time
6
+
7
+ # トレーニングの進行状況を格納するグローバル変数
8
+ progress_info = {
9
+ "status": "待機中",
10
+ "progress": 0,
11
+ "time_remaining": None
12
+ }
13
+
14
+ def update_progress(trainer, epoch, step, total_steps, time_remaining):
15
+ global progress_info
16
+ progress_info["status"] = f"エポック {epoch + 1} / {trainer.args.num_train_epochs}, ステップ {step + 1} / {total_steps}"
17
+ progress_info["progress"] = (step + 1) / total_steps
18
+ progress_info["time_remaining"] = time_remaining
19
 
20
  def train_and_deploy(write_token, repo_name, license_text):
21
+ global progress_info
22
+ progress_info["status"] = "トレーニング開始"
23
+ progress_info["progress"] = 0
24
+ progress_info["time_remaining"] = None
25
+
26
  # トークンを環境変数に設定
27
  os.environ['HF_WRITE_TOKEN'] = write_token
28
 
 
81
  args=training_args,
82
  train_dataset=tokenized_datasets["train"],
83
  eval_dataset=tokenized_datasets["test"],
84
+ callbacks=[CustomCallback()]
85
  )
86
 
87
  # トレーニング実行
88
+ start_time = time.time()
89
  trainer.train()
90
+ end_time = time.time()
91
+ total_time = end_time - start_time
92
+ progress_info["status"] = f"トレーニング完了(所要時間: {total_time:.2f}秒)"
93
+ progress_info["progress"] = 1
94
+ progress_info["time_remaining"] = 0
95
 
96
  # モデルをHugging Face Hubにプッシュ
97
  trainer.push_to_hub()
98
+
99
  return f"モデルが'{repo_name}'リポジトリにデプロイされました!"
100
 
101
+ class CustomCallback(TrainerCallback):
102
+ def on_train_begin(self, args, state, control, **kwargs):
103
+ global progress_info
104
+ progress_info["status"] = "トレーニング開始"
105
+ progress_info["progress"] = 0
106
+ progress_info["time_remaining"] = None
107
+
108
+ def on_step_begin(self, args, state, control, **kwargs):
109
+ global progress_info
110
+ total_steps = state.num_train_steps
111
+ current_step = state.global_step
112
+ progress_info["status"] = f"エポック {state.epoch + 1} / {args.num_train_epochs}, ステップ {current_step + 1} / {total_steps}"
113
+ progress_info["progress"] = (current_step + 1) / total_steps
114
+ progress_info["time_remaining"] = None
115
+
116
+ def on_step_end(self, args, state, control, **kwargs):
117
+ global progress_info
118
+ total_steps = state.num_train_steps
119
+ current_step = state.global_step
120
+ elapsed_time = time.time() - state.log_history[0]["epoch_time"]
121
+ time_per_step = elapsed_time / (current_step + 1)
122
+ remaining_steps = total_steps - current_step
123
+ time_remaining = time_per_step * remaining_steps
124
+ progress_info["status"] = f"エポック {state.epoch + 1} / {args.num_train_epochs}, ステップ {current_step + 1} / {total_steps}"
125
+ progress_info["progress"] = (current_step + 1) / total_steps
126
+ progress_info["time_remaining"] = f"{time_remaining:.2f}秒"
127
+
128
  # Gradio UI
129
  with gr.Blocks() as demo:
130
  gr.Markdown("### pythia トレーニングとデプロイ")
 
132
  repo_input = gr.Textbox(label="リポジトリ名", placeholder="デプロイするリポジトリ名を入力してください...")
133
  license_input = gr.Textbox(label="ライセンス", placeholder="ライセンス情報を入力してください...")
134
  output = gr.Textbox(label="出力")
135
+ progress = gr.Progress(track_tqdm=True)
136
+ status = gr.Textbox(label="ステータス", value="待機中")
137
+ time_remaining = gr.Textbox(label="残り時間", value="待機中")
138
  train_button = gr.Button("デプロイ")
139
 
140
+ def update_ui():
141
+ global progress_info
142
+ status.update(value=progress_info["status"])
143
+ progress.update(value=progress_info["progress"])
144
+ time_remaining.update(value=f"{progress_info['time_remaining']}秒" if progress_info['time_remaining'] else "待機中")
145
+
146
+ train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output).then(fn=update_ui)
147
 
148
  demo.launch()