zetavg
show loss/epoch chart on finetune ui
3daa16f unverified
raw
history blame
36.8 kB
import os
import json
from datetime import datetime
import gradio as gr
from random_word import RandomWords
from ...config import Config
from ...globals import Global
from ...utils.data import (
get_available_template_names,
get_available_dataset_names,
get_available_lora_model_names
)
from ...utils.relative_read_file import relative_read_file
from ..css_styles import register_css_style
from .values import (
default_dataset_plain_text_input_variables_separator,
default_dataset_plain_text_input_and_output_separator,
default_dataset_plain_text_data_separator,
sample_plain_text_value,
sample_jsonl_text_value,
sample_json_text_value,
)
from .previewing import (
refresh_preview,
refresh_dataset_items_count,
)
from .training import (
do_train,
render_training_status,
render_loss_plot
)
register_css_style('finetune', relative_read_file(__file__, "style.css"))
def random_hyphenated_word():
r = RandomWords()
word1 = r.get_random_word()
word2 = r.get_random_word()
return word1 + '-' + word2
def random_name():
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d-%H-%M-%S")
return f"{random_hyphenated_word()}-{formatted_datetime}"
def reload_selections(current_template, current_dataset):
available_template_names = get_available_template_names()
available_template_names_with_none = available_template_names + ["None"]
if current_template not in available_template_names_with_none:
current_template = None
current_template = current_template or next(
iter(available_template_names_with_none), None)
available_dataset_names = get_available_dataset_names()
if current_dataset not in available_dataset_names:
current_dataset = None
current_dataset = current_dataset or next(
iter(available_dataset_names), None)
available_lora_models = ["-"] + get_available_lora_model_names()
return (
gr.Dropdown.update(
choices=available_template_names_with_none,
value=current_template),
gr.Dropdown.update(
choices=available_dataset_names,
value=current_dataset),
gr.Dropdown.update(choices=available_lora_models)
)
def handle_switch_dataset_source(source):
if source == "Text Input":
return gr.Column.update(visible=True), gr.Column.update(visible=False)
else:
return gr.Column.update(visible=False), gr.Column.update(visible=True)
def handle_switch_dataset_text_format(format):
if format == "Plain Text":
return gr.Column.update(visible=True)
return gr.Column.update(visible=False)
def load_sample_dataset_to_text_input(format):
if format == "JSON":
return gr.Code.update(value=sample_json_text_value)
if format == "JSON Lines":
return gr.Code.update(value=sample_jsonl_text_value)
else: # Plain Text
return gr.Code.update(value=sample_plain_text_value)
def handle_continue_from_model_change(model_name):
try:
lora_models_directory_path = os.path.join(
Config.data_dir, "lora_models")
lora_model_directory_path = os.path.join(
lora_models_directory_path, model_name)
all_files = os.listdir(lora_model_directory_path)
checkpoints = [
file for file in all_files if file.startswith("checkpoint-")]
checkpoints = ["-"] + checkpoints
can_load_params = "finetune_params.json" in all_files or "finetune_args.json" in all_files
return (gr.Dropdown.update(choices=checkpoints, value="-"),
gr.Button.update(visible=can_load_params),
gr.Markdown.update(value="", visible=False))
except Exception:
pass
return (gr.Dropdown.update(choices=["-"], value="-"),
gr.Button.update(visible=False),
gr.Markdown.update(value="", visible=False))
def handle_load_params_from_model(
model_name,
template, load_dataset_from, dataset_from_data_dir,
max_seq_length,
evaluate_data_count,
micro_batch_size,
gradient_accumulation_steps,
epochs,
learning_rate,
train_on_inputs,
lora_r,
lora_alpha,
lora_dropout,
lora_target_modules,
lora_modules_to_save,
load_in_8bit,
fp16,
bf16,
gradient_checkpointing,
save_steps,
save_total_limit,
logging_steps,
additional_training_arguments,
additional_lora_config,
lora_target_module_choices,
lora_modules_to_save_choices,
):
error_message = ""
notice_message = ""
unknown_keys = []
try:
lora_models_directory_path = os.path.join(
Config.data_dir, "lora_models")
lora_model_directory_path = os.path.join(
lora_models_directory_path, model_name)
try:
with open(os.path.join(lora_model_directory_path, "info.json"), "r") as f:
info = json.load(f)
if isinstance(info, dict):
model_prompt_template = info.get("prompt_template")
if model_prompt_template:
template = model_prompt_template
model_dataset_name = info.get("dataset_name")
if model_dataset_name and isinstance(model_dataset_name, str) and not model_dataset_name.startswith("N/A"):
load_dataset_from = "Data Dir"
dataset_from_data_dir = model_dataset_name
except FileNotFoundError:
pass
data = {}
possible_files = ["finetune_params.json", "finetune_args.json"]
for file in possible_files:
try:
with open(os.path.join(lora_model_directory_path, file), "r") as f:
data = json.load(f)
except FileNotFoundError:
pass
for key, value in data.items():
if key == "max_seq_length":
max_seq_length = value
if key == "cutoff_len":
max_seq_length = value
elif key == "evaluate_data_count":
evaluate_data_count = value
elif key == "val_set_size":
evaluate_data_count = value
elif key == "micro_batch_size":
micro_batch_size = value
elif key == "gradient_accumulation_steps":
gradient_accumulation_steps = value
elif key == "epochs":
epochs = value
elif key == "num_train_epochs":
epochs = value
elif key == "learning_rate":
learning_rate = value
elif key == "train_on_inputs":
train_on_inputs = value
elif key == "lora_r":
lora_r = value
elif key == "lora_alpha":
lora_alpha = value
elif key == "lora_dropout":
lora_dropout = value
elif key == "lora_target_modules":
lora_target_modules = value
if value:
for element in value:
if element not in lora_target_module_choices:
lora_target_module_choices.append(element)
elif key == "lora_modules_to_save":
lora_modules_to_save = value
if value:
for element in value:
if element not in lora_modules_to_save_choices:
lora_modules_to_save_choices.append(element)
elif key == "load_in_8bit":
load_in_8bit = value
elif key == "fp16":
fp16 = value
elif key == "bf16":
bf16 = value
elif key == "gradient_checkpointing":
gradient_checkpointing = value
elif key == "save_steps":
save_steps = value
elif key == "save_total_limit":
save_total_limit = value
elif key == "logging_steps":
logging_steps = value
elif key == "additional_training_arguments":
if value:
additional_training_arguments = json.dumps(value, indent=2)
else:
additional_training_arguments = ""
elif key == "additional_lora_config":
if value:
additional_lora_config = json.dumps(value, indent=2)
else:
additional_lora_config = ""
elif key == "group_by_length":
pass
elif key == "resume_from_checkpoint":
pass
else:
unknown_keys.append(key)
except Exception as e:
error_message = str(e)
if len(unknown_keys) > 0:
notice_message = f"Note: cannot restore unknown arg: {', '.join([f'`{x}`' for x in unknown_keys])}"
message = ". ".join([x for x in [error_message, notice_message] if x])
has_message = False
if message:
message += "."
has_message = True
return (
gr.Markdown.update(value=message, visible=has_message),
template, load_dataset_from, dataset_from_data_dir,
max_seq_length,
evaluate_data_count,
micro_batch_size,
gradient_accumulation_steps,
epochs,
learning_rate,
train_on_inputs,
lora_r,
lora_alpha,
lora_dropout,
gr.CheckboxGroup.update(value=lora_target_modules,
choices=lora_target_module_choices),
gr.CheckboxGroup.update(
value=lora_modules_to_save, choices=lora_modules_to_save_choices),
load_in_8bit,
fp16,
bf16,
gradient_checkpointing,
save_steps,
save_total_limit,
logging_steps,
additional_training_arguments,
additional_lora_config,
lora_target_module_choices,
lora_modules_to_save_choices
)
default_lora_target_module_choices = ["q_proj", "k_proj", "v_proj", "o_proj"]
default_lora_modules_to_save_choices = ["model.embed_tokens", "lm_head"]
def handle_lora_target_modules_add(choices, new_module, selected_modules):
choices.append(new_module)
selected_modules.append(new_module)
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
def handle_lora_modules_to_save_add(choices, new_module, selected_modules):
choices.append(new_module)
selected_modules.append(new_module)
return (choices, "", gr.CheckboxGroup.update(value=selected_modules, choices=choices))
def do_abort_training():
Global.should_stop_training = True
Global.training_status_text = "Aborting..."
def finetune_ui():
things_that_might_timeout = []
with gr.Blocks() as finetune_ui_blocks:
with gr.Column(elem_id="finetune_ui_content"):
with gr.Tab("Prepare"):
with gr.Box(elem_id="finetune_ui_select_dataset_source"):
with gr.Row():
template = gr.Dropdown(
label="Template",
elem_id="finetune_template",
)
load_dataset_from = gr.Radio(
["Text Input", "Data Dir"],
label="Load Dataset From",
value="Text Input",
elem_id="finetune_load_dataset_from")
reload_selections_button = gr.Button(
"↻",
elem_id="finetune_reload_selections_button"
)
reload_selections_button.style(
full_width=False,
size="sm")
with gr.Column(
elem_id="finetune_dataset_from_data_dir_group",
visible=False
) as dataset_from_data_dir_group:
dataset_from_data_dir = gr.Dropdown(
label="Dataset",
elem_id="finetune_dataset_from_data_dir",
)
dataset_from_data_dir_message = gr.Markdown(
"",
visible=False,
elem_id="finetune_dataset_from_data_dir_message")
with gr.Box(elem_id="finetune_dataset_text_input_group") as dataset_text_input_group:
gr.Textbox(
label="Training Data", elem_classes="textbox_that_is_only_used_to_display_a_label")
dataset_text = gr.Code(
show_label=False,
language="json",
value=sample_plain_text_value,
# max_lines=40,
elem_id="finetune_dataset_text_input_textbox")
dataset_from_text_message = gr.Markdown(
"",
visible=False,
elem_id="finetune_dataset_from_text_message")
gr.Markdown(
"The data you entered here will not be saved. Do not make edits here directly. Instead, edit the data elsewhere then paste it here.")
with gr.Row():
with gr.Column():
dataset_text_format = gr.Radio(
["Plain Text", "JSON Lines", "JSON"],
label="Format", value="Plain Text", elem_id="finetune_dataset_text_format")
dataset_text_load_sample_button = gr.Button(
"Load Sample", elem_id="finetune_dataset_text_load_sample_button")
dataset_text_load_sample_button.style(
full_width=False,
size="sm")
with gr.Column(elem_id="finetune_dataset_plain_text_separators_group") as dataset_plain_text_separators_group:
dataset_plain_text_input_variables_separator = gr.Textbox(
label="Input Variables Separator",
elem_id="dataset_plain_text_input_variables_separator",
placeholder=default_dataset_plain_text_input_variables_separator,
value=default_dataset_plain_text_input_variables_separator)
dataset_plain_text_input_and_output_separator = gr.Textbox(
label="Input and Output Separator",
elem_id="dataset_plain_text_input_and_output_separator",
placeholder=default_dataset_plain_text_input_and_output_separator,
value=default_dataset_plain_text_input_and_output_separator)
dataset_plain_text_data_separator = gr.Textbox(
label="Data Separator",
elem_id="dataset_plain_text_data_separator",
placeholder=default_dataset_plain_text_data_separator,
value=default_dataset_plain_text_data_separator)
things_that_might_timeout.append(
dataset_text_format.change(
fn=handle_switch_dataset_text_format,
inputs=[dataset_text_format],
outputs=[
dataset_plain_text_separators_group # type: ignore
]
))
things_that_might_timeout.append(
dataset_text_load_sample_button.click(fn=load_sample_dataset_to_text_input, inputs=[
dataset_text_format], outputs=[dataset_text]))
gr.Markdown(
"💡 Switch to the \"Preview\" tab to verify that your inputs are correct.")
with gr.Tab("Preview"):
with gr.Row():
finetune_dataset_preview_info_message = gr.Markdown(
"Set the dataset in the \"Prepare\" tab, then preview it here.",
elem_id="finetune_dataset_preview_info_message"
)
finetune_dataset_preview_count = gr.Number(
label="Preview items count",
value=10,
# minimum=1,
# maximum=100,
precision=0,
elem_id="finetune_dataset_preview_count"
)
finetune_dataset_preview = gr.Dataframe(
wrap=True, elem_id="finetune_dataset_preview")
things_that_might_timeout.append(
load_dataset_from.change(
fn=handle_switch_dataset_source,
inputs=[load_dataset_from],
outputs=[
dataset_text_input_group,
dataset_from_data_dir_group
] # type: ignore
))
dataset_inputs = [
template,
load_dataset_from,
dataset_from_data_dir,
dataset_text,
dataset_text_format,
dataset_plain_text_input_variables_separator,
dataset_plain_text_input_and_output_separator,
dataset_plain_text_data_separator,
]
dataset_preview_inputs = dataset_inputs + \
[finetune_dataset_preview_count]
with gr.Row():
max_seq_length = gr.Slider(
minimum=1, maximum=4096, value=512,
label="Max Sequence Length",
info="The maximum length of each sample text sequence. Sequences longer than this will be truncated.",
elem_id="finetune_max_seq_length"
)
train_on_inputs = gr.Checkbox(
label="Train on Inputs",
value=True,
info="If not enabled, inputs will be masked out in loss.",
elem_id="finetune_train_on_inputs"
)
with gr.Row():
# https://huggingface.co/docs/transformers/main/main_classes/trainer
micro_batch_size_default_value = 1
if Global.gpu_total_cores is not None and Global.gpu_total_memory is not None:
memory_per_core = Global.gpu_total_memory / Global.gpu_total_cores
if memory_per_core >= 6291456:
micro_batch_size_default_value = 8
elif memory_per_core >= 4000000: # ?
micro_batch_size_default_value = 4
with gr.Column():
micro_batch_size = gr.Slider(
minimum=1, maximum=100, step=1, value=micro_batch_size_default_value,
label="Micro Batch Size",
info="The number of examples in each mini-batch for gradient computation. A smaller micro_batch_size reduces memory usage but may increase training time."
)
gradient_accumulation_steps = gr.Slider(
minimum=1, maximum=10, step=1, value=1,
label="Gradient Accumulation Steps",
info="The number of steps to accumulate gradients before updating model parameters. This can be used to simulate a larger effective batch size without increasing memory usage."
)
epochs = gr.Slider(
minimum=1, maximum=100, step=1, value=10,
label="Epochs",
info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
learning_rate = gr.Slider(
minimum=0.00001, maximum=0.01, value=3e-4,
label="Learning Rate",
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
)
with gr.Column(elem_id="finetune_eval_data_group"):
evaluate_data_count = gr.Slider(
minimum=0, maximum=1, step=1, value=0,
label="Evaluation Data Count",
info="The number of data to be used for evaluation. This specific amount of data will be randomly chosen from the training dataset for evaluating the model's performance during the process, without contributing to the actual training.",
elem_id="finetune_evaluate_data_count"
)
gr.HTML(elem_classes="flex_vertical_grow_area")
with gr.Accordion("Advanced Options", open=False, elem_id="finetune_advance_options_accordion"):
with gr.Row(elem_id="finetune_advanced_options_checkboxes"):
load_in_8bit = gr.Checkbox(
label="8bit", value=Config.load_8bit)
fp16 = gr.Checkbox(label="FP16", value=True)
bf16 = gr.Checkbox(label="BF16", value=False)
gradient_checkpointing = gr.Checkbox(
label="gradient_checkpointing", value=False)
with gr.Column(variant="panel", elem_id="finetune_additional_training_arguments_box"):
gr.Textbox(
label="Additional Training Arguments",
info="Additional training arguments to be passed to the Trainer. Note that this can override ALL other arguments set elsewhere. See https://bit.ly/hf20-transformers-training-arguments for more details.",
elem_id="finetune_additional_training_arguments_textbox_for_label_display"
)
additional_training_arguments = gr.Code(
label="JSON",
language="json",
value="",
lines=2,
elem_id="finetune_additional_training_arguments")
with gr.Box(elem_id="finetune_continue_from_model_box"):
with gr.Row():
continue_from_model = gr.Dropdown(
value="-",
label="Continue from Model",
choices=["-"],
allow_custom_value=True,
elem_id="finetune_continue_from_model"
)
continue_from_checkpoint = gr.Dropdown(
value="-",
label="Resume from Checkpoint",
choices=["-"],
elem_id="finetune_continue_from_checkpoint")
with gr.Column():
load_params_from_model_btn = gr.Button(
"Load training parameters from selected model", visible=False)
load_params_from_model_btn.style(
full_width=False,
size="sm")
load_params_from_model_message = gr.Markdown(
"", visible=False)
things_that_might_timeout.append(
continue_from_model.change(
fn=handle_continue_from_model_change,
inputs=[continue_from_model],
outputs=[
continue_from_checkpoint,
load_params_from_model_btn,
load_params_from_model_message
]
)
)
with gr.Column():
lora_r = gr.Slider(
minimum=1, maximum=16, step=1, value=8,
label="LoRA R",
info="The rank parameter for LoRA, which controls the dimensionality of the rank decomposition matrices. A larger lora_r increases the expressiveness and flexibility of LoRA but also increases the number of trainable parameters and memory usage."
)
lora_alpha = gr.Slider(
minimum=1, maximum=128, step=1, value=16,
label="LoRA Alpha",
info="The scaling parameter for LoRA, which controls how much LoRA affects the original pre-trained model weights. A larger lora_alpha amplifies the impact of LoRA but may also distort or override the pre-trained knowledge."
)
lora_dropout = gr.Slider(
minimum=0, maximum=1, value=0.05,
label="LoRA Dropout",
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
)
with gr.Column(elem_id="finetune_lora_target_modules_box"):
lora_target_modules = gr.CheckboxGroup(
label="LoRA Target Modules",
choices=default_lora_target_module_choices,
value=["q_proj", "v_proj"],
info="Modules to replace with LoRA.",
elem_id="finetune_lora_target_modules"
)
lora_target_module_choices = gr.State(
value=default_lora_target_module_choices) # type: ignore
with gr.Box(elem_id="finetune_lora_target_modules_add_box"):
with gr.Row():
lora_target_modules_add = gr.Textbox(
lines=1, max_lines=1, show_label=False,
elem_id="finetune_lora_target_modules_add"
)
lora_target_modules_add_btn = gr.Button(
"Add",
elem_id="finetune_lora_target_modules_add_btn"
)
lora_target_modules_add_btn.style(
full_width=False, size="sm")
things_that_might_timeout.append(lora_target_modules_add_btn.click(
handle_lora_target_modules_add,
inputs=[lora_target_module_choices,
lora_target_modules_add, lora_target_modules],
outputs=[lora_target_module_choices,
lora_target_modules_add, lora_target_modules],
))
with gr.Accordion("Advanced LoRA Options", open=False, elem_id="finetune_advance_lora_options_accordion"):
with gr.Column(elem_id="finetune_lora_modules_to_save_box"):
lora_modules_to_save = gr.CheckboxGroup(
label="LoRA Modules To Save",
choices=default_lora_modules_to_save_choices,
value=[],
# info="",
elem_id="finetune_lora_modules_to_save"
)
lora_modules_to_save_choices = gr.State(
value=default_lora_modules_to_save_choices) # type: ignore
with gr.Box(elem_id="finetune_lora_modules_to_save_add_box"):
with gr.Row():
lora_modules_to_save_add = gr.Textbox(
lines=1, max_lines=1, show_label=False,
elem_id="finetune_lora_modules_to_save_add"
)
lora_modules_to_save_add_btn = gr.Button(
"Add",
elem_id="finetune_lora_modules_to_save_add_btn"
)
lora_modules_to_save_add_btn.style(
full_width=False, size="sm")
things_that_might_timeout.append(lora_modules_to_save_add_btn.click(
handle_lora_modules_to_save_add,
inputs=[lora_modules_to_save_choices,
lora_modules_to_save_add, lora_modules_to_save],
outputs=[lora_modules_to_save_choices,
lora_modules_to_save_add, lora_modules_to_save],
))
with gr.Column(variant="panel", elem_id="finetune_additional_lora_config_box"):
gr.Textbox(
label="Additional LoRA Config",
info="Additional LoraConfig. Note that this can override ALL other arguments set elsewhere.",
elem_id="finetune_additional_lora_config_textbox_for_label_display"
)
additional_lora_config = gr.Code(
label="JSON",
language="json",
value="",
lines=2,
elem_id="finetune_additional_lora_config")
gr.HTML(elem_classes="flex_vertical_grow_area no_limit")
with gr.Column(elem_id="finetune_log_and_save_options_group_container"):
with gr.Row(elem_id="finetune_log_and_save_options_group"):
logging_steps = gr.Number(
label="Logging Steps",
precision=0,
value=10,
elem_id="finetune_logging_steps"
)
save_steps = gr.Number(
label="Steps Per Save",
precision=0,
value=500,
elem_id="finetune_save_steps"
)
save_total_limit = gr.Number(
label="Saved Checkpoints Limit",
precision=0,
value=5,
elem_id="finetune_save_total_limit"
)
with gr.Column(elem_id="finetune_model_name_group"):
model_name = gr.Textbox(
lines=1, label="LoRA Model Name", value=random_name,
max_lines=1,
info="The name of the new LoRA model.",
elem_id="finetune_model_name",
)
with gr.Row():
with gr.Column():
pass
with gr.Column():
with gr.Row():
train_btn = gr.Button(
"Train", variant="primary", label="Train",
elem_id="finetune_start_btn"
)
abort_button = gr.Button(
"Abort", label="Abort",
elem_id="finetune_stop_btn"
)
confirm_abort_button = gr.Button(
"Confirm Abort", label="Confirm Abort", variant="stop",
elem_id="finetune_confirm_stop_btn"
)
things_that_might_timeout.append(reload_selections_button.click(
reload_selections,
inputs=[template, dataset_from_data_dir],
outputs=[template, dataset_from_data_dir, continue_from_model],
))
for i in dataset_preview_inputs:
things_that_might_timeout.append(
i.change(
fn=refresh_preview,
inputs=dataset_preview_inputs,
outputs=[
finetune_dataset_preview,
finetune_dataset_preview_info_message,
dataset_from_text_message,
dataset_from_data_dir_message
]
).then(
fn=refresh_dataset_items_count,
inputs=dataset_preview_inputs,
outputs=[
finetune_dataset_preview_info_message,
dataset_from_text_message,
dataset_from_data_dir_message,
evaluate_data_count,
]
))
finetune_args = [
max_seq_length,
evaluate_data_count,
micro_batch_size,
gradient_accumulation_steps,
epochs,
learning_rate,
train_on_inputs,
lora_r,
lora_alpha,
lora_dropout,
lora_target_modules,
lora_modules_to_save,
load_in_8bit,
fp16,
bf16,
gradient_checkpointing,
save_steps,
save_total_limit,
logging_steps,
additional_training_arguments,
additional_lora_config,
]
things_that_might_timeout.append(
load_params_from_model_btn.click(
fn=handle_load_params_from_model,
inputs=(
[continue_from_model] +
[template, load_dataset_from, dataset_from_data_dir] +
finetune_args +
[lora_target_module_choices, lora_modules_to_save_choices]
), # type: ignore
outputs=(
[load_params_from_model_message] +
[template, load_dataset_from, dataset_from_data_dir] +
finetune_args +
[lora_target_module_choices, lora_modules_to_save_choices]
) # type: ignore
)
)
train_status = gr.HTML(
"",
label="Train Output",
elem_id="finetune_training_status")
with gr.Column(visible=False, elem_id="finetune_loss_plot_container") as loss_plot_container:
loss_plot = gr.Plot(
visible=False, show_label=False,
elem_id="finetune_loss_plot")
training_indicator = gr.HTML(
"training_indicator", visible=False, elem_id="finetune_training_indicator")
train_start = train_btn.click(
fn=do_train,
inputs=(dataset_inputs + finetune_args + [
model_name,
continue_from_model,
continue_from_checkpoint,
]),
outputs=[train_status, training_indicator,
loss_plot_container, loss_plot]
)
# controlled by JS, shows the confirm_abort_button
abort_button.click(None, None, None, None)
confirm_abort_button.click(
fn=do_abort_training,
inputs=None, outputs=None,
cancels=[train_start])
training_status_updates = finetune_ui_blocks.load(
fn=render_training_status,
inputs=None,
outputs=[train_status, training_indicator],
every=0.2
)
loss_plot_updates = finetune_ui_blocks.load(
fn=render_loss_plot,
inputs=None,
outputs=[loss_plot_container, loss_plot],
every=10
)
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
# things_that_might_timeout.append(training_status_updates)
stop_timeoutable_btn = gr.Button(
"stop not-responding elements",
elem_id="inference_stop_timeoutable_btn",
elem_classes="foot_stop_timeoutable_btn")
stop_timeoutable_btn.click(
fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)