Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
6947876
1
Parent(s):
b606ad0
finetune loss chart: use steps as the x axis if possible
Browse files
llama_lora/globals.py
CHANGED
@@ -5,7 +5,7 @@ import psutil
|
|
5 |
import math
|
6 |
|
7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
-
|
9 |
from numba import cuda
|
10 |
import nvidia_smi
|
11 |
|
@@ -47,6 +47,7 @@ class Global:
|
|
47 |
training_status_text: str = ""
|
48 |
training_eta_predictor = ETAPredictor()
|
49 |
training_eta: Union[int, None] = None
|
|
|
50 |
train_output: Union[None, Any] = None
|
51 |
train_output_str: Union[None, str] = None
|
52 |
training_params_info_text: str = ""
|
|
|
5 |
import math
|
6 |
|
7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
+
from transformers import TrainingArguments
|
9 |
from numba import cuda
|
10 |
import nvidia_smi
|
11 |
|
|
|
47 |
training_status_text: str = ""
|
48 |
training_eta_predictor = ETAPredictor()
|
49 |
training_eta: Union[int, None] = None
|
50 |
+
training_args: Union[TrainingArguments, None] = None
|
51 |
train_output: Union[None, Any] = None
|
52 |
train_output_str: Union[None, str] = None
|
53 |
training_params_info_text: str = ""
|
llama_lora/ui/finetune/training.py
CHANGED
@@ -12,11 +12,13 @@ import pandas as pd
|
|
12 |
import gradio as gr
|
13 |
|
14 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
|
|
15 |
|
16 |
from ...config import Config
|
17 |
from ...globals import Global
|
18 |
from ...models import clear_cache, unload_models
|
19 |
from ...utils.prompter import Prompter
|
|
|
20 |
from ..trainer_callback import (
|
21 |
UiTrainerCallback, reset_training_status,
|
22 |
update_training_states, set_train_output
|
@@ -202,26 +204,31 @@ def do_train(
|
|
202 |
train_data = prompter.get_train_data_from_dataset(data)
|
203 |
|
204 |
if Config.ui_dev_mode:
|
|
|
|
|
|
|
|
|
205 |
message = "Currently in UI dev mode, not doing the actual training."
|
206 |
message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
|
207 |
message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
|
208 |
|
209 |
print(message)
|
210 |
|
211 |
-
|
|
|
212 |
log_history = []
|
213 |
initial_loss = 2
|
214 |
loss_decay_rate = 0.8
|
215 |
-
for i in range(
|
216 |
if (Global.should_stop_training):
|
217 |
break
|
218 |
|
219 |
current_step = i + 1
|
220 |
-
|
221 |
-
current_epoch = i / 100
|
222 |
|
223 |
-
if (
|
224 |
-
loss = initial_loss *
|
|
|
225 |
log_history.append({
|
226 |
'loss': loss,
|
227 |
'learning_rate': 0.0001,
|
@@ -424,7 +431,10 @@ def render_loss_plot():
|
|
424 |
if len(Global.training_log_history) <= 2:
|
425 |
return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
|
426 |
|
427 |
-
|
|
|
|
|
|
|
428 |
|
429 |
loss_data = [
|
430 |
{
|
@@ -436,6 +446,12 @@ def render_loss_plot():
|
|
436 |
and 'epoch' in item
|
437 |
]
|
438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
source = pd.DataFrame(loss_data)
|
440 |
|
441 |
highlight = alt.selection(
|
@@ -443,12 +459,20 @@ def render_loss_plot():
|
|
443 |
on='mouseover', fields=['type'], nearest=True
|
444 |
)
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
points = base.mark_circle().encode(
|
454 |
opacity=alt.value(0)
|
|
|
12 |
import gradio as gr
|
13 |
|
14 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
15 |
+
from transformers import TrainingArguments
|
16 |
|
17 |
from ...config import Config
|
18 |
from ...globals import Global
|
19 |
from ...models import clear_cache, unload_models
|
20 |
from ...utils.prompter import Prompter
|
21 |
+
from ...utils.sample_evenly import sample_evenly
|
22 |
from ..trainer_callback import (
|
23 |
UiTrainerCallback, reset_training_status,
|
24 |
update_training_states, set_train_output
|
|
|
204 |
train_data = prompter.get_train_data_from_dataset(data)
|
205 |
|
206 |
if Config.ui_dev_mode:
|
207 |
+
Global.training_args = TrainingArguments(
|
208 |
+
logging_steps=logging_steps, output_dir=""
|
209 |
+
)
|
210 |
+
|
211 |
message = "Currently in UI dev mode, not doing the actual training."
|
212 |
message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
|
213 |
message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
|
214 |
|
215 |
print(message)
|
216 |
|
217 |
+
total_epochs = epochs
|
218 |
+
total_steps = len(train_data) * epochs
|
219 |
log_history = []
|
220 |
initial_loss = 2
|
221 |
loss_decay_rate = 0.8
|
222 |
+
for i in range(total_steps):
|
223 |
if (Global.should_stop_training):
|
224 |
break
|
225 |
|
226 |
current_step = i + 1
|
227 |
+
current_epoch = i / (total_steps / total_epochs)
|
|
|
228 |
|
229 |
+
if (current_step % logging_steps == 0):
|
230 |
+
loss = initial_loss * \
|
231 |
+
math.exp(-loss_decay_rate * current_epoch)
|
232 |
log_history.append({
|
233 |
'loss': loss,
|
234 |
'learning_rate': 0.0001,
|
|
|
431 |
if len(Global.training_log_history) <= 2:
|
432 |
return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
|
433 |
|
434 |
+
max_elements = 5000
|
435 |
+
training_log_history = sample_evenly(
|
436 |
+
Global.training_log_history, max_elements=max_elements)
|
437 |
+
logging_steps = Global.training_args and Global.training_args.logging_steps
|
438 |
|
439 |
loss_data = [
|
440 |
{
|
|
|
446 |
and 'epoch' in item
|
447 |
]
|
448 |
|
449 |
+
use_steps = False
|
450 |
+
if len(Global.training_log_history) <= max_elements and logging_steps:
|
451 |
+
for index, item in enumerate(loss_data):
|
452 |
+
item["step"] = index * logging_steps
|
453 |
+
use_steps = True
|
454 |
+
|
455 |
source = pd.DataFrame(loss_data)
|
456 |
|
457 |
highlight = alt.selection(
|
|
|
459 |
on='mouseover', fields=['type'], nearest=True
|
460 |
)
|
461 |
|
462 |
+
if use_steps:
|
463 |
+
base = alt.Chart(source).encode( # type: ignore
|
464 |
+
x='step:Q',
|
465 |
+
y='loss:Q',
|
466 |
+
color='type:N',
|
467 |
+
tooltip=['type:N', 'loss:Q', 'step:Q', 'epoch:Q']
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
base = alt.Chart(source).encode( # type: ignore
|
471 |
+
x='epoch:Q',
|
472 |
+
y='loss:Q',
|
473 |
+
color='type:N',
|
474 |
+
tooltip=['type:N', 'loss:Q', 'epoch:Q']
|
475 |
+
)
|
476 |
|
477 |
points = base.mark_circle().encode(
|
478 |
opacity=alt.value(0)
|
llama_lora/ui/trainer_callback.py
CHANGED
@@ -22,6 +22,7 @@ def reset_training_status():
|
|
22 |
Global.training_status_text = ""
|
23 |
Global.training_eta_predictor = ETAPredictor()
|
24 |
Global.training_eta = None
|
|
|
25 |
Global.train_output = None
|
26 |
Global.train_output_str = None
|
27 |
Global.training_params_info_text = ""
|
@@ -102,6 +103,7 @@ class UiTrainerCallback(TrainerCallback):
|
|
102 |
traceback.print_exc()
|
103 |
|
104 |
def on_epoch_begin(self, args, state, control, **kwargs):
|
|
|
105 |
self._on_progress(args, state, control)
|
106 |
|
107 |
def on_step_end(self, args, state, control, **kwargs):
|
|
|
22 |
Global.training_status_text = ""
|
23 |
Global.training_eta_predictor = ETAPredictor()
|
24 |
Global.training_eta = None
|
25 |
+
Global.training_args = None
|
26 |
Global.train_output = None
|
27 |
Global.train_output_str = None
|
28 |
Global.training_params_info_text = ""
|
|
|
103 |
traceback.print_exc()
|
104 |
|
105 |
def on_epoch_begin(self, args, state, control, **kwargs):
|
106 |
+
Global.training_args = args
|
107 |
self._on_progress(args, state, control)
|
108 |
|
109 |
def on_step_end(self, args, state, control, **kwargs):
|
llama_lora/utils/sample_evenly.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Any, Iterator
|
3 |
+
|
4 |
+
|
5 |
+
def sample_evenly_it(input_list: List[Any], max_elements: int = 1000) -> Iterator[Any]:
|
6 |
+
if len(input_list) <= max_elements:
|
7 |
+
yield from input_list
|
8 |
+
else:
|
9 |
+
step = len(input_list) / max_elements
|
10 |
+
indices = np.arange(0, len(input_list), step).astype(int)
|
11 |
+
yield from (input_list[i] for i in indices)
|
12 |
+
|
13 |
+
|
14 |
+
def sample_evenly(input_list: List[Any], max_elements: int = 1000) -> List[Any]:
|
15 |
+
return list(sample_evenly_it(input_list, max_elements))
|