zetavg
finetune loss chart: use steps as the x axis if possible
6947876 unverified
raw
history blame
19.6 kB
import os
import json
import time
import math
import datetime
import pytz
import socket
import threading
import traceback
import altair as alt
import pandas as pd
import gradio as gr
from huggingface_hub import try_to_load_from_cache, snapshot_download
from transformers import TrainingArguments
from ...config import Config
from ...globals import Global
from ...models import clear_cache, unload_models
from ...utils.prompter import Prompter
from ...utils.sample_evenly import sample_evenly
from ..trainer_callback import (
UiTrainerCallback, reset_training_status,
update_training_states, set_train_output
)
from .data_processing import get_data_from_input
def status_message_callback(message):
if Global.should_stop_training:
return True
Global.training_status_text = message
def params_info_callback(all_params, trainable_params):
Global.training_params_info_text = f"Params: {trainable_params}/{all_params} ({100 * trainable_params / all_params:.4f}% trainable)"
def do_train(
# Dataset
template,
load_dataset_from,
dataset_from_data_dir,
dataset_text,
dataset_text_format,
dataset_plain_text_input_variables_separator,
dataset_plain_text_input_and_output_separator,
dataset_plain_text_data_separator,
# Training Options
max_seq_length,
evaluate_data_count,
micro_batch_size,
gradient_accumulation_steps,
epochs,
learning_rate,
train_on_inputs,
lora_r,
lora_alpha,
lora_dropout,
lora_target_modules,
lora_modules_to_save,
load_in_8bit,
fp16,
bf16,
gradient_checkpointing,
save_steps,
save_total_limit,
logging_steps,
additional_training_arguments,
additional_lora_config,
model_name,
continue_from_model,
continue_from_checkpoint,
progress=gr.Progress(track_tqdm=False),
):
if Global.is_training or Global.is_train_starting:
return render_training_status() + render_loss_plot()
reset_training_status()
Global.is_train_starting = True
try:
base_model_name = Global.base_model_name
tokenizer_name = Global.tokenizer_name or Global.base_model_name
resume_from_checkpoint_param = None
if continue_from_model == "-" or continue_from_model == "None":
continue_from_model = None
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
continue_from_checkpoint = None
if continue_from_model:
resume_from_model_path = os.path.join(
Config.data_dir, "lora_models", continue_from_model)
resume_from_checkpoint_param = resume_from_model_path
if continue_from_checkpoint:
resume_from_checkpoint_param = os.path.join(
resume_from_checkpoint_param, continue_from_checkpoint)
will_be_resume_from_checkpoint_file = os.path.join(
resume_from_checkpoint_param, "pytorch_model.bin")
if not os.path.exists(will_be_resume_from_checkpoint_file):
raise ValueError(
f"Unable to resume from checkpoint {continue_from_model}/{continue_from_checkpoint}. Resuming is only possible from checkpoints stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
else:
will_be_resume_from_checkpoint_file = os.path.join(
resume_from_checkpoint_param, "adapter_model.bin")
if not os.path.exists(will_be_resume_from_checkpoint_file):
# Try to get model in Hugging Face cache
resume_from_checkpoint_param = None
possible_hf_model_name = None
possible_model_info_file = os.path.join(
resume_from_model_path, "info.json")
if "/" in continue_from_model:
possible_hf_model_name = continue_from_model
elif os.path.exists(possible_model_info_file):
with open(possible_model_info_file, "r") as file:
model_info = json.load(file)
possible_hf_model_name = model_info.get(
"hf_model_name")
if possible_hf_model_name:
possible_hf_model_cached_path = try_to_load_from_cache(
possible_hf_model_name, 'adapter_model.bin')
if not possible_hf_model_cached_path:
snapshot_download(possible_hf_model_name)
possible_hf_model_cached_path = try_to_load_from_cache(
possible_hf_model_name, 'adapter_model.bin')
if possible_hf_model_cached_path:
resume_from_checkpoint_param = os.path.dirname(
possible_hf_model_cached_path)
if not resume_from_checkpoint_param:
raise ValueError(
f"Unable to continue from model {continue_from_model}. Continuation is only possible from models stored locally in the data directory. Please ensure that the file '{will_be_resume_from_checkpoint_file}' exists.")
output_dir = os.path.join(Config.data_dir, "lora_models", model_name)
if os.path.exists(output_dir):
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
raise ValueError(
f"The output directory already exists and is not empty. ({output_dir})")
wandb_group = template
wandb_tags = [f"template:{template}"]
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
wandb_group += f"/{dataset_from_data_dir}"
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
finetune_args = {
'base_model': base_model_name,
'tokenizer': tokenizer_name,
'output_dir': output_dir,
'micro_batch_size': micro_batch_size,
'gradient_accumulation_steps': gradient_accumulation_steps,
'num_train_epochs': epochs,
'learning_rate': learning_rate,
'cutoff_len': max_seq_length,
'val_set_size': evaluate_data_count,
'lora_r': lora_r,
'lora_alpha': lora_alpha,
'lora_dropout': lora_dropout,
'lora_target_modules': lora_target_modules,
'lora_modules_to_save': lora_modules_to_save,
'train_on_inputs': train_on_inputs,
'load_in_8bit': load_in_8bit,
'fp16': fp16,
'bf16': bf16,
'gradient_checkpointing': gradient_checkpointing,
'group_by_length': False,
'resume_from_checkpoint': resume_from_checkpoint_param,
'save_steps': save_steps,
'save_total_limit': save_total_limit,
'logging_steps': logging_steps,
'additional_training_arguments': additional_training_arguments,
'additional_lora_config': additional_lora_config,
'wandb_api_key': Config.wandb_api_key,
'wandb_project': Config.default_wandb_project if Config.enable_wandb else None,
'wandb_group': wandb_group,
'wandb_run_name': model_name,
'wandb_tags': wandb_tags
}
prompter = Prompter(template)
data = get_data_from_input(
load_dataset_from=load_dataset_from,
dataset_text=dataset_text,
dataset_text_format=dataset_text_format,
dataset_plain_text_input_variables_separator=dataset_plain_text_input_variables_separator,
dataset_plain_text_input_and_output_separator=dataset_plain_text_input_and_output_separator,
dataset_plain_text_data_separator=dataset_plain_text_data_separator,
dataset_from_data_dir=dataset_from_data_dir,
prompter=prompter
)
def training():
Global.is_training = True
try:
# Need RAM for training
unload_models()
Global.new_base_model_that_is_ready_to_be_used = None
Global.name_of_new_base_model_that_is_ready_to_be_used = None
clear_cache()
train_data = prompter.get_train_data_from_dataset(data)
if Config.ui_dev_mode:
Global.training_args = TrainingArguments(
logging_steps=logging_steps, output_dir=""
)
message = "Currently in UI dev mode, not doing the actual training."
message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
print(message)
total_epochs = epochs
total_steps = len(train_data) * epochs
log_history = []
initial_loss = 2
loss_decay_rate = 0.8
for i in range(total_steps):
if (Global.should_stop_training):
break
current_step = i + 1
current_epoch = i / (total_steps / total_epochs)
if (current_step % logging_steps == 0):
loss = initial_loss * \
math.exp(-loss_decay_rate * current_epoch)
log_history.append({
'loss': loss,
'learning_rate': 0.0001,
'epoch': current_epoch
})
update_training_states(
total_steps=total_steps,
current_step=current_step,
total_epochs=total_epochs,
current_epoch=current_epoch,
log_history=log_history
)
time.sleep(0.1)
result_message = set_train_output(message)
print(result_message)
time.sleep(1)
Global.is_training = False
return
training_callbacks = [UiTrainerCallback]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
dataset_name = "N/A (from text input)"
if load_dataset_from == "Data Dir":
dataset_name = dataset_from_data_dir
info = {
'base_model': base_model_name,
'prompt_template': template,
'dataset_name': dataset_name,
'dataset_rows': len(train_data),
'trained_on_machine': socket.gethostname(),
'timestamp': time.time(),
}
if continue_from_model:
info['continued_from_model'] = continue_from_model
if continue_from_checkpoint:
info['continued_from_checkpoint'] = continue_from_checkpoint
if Global.version:
info['tuner_version'] = Global.version
json.dump(info, info_json_file, indent=2)
train_output = Global.finetune_train_fn(
train_data=train_data,
callbacks=training_callbacks,
status_message_callback=status_message_callback,
params_info_callback=params_info_callback,
additional_wandb_config=info,
**finetune_args,
)
result_message = set_train_output(train_output)
print(result_message + "\n" + str(train_output))
clear_cache()
Global.is_training = False
except Exception as e:
traceback.print_exc()
Global.training_error_message = str(e)
finally:
Global.is_training = False
training_thread = threading.Thread(target=training)
training_thread.daemon = True
training_thread.start()
except Exception as e:
Global.is_training = False
traceback.print_exc()
Global.training_error_message = str(e)
finally:
Global.is_train_starting = False
return render_training_status() + render_loss_plot()
def render_training_status():
if not Global.is_training:
if Global.is_train_starting:
html_content = """
<div class="progress-block">
<div class="progress-level">
<div class="progress-level-inner">
Starting...
</div>
</div>
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
if Global.training_error_message:
html_content = f"""
<div class="progress-block is_error">
<div class="progress-level">
<div class="error">
<div class="title">
⚠ Something went wrong
</div>
<div class="error-message">{Global.training_error_message}</div>
</div>
</div>
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
if Global.train_output_str:
end_message = "✅ Training completed"
if Global.should_stop_training:
end_message = "🛑 Train aborted"
params_info_html = ""
if Global.training_params_info_text:
params_info_html = f"""
<div class="params-info">
{Global.training_params_info_text}
</div>
"""
html_content = f"""
<div class="progress-block">
<div class="progress-level">
<div class="output">
<div class="title">
{end_message}
</div>
<div class="message">{Global.train_output_str}</div>
</div>
</div>
{params_info_html}
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
if Global.training_status_text:
html_content = f"""
<div class="progress-block">
<div class="status">{Global.training_status_text}</div>
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
html_content = """
<div class="progress-block">
<div class="empty-text">
Training status will be shown here
</div>
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
meta_info = []
meta_info.append(
f"{Global.training_current_step}/{Global.training_total_steps} steps")
current_time = time.time()
time_elapsed = current_time - Global.train_started_at
time_remaining = -1
if Global.training_eta:
time_remaining = Global.training_eta - current_time
if time_remaining >= 0:
meta_info.append(
f"{format_time(time_elapsed)}<{format_time(time_remaining)}")
meta_info.append(f"ETA: {format_timestamp(Global.training_eta)}")
else:
meta_info.append(format_time(time_elapsed))
params_info_html = ""
if Global.training_params_info_text:
params_info_html = f"""
<div class="params-info">
{Global.training_params_info_text}
</div>
"""
html_content = f"""
<div class="progress-block is_training">
<div class="meta-text">{' | '.join(meta_info)}</div>
<div class="progress-level">
<div class="progress-level-inner">
{Global.training_status_text} - {Global.training_progress * 100:.2f}%
</div>
<div class="progress-bar-wrap">
<div class="progress-bar" style="width: {Global.training_progress * 100:.2f}%;">
</div>
</div>
</div>
{params_info_html}
</div>
"""
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
def render_loss_plot():
if len(Global.training_log_history) <= 2:
return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
max_elements = 5000
training_log_history = sample_evenly(
Global.training_log_history, max_elements=max_elements)
logging_steps = Global.training_args and Global.training_args.logging_steps
loss_data = [
{
'type': 'train_loss' if 'loss' in item else 'eval_loss',
'loss': item.get('loss') or item.get('eval_loss'),
'epoch': item.get('epoch')
} for item in training_log_history
if ('loss' in item or 'eval_loss' in item)
and 'epoch' in item
]
use_steps = False
if len(Global.training_log_history) <= max_elements and logging_steps:
for index, item in enumerate(loss_data):
item["step"] = index * logging_steps
use_steps = True
source = pd.DataFrame(loss_data)
highlight = alt.selection(
type='single', # type: ignore
on='mouseover', fields=['type'], nearest=True
)
if use_steps:
base = alt.Chart(source).encode( # type: ignore
x='step:Q',
y='loss:Q',
color='type:N',
tooltip=['type:N', 'loss:Q', 'step:Q', 'epoch:Q']
)
else:
base = alt.Chart(source).encode( # type: ignore
x='epoch:Q',
y='loss:Q',
color='type:N',
tooltip=['type:N', 'loss:Q', 'epoch:Q']
)
points = base.mark_circle().encode(
opacity=alt.value(0)
).add_selection(
highlight
).properties(
width=640
)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)
return (gr.Column.update(visible=True), gr.Plot.update(points + lines, visible=True))
def format_time(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
if hours == 0:
return "{:02d}:{:02d}".format(int(minutes), int(seconds))
else:
return "{:02d}:{:02d}:{:02d}".format(int(hours), int(minutes), int(seconds))
def format_timestamp(timestamp):
dt_naive = datetime.datetime.utcfromtimestamp(timestamp)
utc = pytz.UTC
timezone = Config.timezone
dt_aware = utc.localize(dt_naive).astimezone(timezone)
now = datetime.datetime.now(timezone)
delta = dt_aware.date() - now.date()
if delta.days == 0:
time_str = ""
elif delta.days == 1:
time_str = "tomorrow at "
elif delta.days == -1:
time_str = "yesterday at "
else:
time_str = dt_aware.strftime('%A, %B %d at ')
time_str += dt_aware.strftime('%I:%M %p').lower()
return time_str