Sakalti commited on
Commit
1bf4dfb
verified
1 Parent(s): 1e482da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
  from datasets import load_dataset, Dataset, DatasetDict
4
  import os
5
  import time
@@ -107,7 +107,7 @@ class CustomCallback(TrainerCallback):
107
 
108
  def on_step_begin(self, args, state, control, **kwargs):
109
  global progress_info
110
- total_steps = state.num_train_steps
111
  current_step = state.global_step
112
  progress_info["status"] = f"銈ㄣ儩銉冦偗 {state.epoch + 1} / {args.num_train_epochs}, 銈广儐銉冦儣 {current_step + 1} / {total_steps}"
113
  progress_info["progress"] = (current_step + 1) / total_steps
@@ -115,7 +115,7 @@ class CustomCallback(TrainerCallback):
115
 
116
  def on_step_end(self, args, state, control, **kwargs):
117
  global progress_info
118
- total_steps = state.num_train_steps
119
  current_step = state.global_step
120
  elapsed_time = time.time() - state.log_history[0]["epoch_time"]
121
  time_per_step = elapsed_time / (current_step + 1)
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
3
  from datasets import load_dataset, Dataset, DatasetDict
4
  import os
5
  import time
 
107
 
108
  def on_step_begin(self, args, state, control, **kwargs):
109
  global progress_info
110
+ total_steps = state.max_steps
111
  current_step = state.global_step
112
  progress_info["status"] = f"銈ㄣ儩銉冦偗 {state.epoch + 1} / {args.num_train_epochs}, 銈广儐銉冦儣 {current_step + 1} / {total_steps}"
113
  progress_info["progress"] = (current_step + 1) / total_steps
 
115
 
116
  def on_step_end(self, args, state, control, **kwargs):
117
  global progress_info
118
+ total_steps = state.max_steps
119
  current_step = state.global_step
120
  elapsed_time = time.time() - state.log_history[0]["epoch_time"]
121
  time_per_step = elapsed_time / (current_step + 1)