Spaces:
Sleeping
Sleeping
import asyncio | |
import glob | |
import itertools | |
import logging | |
import os | |
import random | |
import shutil | |
import time | |
import traceback | |
import zipfile | |
from pathlib import Path | |
from typing import Callable, List, Optional, Set, Union | |
import accelerate | |
import einops | |
import huggingface_hub | |
import numpy as np | |
import pandas as pd | |
import torch | |
import transformers | |
import yaml | |
from h2o_wave import Q, data, ui | |
from sqlitedict import SqliteDict | |
from llm_studio.app_utils.config import default_cfg | |
from llm_studio.app_utils.hugging_face_utils import ( | |
get_chat_template, | |
get_model_card, | |
publish_model_to_hugging_face, | |
) | |
from llm_studio.app_utils.sections.chat import chat_tab, load_cfg_model_tokenizer | |
from llm_studio.app_utils.sections.common import clean_dashboard | |
from llm_studio.app_utils.utils import ( | |
GridCheckError, | |
add_model_type, | |
filter_grid_search_combination, | |
flatten_dict, | |
get_cfg_list_items, | |
get_data_dir, | |
get_download_link, | |
get_experiment_status, | |
get_experiments, | |
get_grid_search, | |
get_model_types, | |
get_problem_categories, | |
get_problem_types, | |
get_ui_elements_for_cfg, | |
get_unique_name, | |
hf_repo_friendly_name, | |
parse_ui_elements, | |
remove_model_type, | |
set_env, | |
set_grid_to_cfg, | |
start_experiment, | |
) | |
from llm_studio.app_utils.wave_utils import busy_dialog, ui_table_from_df, wave_theme | |
from llm_studio.python_configs.cfg_checks import check_config_for_errors | |
from llm_studio.src.datasets.text_utils import get_tokenizer | |
from llm_studio.src.tooltips import tooltips | |
from llm_studio.src.utils.config_utils import ( | |
GENERATION_PROBLEM_TYPES, | |
NON_GENERATION_PROBLEM_TYPES, | |
load_config_py, | |
load_config_yaml, | |
save_config_yaml, | |
) | |
from llm_studio.src.utils.exceptions import LLMResourceException | |
from llm_studio.src.utils.export_utils import ( | |
check_available_space, | |
get_adapter_model_path, | |
get_artifact_path_path, | |
get_logs_path, | |
get_model_path, | |
get_predictions_path, | |
save_logs, | |
save_prediction_outputs, | |
) | |
from llm_studio.src.utils.logging_utils import write_flag | |
from llm_studio.src.utils.modeling_utils import unwrap_model | |
from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS | |
from llm_studio.src.utils.utils import add_file_to_zip, kill_child_processes_and_current | |
logger = logging.getLogger(__name__) | |
def trace_calls(func): | |
""" | |
Trace calls to the function by printing the function name and the stack trace. | |
""" | |
async def async_wrapper(*args, **kwargs): | |
logger.debug(f"Async function {func.__name__} called from:") | |
logger.debug("".join(traceback.format_stack(limit=2))) | |
return await func(*args, **kwargs) | |
def sync_wrapper(*args, **kwargs): | |
logger.debug(f"Function {func.__name__} called from:") | |
logger.debug("".join(traceback.format_stack(limit=2))) | |
return func(*args, **kwargs) | |
# Check if the function is asynchronous | |
if asyncio.iscoroutinefunction(func): | |
return async_wrapper | |
else: | |
return sync_wrapper | |
async def experiment_start(q: Q) -> None: | |
"""Display experiment start cards.""" | |
await clean_dashboard(q, mode="experiment_start", exclude=["experiment/start"]) | |
show_update_warnings = True | |
is_create_experiment = False | |
# reset certain configs if new experiment start session | |
if ( | |
q.args.__wave_submission_name__ == "experiment/start" | |
or q.args.__wave_submission_name__ == "experiment/start/grid_search" | |
or q.args.__wave_submission_name__ == "experiment/start_experiment" | |
or q.args.__wave_submission_name__ == "dataset/newexperiment" | |
or q.args.__wave_submission_name__ == "dataset/newexperiment/from_current" | |
or q.args.__wave_submission_name__ == "experiment/list/new" | |
or q.args.__wave_submission_name__ == "experiment/list/new_gridsearch" | |
): | |
q.client["experiment/start/cfg_experiment_prev"] = None | |
q.client["experiment/start/cfg_file_prev"] = None | |
q.client["experiment/start/prev_dataset"] = None | |
q.client["experiment/start/cfg_sub"] = None | |
show_update_warnings = False | |
is_create_experiment = True | |
if ( | |
q.args.__wave_submission_name__ == "experiment/start" | |
or q.args.__wave_submission_name__ == "experiment/start_experiment" | |
or q.args.__wave_submission_name__ == "dataset/newexperiment" | |
or q.args.__wave_submission_name__ == "dataset/newexperiment/from_current" | |
or q.args.__wave_submission_name__ == "experiment/list/new" | |
): | |
q.client["experiment/start/grid_search"] = None | |
elif ( | |
q.args.__wave_submission_name__ == "experiment/start/grid_search" | |
or q.args.__wave_submission_name__ == "experiment/list/new_gridsearch" | |
): | |
q.client["experiment/start/grid_search"] = True | |
# set active navigation | |
if q.client["experiment/start/grid_search"]: | |
q.client["nav/active"] = "experiment/start/grid_search" | |
else: | |
q.client["nav/active"] = "experiment/start" | |
# get all the datasets available | |
df_datasets = q.client.app_db.get_datasets_df() | |
# Hide inference only datasets | |
df_datasets = df_datasets.loc[df_datasets["train_rows"].notna()] | |
if ( | |
not q.client["experiment/start/dataset"] | |
or q.client["experiment/start/dataset"] not in df_datasets.id.astype(str).values | |
): | |
if len(df_datasets) >= 1: | |
q.client["experiment/start/dataset"] = str(df_datasets["id"].iloc[-1]) | |
else: | |
q.client["experiment/start/dataset"] = "1" | |
warning_message = "Experiment settings might be updated after changing {}" | |
items = [ | |
ui.separator(name="general_expander", label="General settings"), | |
ui.dropdown( | |
name="experiment/start/dataset", | |
label="Dataset", | |
required=True, | |
value=q.client["experiment/start/dataset"], | |
choices=[ | |
ui.choice(str(row["id"]), str(row["name"])) | |
for _, row in df_datasets.iterrows() | |
], | |
trigger=True, | |
tooltip=tooltips["experiments_dataset"], | |
), | |
] | |
if ( | |
show_update_warnings | |
and q.client["experiment/start/dataset_prev"] | |
!= q.client["experiment/start/dataset"] | |
): | |
items += [ | |
ui.message_bar(type="warning", text=warning_message.format("Dataset")) | |
] | |
show_update_warnings = False | |
if ( | |
q.client["experiment/start/cfg_file"] is None | |
or q.client["experiment/start/dataset_prev"] | |
!= q.client["experiment/start/dataset"] | |
) and q.client["experiment/start/cfg_category"] != "experiment": | |
dataset = q.client.app_db.get_dataset(q.client["experiment/start/dataset"]) | |
if dataset is not None: | |
problem_type = dataset.config_file.replace(dataset.path + "/", "").replace( | |
".yaml", "" | |
) | |
else: | |
problem_type = default_cfg.cfg_file | |
q.client["experiment/start/cfg_file"] = problem_type | |
q.client["experiment/start/cfg_category"] = problem_type.split("_")[0] | |
if q.client["experiment/start/cfg_category"] == "experiment": | |
q.client["experiment/start/cfg_file"] = "experiment" | |
# get all experiments | |
df_experiments = get_experiments(q) | |
# get all problem category choices | |
choices_problem_categories = [ | |
ui.choice(name, label) for name, label in get_problem_categories() | |
] | |
if len(df_experiments["id"]) > 0: | |
choices_problem_categories += [ui.choice("experiment", "From Experiment")] | |
# set default value of problem type if no match to category | |
if ( | |
q.client["experiment/start/cfg_category"] | |
not in q.client["experiment/start/cfg_file"] | |
): | |
if q.client["experiment/start/cfg_category"] != "experiment": | |
q.client["experiment/start/cfg_file"] = get_problem_types( | |
category=q.client["experiment/start/cfg_category"] | |
)[0][0] | |
# get all problem type choices | |
choices_problem_types = [ | |
ui.choice(name, label) | |
for name, label in get_problem_types( | |
category=q.client["experiment/start/cfg_category"] | |
) | |
] | |
# remove model type if present in cfg file name here | |
q.client["experiment/start/cfg_file"] = remove_model_type( | |
q.client["experiment/start/cfg_file"] | |
) | |
if len(df_experiments["id"]) > 0: | |
if q.client["experiment/start/cfg_experiment"] is None: | |
q.client["experiment/start/cfg_experiment"] = str( | |
df_experiments["id"].iloc[0] | |
) | |
# Default pretrained from the previous experiment to False | |
if ( | |
q.client["experiment/start/cfg_experiment_pretrained"] is None | |
or is_create_experiment | |
): | |
q.client["experiment/start/cfg_experiment_pretrained"] = False | |
if q.client["experiment/start/cfg_category"] != "experiment": | |
items += [ | |
ui.dropdown( | |
name="experiment/start/cfg_file", | |
label="Problem Type", | |
required=True, | |
choices=choices_problem_types, | |
value=q.client["experiment/start/cfg_file"], | |
trigger=True, | |
tooltip=tooltips["experiments_problem_type"], | |
) | |
] | |
model_types = get_model_types(q.client["experiment/start/cfg_file"]) | |
if len(model_types) > 0: | |
choices = [ui.choice(name, label) for name, label in model_types] | |
if q.client["experiment/start/cfg_sub"] in [None, ""]: | |
q.client["experiment/start/cfg_sub"] = model_types[0][0] | |
items += [ | |
ui.dropdown( | |
name="experiment/start/cfg_sub", | |
label="Model Type", | |
required=True, | |
choices=choices, | |
value=q.client["experiment/start/cfg_sub"], | |
trigger=True, | |
) | |
] | |
else: | |
q.client["experiment/start/cfg_sub"] = "" | |
# add model type to cfg file name here | |
q.client["experiment/start/cfg_file"] = add_model_type( | |
q.client["experiment/start/cfg_file"], q.client["experiment/start/cfg_sub"] | |
) | |
if ( | |
show_update_warnings | |
and q.client["experiment/start/cfg_file_prev"] | |
!= q.client["experiment/start/cfg_file"] | |
and q.client["experiment/start/cfg_category"] != "experiment" | |
): | |
items += [ | |
ui.message_bar(type="warning", text=warning_message.format("Problem Type")) | |
] | |
show_update_warnings = False | |
if q.client["experiment/start/cfg_category"] == "experiment": | |
if ( | |
show_update_warnings | |
and q.client["experiment/start/cfg_experiment_prev"] | |
!= q.client["experiment/start/cfg_experiment"] | |
): | |
items += [ | |
ui.message_bar( | |
type="warning", text=warning_message.format("previous Experiment") | |
) | |
] | |
# Show pretrained weights toggle only for successfully finished experiments | |
if ( | |
df_experiments.loc[ | |
df_experiments.id == int(q.client["experiment/start/cfg_experiment"]), | |
"status", | |
].values[0] | |
== "finished" | |
): | |
items += [ | |
ui.toggle( | |
name="experiment/start/cfg_experiment_pretrained", | |
label="Use previous experiment weights", | |
value=q.client["experiment/start/cfg_experiment_pretrained"], | |
trigger=True, | |
) | |
] | |
# only show yaml option, when not starting from another experiment | |
if q.client["experiment/start/cfg_category"] != "experiment": | |
items += [ | |
ui.toggle( | |
name="experiment/start/from_yaml", | |
label="Import config from YAML", | |
value=False, | |
trigger=True, | |
tooltip=tooltips["experiments_import_config_from_yaml"], | |
) | |
] | |
if q.args["experiment/start/from_yaml"]: | |
items += [ | |
ui.file_upload( | |
name="experiment/upload_yaml", | |
label="Upload!", | |
multiple=False, | |
file_extensions=["yaml"], | |
) | |
] | |
if q.args["experiment/upload_yaml"] is not None: | |
# reset previous, so the UI will be reloaded | |
q.client["experiment/start/cfg_file_prev"] = None | |
await config_import_uploaded_file(q) | |
logger.info( | |
f"PREV {q.client['experiment/start/cfg_file_prev']} " | |
f"{q.client['experiment/start/cfg_file']} " | |
f"{q.client['experiment/start/dataset_prev']} " | |
f"{q.client['experiment/start/dataset']} " | |
f"{q.client['experiment/start/cfg_experiment_prev']} " | |
f"{q.client['experiment/start/cfg_experiment']} " | |
) | |
if q.client["experiment/start/cfg_category"] == "experiment": | |
logger.info("Starting from experiment") | |
# reset previous config file | |
q.client["experiment/start/cfg_file_prev"] = None | |
q.client["experiment/start/experiment"] = q.client.app_db.get_experiment( | |
q.client["experiment/start/cfg_experiment"] | |
) | |
parent_path = os.path.dirname(q.client["experiment/start/experiment"].path) | |
parent_config = load_config_yaml(f"{parent_path}/cfg.yaml") | |
q.client["experiment/start/cfg"] = parent_config | |
# set pretrained weights | |
if q.client["experiment/start/cfg_experiment_pretrained"]: | |
prev_weights = os.path.join( | |
q.client["experiment/start/experiment"].path, | |
"checkpoint.pth", | |
) | |
if os.path.exists(prev_weights): | |
q.client["experiment/start/cfg"].architecture.pretrained_weights = ( | |
prev_weights | |
) | |
q.client["experiment/start/cfg"].architecture._visibility[ | |
"pretrained_weights" | |
] = -1 | |
experiments_df = q.client.app_db.get_experiments_df() | |
output_dir = os.path.abspath( | |
os.path.join(q.client["experiment/start/cfg"].output_directory, "..") | |
) | |
q.client["experiment/start/cfg"].experiment_name = get_unique_name( | |
q.client["experiment/start/cfg"].experiment_name, | |
experiments_df["name"].values, | |
lambda x: os.path.exists(os.path.join(output_dir, x)), | |
) | |
# Configuration flags: | |
# from_dataset -- take the values from the dataset config | |
# from_cfg -- take the values from the configuration file | |
# from_default -- take the values from the default settings | |
# from_dataset_args -- take the values from the dataset's q.args | |
# Otherwise -- take the values from the q.args (user input) | |
# pick default values from config | |
if ( | |
q.client["experiment/start/cfg_experiment_prev"] | |
!= q.client["experiment/start/cfg_experiment"] | |
): | |
q.client["experiment/start/cfg_mode/from_dataset"] = False | |
q.client["experiment/start/cfg_mode/from_cfg"] = True | |
q.client["experiment/start/cfg_mode/from_dataset_args"] = False | |
q.client["experiment/start/dataset"] = str( | |
q.client["experiment/start/experiment"].dataset | |
) | |
items[1].dropdown.value = q.client["experiment/start/dataset"] | |
# pick default values from config or dataset | |
elif ( | |
q.client["experiment/start/dataset_prev"] | |
!= q.client["experiment/start/dataset"] | |
): | |
q.client["experiment/start/cfg_mode/from_dataset"] = True | |
q.client["experiment/start/cfg_mode/from_cfg"] = True | |
q.client["experiment/start/cfg_mode/from_dataset_args"] = False | |
# pick default values from args | |
else: | |
q.client["experiment/start/cfg_mode/from_dataset"] = False | |
q.client["experiment/start/cfg_mode/from_cfg"] = False | |
q.client["experiment/start/cfg_mode/from_dataset_args"] = True | |
q.client["experiment/start/cfg_mode/from_default"] = False | |
q.client["experiment/start/cfg_experiment_prev"] = q.client[ | |
"experiment/start/cfg_experiment" | |
] | |
else: | |
logger.info("Starting from CFG") | |
# reset previous experiment | |
q.client["experiment/start/cfg_experiment_prev"] = None | |
# pick default values from dataset or config | |
if ( | |
q.client["experiment/start/cfg_file_prev"] | |
!= q.client["experiment/start/cfg_file"] | |
) or ( | |
q.client["experiment/start/dataset_prev"] | |
!= q.client["experiment/start/dataset"] | |
): | |
q.client["experiment/start/cfg_mode/from_dataset"] = True | |
q.client["experiment/start/cfg_mode/from_cfg"] = True | |
q.client["experiment/start/cfg_mode/from_default"] = True | |
q.client["experiment/start/cfg_mode/from_dataset_args"] = False | |
# pick default values from args | |
else: | |
q.client["experiment/start/cfg_mode/from_dataset"] = False | |
q.client["experiment/start/cfg_mode/from_cfg"] = False | |
q.client["experiment/start/cfg_mode/from_default"] = False | |
q.client["experiment/start/cfg_mode/from_dataset_args"] = True | |
q.client["experiment/start/cfg_file_prev"] = q.client[ | |
"experiment/start/cfg_file" | |
] | |
config_path = ( | |
f"llm_studio/python_configs/{q.client['experiment/start/cfg_file']}" | |
) | |
q.client["experiment/start/cfg"] = load_config_py( | |
config_path=config_path, config_name="ConfigProblemBase" | |
) | |
q.client["experiment/start/dataset_prev"] = q.client["experiment/start/dataset"] | |
logger.info(f"From dataset {q.client['experiment/start/cfg_mode/from_dataset']}") | |
logger.info(f"From cfg {q.client['experiment/start/cfg_mode/from_cfg']}") | |
logger.info(f"From default {q.client['experiment/start/cfg_mode/from_default']}") | |
logger.info(f"Config file: {q.client['experiment/start/cfg_file']}") | |
option_items = get_ui_elements_for_cfg(cfg=q.client["experiment/start/cfg"], q=q) | |
items.extend(option_items) | |
if q.client["experiment/start/cfg_mode/from_cfg"]: | |
q.page["experiment/start"] = ui.form_card(box="content", items=items) | |
else: | |
q.page["experiment/start"].items = items | |
q.client.delete_cards.add("experiment/start") | |
q.page["experiment/start/footer"] = ui.form_card( | |
box="footer", | |
items=[ | |
ui.inline( | |
items=[ | |
ui.button( | |
name="experiment/start/run", | |
label=( | |
"Run grid search" | |
if q.client["experiment/start/grid_search"] | |
else "Run experiment" | |
), | |
primary=True, | |
) | |
], | |
justify="start", | |
) | |
], | |
) | |
q.client.delete_cards.add("experiment/start/footer") | |
def experiment_input_type_error( | |
q: Q, pre: str = "experiment/start" | |
) -> Union[bool, GridCheckError]: | |
"""Error check for custom entered values in combo boxes (grid search) | |
Returns: | |
GridCheckError if errors found and False if no errors found | |
""" | |
cfg = q.client[f"{pre}/cfg"] | |
cfg = parse_ui_elements(cfg=cfg, q=q, pre=f"{pre}/cfg/") | |
try: | |
get_grid_search(cfg=cfg, q=q, pre=pre) | |
except GridCheckError as e: | |
return e | |
return False | |
async def experiment_run(q: Q): | |
"""Start an experiment. | |
Args: | |
q: Q | |
""" | |
# import here to avoid circular imports | |
from llm_studio.app_utils.sections.project import list_current_experiments | |
pre = "experiment/start" | |
logger.info("Starting experiment") | |
logger.info(f"{pre}/cfg_file") | |
logger.info(f"CFG: {q.client[f'{pre}/cfg_file']}") | |
if q.client[f"{pre}/cfg_category"] == "experiment": | |
q.client[f"{pre}/cfg_file"] = q.client[f"{pre}/experiment"].config_file | |
cfg = q.client[f"{pre}/cfg"] | |
cfg = parse_ui_elements(cfg=cfg, q=q, pre=f"{pre}/cfg/") | |
cfg.experiment_name = cfg.experiment_name.replace("/", "-") | |
grid_search = get_grid_search(cfg=cfg, q=q, pre=pre) | |
if len(grid_search) == 0: | |
errors = check_config_for_errors(cfg) | |
if errors["title"] and not q.args["experiment/start/error/proceed"]: | |
title = ( | |
errors["title"][0] | |
if len(errors["title"]) == 1 | |
else "The following configuration mismatches were found:" | |
) | |
error_text = [ui.text(message) for message in errors["message"]] | |
q.page["meta"].dialog = ui.dialog( | |
title=title, | |
name="experiment/start/error/dialog", | |
items=error_text | |
+ [ | |
ui.buttons( | |
[ | |
ui.button( | |
name="experiment/start/error/ok", | |
label="Ok", | |
primary=True, | |
), | |
ui.button( | |
name="experiment/start/error/proceed", | |
label="I want to proceed anyhow", | |
primary=False, | |
), | |
] | |
) | |
], | |
closable=True, | |
) | |
q.client["keep_meta"] = True | |
else: | |
start_experiment(cfg=cfg, q=q, pre=pre) | |
await list_current_experiments(q) | |
else: | |
exp_name = cfg.experiment_name | |
all_grid_hyperparams = sorted(grid_search) | |
combinations = itertools.product( | |
*(grid_search[name] for name in all_grid_hyperparams) | |
) | |
combinations = [dict(zip(all_grid_hyperparams, x)) for x in list(combinations)] | |
random.shuffle(combinations) | |
q.page["meta"].dialog = ui.dialog( | |
title="Start grid search", | |
name="experiment/start/gridsearch/dialog", | |
items=[ | |
ui.text( | |
"Your selected grid of hyperparameters results in " | |
f"{len(combinations)} individual experiments. " | |
"Do you want to proceed?" | |
) | |
] | |
+ [ | |
ui.buttons( | |
[ | |
ui.button( | |
name="experiment/start/gridsearch/proceed", | |
label=( | |
f"Start grid search of {len(combinations)} experiments" | |
), | |
primary=True, | |
), | |
ui.button( | |
name="experiment/start/gridsearch/cancel", | |
label="Cancel", | |
primary=False, | |
), | |
] | |
) | |
], | |
closable=True, | |
) | |
q.client["keep_meta"] = True | |
if q.args["experiment/start/gridsearch/proceed"]: | |
all_grid_names = [] | |
for exp_idx, combo in enumerate(combinations): | |
filtered_combo = filter_grid_search_combination( | |
grid=combo.copy(), cfg=cfg | |
) | |
grid_name = "_".join( | |
[ | |
f"{hyp}_{filtered_combo[hyp]}" | |
for hyp in sorted(filtered_combo.keys()) | |
if len(grid_search[hyp]) > 1 | |
] | |
) | |
if grid_name in all_grid_names: | |
continue | |
else: | |
all_grid_names.append(grid_name) | |
cfg = set_grid_to_cfg(cfg=cfg, grid=combo) | |
if grid_name != "": | |
cfg.experiment_name = exp_name + f"_{grid_name}" | |
cfg.experiment_name = cfg.experiment_name.replace("/", "-") | |
start_experiment(cfg=cfg, q=q, pre=pre) | |
await list_current_experiments(q) | |
# Remove the dialog | |
q.client["keep_meta"] = False | |
def get_experiment_table(q, df_viz, height="calc(100vh - 245px)", actions=None): | |
col_remove = [ | |
"id", | |
"path", | |
"mode", | |
"seed", | |
"process_id", | |
"gpu_list", | |
"loss", | |
"eta", | |
"epoch", | |
] | |
for col in col_remove: | |
if col in df_viz: | |
del df_viz[col] | |
df_viz = df_viz.rename( | |
columns={"config_file": "problem type"}, | |
) | |
df_viz["problem type"] = df_viz["problem type"].str.replace("Text ", "") | |
if actions == "experiment": | |
actions_dict = { | |
"experiment/list/new": "New experiment", | |
"experiment/list/new_gridsearch": "New grid search", | |
"experiment/list/rename": "Rename experiment", | |
"experiment/list/stop/table": "Stop experiment", | |
"experiment/list/delete/table/dialog": "Delete experiment", | |
} | |
else: | |
actions_dict = {} | |
min_widths = { | |
"name": "350", | |
"dataset": "150", | |
"problem type": "195", | |
"metric": "75", | |
"val metric": "102", | |
"progress": "85", | |
"status": "90", | |
"info": "115", | |
"actions": "5", | |
} | |
return ui_table_from_df( | |
q=q, | |
df=df_viz, | |
name="experiment/list/table", | |
sortables=["val metric"], | |
filterables=["name", "dataset", "problem type", "metric", "status"], | |
searchables=["name", "dataset"], | |
numerics=["val metric"], | |
tags=["status"], | |
progresses=["progress"], | |
min_widths=min_widths, | |
link_col="name", | |
height=height, | |
actions=actions_dict, | |
) | |
async def experiment_list( | |
q: Q, | |
reset: bool = True, | |
allowed_statuses: Optional[List[str]] = None, | |
actions: bool = True, | |
) -> None: | |
"""List all experiments.""" | |
q.client["nav/active"] = "experiment/list" | |
if reset: | |
await clean_dashboard(q, mode="full") | |
q.client["experiment/list/df_experiments"] = get_experiments( | |
q, | |
status=allowed_statuses, | |
) | |
df_viz = q.client["experiment/list/df_experiments"].copy() | |
table = get_experiment_table( | |
q, | |
df_viz, | |
actions="experiment" if actions else None, | |
) | |
message_bar = get_experiment_list_message_bar(q) | |
items = [table, message_bar] | |
q.page["experiment/list"] = ui.form_card(box="content", items=items) | |
q.client.delete_cards.add("experiment/list") | |
buttons = [ | |
ui.button(name="experiment/list/refresh", label="Refresh", primary=True), | |
ui.button( | |
name="experiment/list/compare", | |
label="Compare experiments", | |
primary=False, | |
), | |
ui.button(name="experiment/list/stop", label="Stop experiments", primary=False), | |
ui.button( | |
name="experiment/list/delete", label="Delete experiments", primary=False | |
), | |
] | |
q.page["dataset/display/footer"] = ui.form_card( | |
box="footer", items=[ui.inline(items=buttons, justify="start")] | |
) | |
q.client.delete_cards.add("dataset/display/footer") | |
def get_table_and_message_item_indices(q): | |
table_item_idx, message_item_idx = 0, 1 | |
return table_item_idx, message_item_idx | |
async def experiment_compare(q: Q, selected_rows: list): | |
if q.client["experiment/compare/tab"] is None: | |
q.client["experiment/compare/tab"] = "experiment/compare/charts" | |
if q.args["experiment/compare/charts"] is not None: | |
q.client["experiment/compare/tab"] = "experiment/compare/charts" | |
if q.args["experiment/compare/config"] is not None: | |
q.client["experiment/compare/tab"] = "experiment/compare/config" | |
experiment_ids = [ | |
q.client["experiment/list/df_experiments"]["id"].iloc[int(idx)] | |
for idx in selected_rows | |
] | |
await clean_dashboard(q, mode=q.client["experiment/compare/tab"]) | |
tabs = [ | |
ui.tab(name="experiment/compare/charts", label="Charts"), | |
ui.tab(name="experiment/compare/config", label="Config"), | |
] | |
q.page["experiment/compare/tab"] = ui.tab_card( | |
box="nav2", link=True, items=tabs, value=q.client["experiment/compare/tab"] | |
) | |
q.client.delete_cards.add("experiment/compare/tab") | |
if q.client["experiment/compare/tab"] == "experiment/compare/charts": | |
charts = [] | |
experiment_names = [] | |
for experiment_id in experiment_ids: | |
experiment = q.client.app_db.get_experiment(experiment_id) | |
experiment_path = experiment.path | |
charts.append(load_charts(experiment_path)) | |
current_name = f" {experiment.name}" | |
experiment_names.append(current_name) | |
await charts_tab(q, charts, experiment_names) | |
elif q.client["experiment/compare/tab"] == "experiment/compare/config": | |
if q.client["experiment/compare/diff_toggle"] is None: | |
q.client["experiment/compare/diff_toggle"] = False | |
settings = pd.DataFrame() | |
for experiment_id in experiment_ids: | |
experiment = q.client.app_db.get_experiment(experiment_id) | |
experiment_path = experiment.path | |
experiment_cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) | |
items = get_cfg_list_items(experiment_cfg) | |
act_df = pd.Series({item.label: item.value for item in items}) | |
settings[experiment.name] = act_df | |
settings.index.name = "setting" | |
if q.client["experiment/compare/diff_toggle"]: | |
val_counts = settings.T.nunique() | |
drop_idx = val_counts[val_counts == 1].index.values | |
settings = settings.drop(drop_idx) | |
items = [ | |
ui.toggle( | |
name="experiment/compare/diff_toggle", | |
label="Show differences only", | |
value=q.client["experiment/compare/diff_toggle"], | |
trigger=True, | |
), | |
ui_table_from_df( | |
q=q, | |
df=settings.reset_index(), | |
name="experiment/compare/summary/table", | |
link_col="setting", | |
height="calc(100vh - 315px)", | |
), | |
] | |
q.page["experiment/compare/config"] = ui.form_card(box="first", items=items) | |
q.client.delete_cards.add("experiment/compare/config") | |
buttons = [ | |
ui.button(name="experiment/compare", label="Refresh", primary=True), | |
ui.button(name="experiment/list/current", label="Back", primary=False), | |
] | |
q.page["experiment/compare/footer"] = ui.form_card( | |
box="footer", items=[ui.inline(items=buttons, justify="start")] | |
) | |
q.client.delete_cards.add("experiment/compare/footer") | |
async def experiment_rename_form(q: Q, error: str = "") -> None: | |
experiment = q.client.app_db.get_experiment(q.client["experiment/rename/id"]) | |
experiment_name = experiment.name | |
items = [ | |
ui.textbox( | |
name="experiment/rename/name", | |
label=f"New name for {experiment_name}", | |
value=experiment_name, | |
required=True, | |
) | |
] | |
if error: | |
items.append(ui.message_bar(type="error", text=error)) | |
q.page["experiment/list"].items = items | |
buttons = [ | |
ui.button(name="experiment/rename/action", label="Rename", primary=True), | |
ui.button(name="experiment/list/current", label="Abort", primary=False), | |
] | |
q.page["dataset/display/footer"] = ui.form_card( | |
box="footer", items=[ui.inline(items=buttons, justify="start")] | |
) | |
q.client.delete_cards.add("dataset/display/footer") | |
async def experiment_rename_ui_workflow(q: Q): | |
selected_row = q.args["experiment/list/rename"] | |
rename_id = q.client["experiment/list/df_experiments"]["id"].iloc[int(selected_row)] | |
q.client["experiment/rename/id"] = rename_id | |
await experiment_rename_form(q) | |
async def experiment_rename_action(q, experiment, new_name): | |
"""Rename experiment with `current_id` id in DB to `new_name`""" | |
old_name = experiment.name | |
old_path = experiment.path | |
new_path = old_path.replace(old_name, new_name) | |
if old_path != new_path: | |
old_exp_path = f"{old_path}" | |
exp_path = f"{new_path}" | |
logger.info(f"Renaming {old_exp_path} to {exp_path}") | |
shutil.move(os.path.abspath(old_exp_path), os.path.abspath(exp_path)) | |
# update the experiment name in the DB | |
with SqliteDict(os.path.join(new_path, "charts.db")) as charts: | |
for k1 in PLOT_ENCODINGS: | |
if k1 == "df": | |
# this is required to properly overwrite it | |
df = charts[k1].copy() | |
for k2, v2 in df.items(): | |
logger.info( | |
f"Renaming charts {v2} to {v2.replace(old_name, new_name)}" | |
) | |
df[k2] = v2.replace(old_name, new_name) | |
charts[k1] = df | |
charts.commit() | |
for config_file in ["cfg.yaml"]: | |
config_path = os.path.join(exp_path, config_file) | |
if os.path.exists(config_path): | |
experiment_cfg = load_config_yaml(config_path) | |
experiment_cfg.experiment_name = new_name | |
experiment_cfg.output_directory = new_path | |
save_config_yaml(config_path, experiment_cfg) | |
rename_files = ["preds"] | |
for file in rename_files: | |
old_file = get_artifact_path_path(old_name, exp_path, file) | |
new_file = get_artifact_path_path(new_name, exp_path, file) | |
if os.path.exists(old_file): | |
logger.info(f"Renaming {old_file} to {new_file}") | |
shutil.move(os.path.abspath(old_file), os.path.abspath(new_file)) | |
delete_files = ["logs"] # will be generated on demand with updates | |
for file in delete_files: | |
file = get_artifact_path_path(old_name, exp_path, file) | |
if os.path.exists(file): | |
logger.info(f"Deleting {file}") | |
os.remove(file) | |
q.client.app_db.rename_experiment(experiment.id, new_name, new_path) | |
async def experiment_delete(q: Q, experiment_ids: List[int]) -> None: | |
"""Delete selected experiments. | |
Args: | |
q: Q | |
experiment_ids: list of experiment ids to delete | |
""" | |
for experiment_id in experiment_ids: | |
experiment = q.client.app_db.get_experiment(experiment_id) | |
q.client.app_db.delete_experiment(experiment.id) | |
try: | |
shutil.rmtree(f"{experiment.path}") | |
except FileNotFoundError: | |
logger.warning(f"Experiment path {experiment.path} not found for deletion.") | |
async def experiment_stop(q: Q, experiment_ids: List[int]) -> None: | |
"""Stop selected experiments. | |
Args: | |
q: Q | |
experiment_ids: list of experiment ids to stop | |
""" | |
for experiment_id in experiment_ids: | |
experiment = q.client.app_db.get_experiment(experiment_id) | |
try: | |
ret = kill_child_processes_and_current(int(experiment.process_id)) | |
if ret: | |
flag_path = os.path.join(experiment.path, "flags.json") | |
write_flag(flag_path, "status", "stopped") | |
except Exception as e: | |
logger.error(f"Error while stopping the experiment: {e}") | |
pass | |
def load_charts(experiment_path): | |
try: | |
with SqliteDict(os.path.join(experiment_path, "charts.db")) as charts: | |
charts = dict(charts) | |
except Exception: | |
charts = {} | |
logger.warning("Too early, wait for the charts to appear") | |
return charts | |
async def experiment_display(q: Q) -> None: | |
"""Display a selected experiment.""" | |
experiment_id = q.client["experiment/list/df_experiments"]["id"].iloc[ | |
q.client["experiment/display/id"] | |
] | |
q.client["experiment/display/experiment_id"] = experiment_id | |
experiment = q.client.app_db.get_experiment(experiment_id) | |
q.client["experiment/display/experiment"] = experiment | |
q.client["experiment/display/experiment_path"] = experiment.path | |
checkpoints_exists = os.path.exists( | |
os.path.join(q.client["experiment/display/experiment_path"], "checkpoint.pth") | |
) | |
adapter_exists = os.path.exists( | |
os.path.join(q.client["experiment/display/experiment_path"], "adapter_model") | |
) | |
status, _ = get_experiment_status(experiment.path) | |
charts = load_charts(q.client["experiment/display/experiment_path"]) | |
q.client["experiment/display/charts"] = charts | |
if experiment.mode == "train": | |
if q.client["experiment/display/tab"] is None: | |
q.client["experiment/display/tab"] = "experiment/display/charts" | |
else: | |
if q.client["experiment/display/tab"] is None: | |
q.client["experiment/display/tab"] = "experiment/display/summary" | |
if q.args["experiment/display/charts"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/charts" | |
if q.args["experiment/display/summary"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/summary" | |
if q.args["experiment/display/train_data_insights"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/train_data_insights" | |
if q.args["experiment/display/validation_prediction_insights"] is not None: | |
q.client["experiment/display/tab"] = ( | |
"experiment/display/validation_prediction_insights" | |
) | |
if q.args["experiment/display/config"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/config" | |
if q.args["experiment/display/deployment"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/deployment" | |
if q.args["experiment/display/logs"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/logs" | |
if q.args["experiment/display/chat"] is not None: | |
q.client["experiment/display/tab"] = "experiment/display/chat" | |
await clean_dashboard(q, mode=q.client["experiment/display/tab"]) | |
tabs = [ | |
ui.tab(name="experiment/display/charts", label="Charts"), | |
ui.tab(name="experiment/display/summary", label="Summary"), | |
] | |
# html for legacy experiments | |
has_train_data_insights = any( | |
[ | |
charts.get(plot_encoding, dict()).get("train_data") is not None | |
for plot_encoding in PLOT_ENCODINGS | |
] | |
) | |
if has_train_data_insights: | |
tabs += [ | |
ui.tab( | |
name="experiment/display/train_data_insights", | |
label="Train Data Insights", | |
) | |
] | |
has_validation_prediction_insights = any( | |
[ | |
charts.get(plot_encoding, dict()).get("validation_predictions") is not None | |
for plot_encoding in PLOT_ENCODINGS | |
] | |
) | |
if has_validation_prediction_insights: | |
tabs += [ | |
ui.tab( | |
name="experiment/display/validation_prediction_insights", | |
label="Validation Prediction Insights", | |
) | |
] | |
tabs += [ | |
ui.tab(name="experiment/display/logs", label="Logs"), | |
ui.tab(name="experiment/display/config", label="Config"), | |
] | |
if status == "finished" and checkpoints_exists: | |
tabs += [ui.tab(name="experiment/display/chat", label="Chat")] | |
q.page["experiment/display/tab"] = ui.tab_card( | |
box="nav2", link=True, items=tabs, value=q.client["experiment/display/tab"] | |
) | |
q.client.delete_cards.add("experiment/display/tab") | |
if q.client["experiment/display/tab"] == "experiment/display/charts": | |
await charts_tab(q, [charts], [""]) | |
elif q.client["experiment/display/tab"] in [ | |
"experiment/display/train_data_insights", | |
"experiment/display/validation_prediction_insights", | |
]: | |
await insights_tab(charts, q) | |
elif q.client["experiment/display/tab"] in ["experiment/display/summary"]: | |
await summary_tab(experiment_id, q) | |
elif q.client["experiment/display/tab"] in ["experiment/display/config"]: | |
await configs_tab(q) | |
elif q.client["experiment/display/tab"] in ["experiment/display/logs"]: | |
await logs_tab(q) | |
elif q.client["experiment/display/tab"] in ["experiment/display/chat"]: | |
await chat_tab(q) | |
await q.page.save() | |
buttons = [ | |
ui.button(name="experiment/display/refresh", label="Refresh", primary=True) | |
] | |
buttons += [ | |
ui.button( | |
name="experiment/display/download_logs", | |
label="Download logs/config", | |
primary=False, | |
) | |
] | |
if status == "finished": | |
buttons += [ | |
ui.button( | |
name="experiment/display/download_predictions", | |
label="Download predictions", | |
primary=False, | |
disabled=False, | |
tooltip=None, | |
) | |
] | |
if checkpoints_exists: | |
buttons += [ | |
ui.button( | |
name="experiment/display/download_model", | |
label="Download model", | |
primary=False, | |
disabled=False, | |
tooltip=None, | |
), | |
] | |
if adapter_exists: | |
buttons += [ | |
ui.button( | |
name="experiment/display/download_adapter", | |
label="Download adapter", | |
primary=False, | |
disabled=False, | |
tooltip=None, | |
), | |
] | |
if checkpoints_exists: | |
buttons += [ | |
ui.button( | |
name="experiment/display/push_to_huggingface", | |
label="Push checkpoint to huggingface", | |
primary=False, | |
disabled=False, | |
tooltip=None, | |
), | |
] | |
buttons += [ui.button(name="experiment/list/current", label="Back", primary=False)] | |
q.page["experiment/display/footer"] = ui.form_card( | |
box="footer", | |
items=[ | |
ui.inline(items=buttons, justify="start"), | |
], | |
) | |
q.client.delete_cards.add("experiment/display/footer") | |
async def insights_tab(charts, q): | |
if q.client["experiment/display/tab"] == "experiment/display/train_data_insights": | |
key = "train_data" | |
elif ( | |
q.client["experiment/display/tab"] | |
== "experiment/display/validation_prediction_insights" | |
): | |
key = "validation_predictions" | |
for k1 in PLOT_ENCODINGS: | |
if k1 not in charts: | |
continue | |
for k2, v2 in charts[k1].items(): | |
if k2 != key: | |
continue | |
if k1 == "html": | |
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.markup_card( | |
box="first", title="", content=v2 | |
) | |
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") | |
continue | |
elif k1 == "image": | |
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.image_card( | |
box="first", title="", type="png", image=v2 | |
) | |
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") | |
continue | |
elif k1 == "df": | |
df = pd.read_parquet(v2) | |
min_widths = { | |
col: "350" for col in df.columns if "text" in str(col).lower() | |
} | |
# | |
if key == "train_data": | |
min_widths["Content"] = "800" | |
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.form_card( | |
box="first", | |
items=[ | |
ui_table_from_df( | |
q=q, | |
df=df, | |
name=f"experiment/display/charts/{k1}_{k2}", | |
sortables=[ | |
col for col in df.columns if col.startswith("Metric") | |
], | |
markdown_cells=[ | |
col | |
for col in df.columns | |
if not col.startswith("Metric") | |
], | |
searchables=list(df.columns), | |
downloadable=True, | |
resettable=True, | |
min_widths=min_widths, | |
height="calc(100vh - 267px)", | |
max_char_length=50_000, | |
cell_overflow="tooltip", | |
) | |
], | |
) | |
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") | |
continue | |
async def summary_tab(experiment_id, q): | |
experiment_df = get_experiments(q) | |
input_dict = experiment_df[experiment_df.id == experiment_id].iloc[0].to_dict() | |
cfg = load_config_yaml( | |
os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") | |
) | |
_ = get_tokenizer(cfg) | |
# experiment card | |
card_name = "experiment/display/summary/experiment" | |
q.page[card_name] = ui.form_card( | |
box=ui.box(zone="first"), | |
items=[ | |
ui.separator("Experiment"), | |
ui.stats( | |
[ | |
ui.stat( | |
value=cfg.experiment_name, | |
label="Name", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
ui.stats( | |
[ | |
ui.stat( | |
value=input_dict["config_file"], | |
label="Problem Type", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
], | |
) | |
q.client.delete_cards.add(card_name) | |
# datasets card | |
card_name = "experiment/display/summary/datasets" | |
q.page[card_name] = ui.form_card( | |
box=ui.box(zone="first"), | |
items=[ | |
ui.separator("Datasets"), | |
ui.stats( | |
[ | |
ui.stat( | |
value=Path(cfg.dataset.train_dataframe).stem, | |
label="Training Dataset", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
ui.stats( | |
[ | |
ui.stat( | |
value=( | |
"-" | |
if cfg.dataset.validation_dataframe in ["", "None", None] | |
else Path(cfg.dataset.validation_dataframe).stem | |
), | |
label="Validation Dataset", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
], | |
) | |
q.client.delete_cards.add(card_name) | |
# score card | |
card_name = "experiment/display/summary/score" | |
q.page[card_name] = ui.form_card( | |
box=ui.box(zone="first"), | |
items=[ | |
ui.separator("Score"), | |
ui.stats( | |
[ | |
ui.stat( | |
value=input_dict["metric"], | |
label="Metric", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
ui.stats( | |
[ | |
ui.stat( | |
value=( | |
"-" | |
if input_dict["val metric"] in ["", "None", None] | |
else str(input_dict["val metric"]) | |
), | |
label="Validation Score", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
], | |
) | |
q.client.delete_cards.add(card_name) | |
# main configs card | |
card_name = "experiment/display/summary/main_configs" | |
q.page[card_name] = ui.form_card( | |
box=ui.box(zone="second"), | |
items=[ | |
ui.separator("Main Configurations"), | |
ui.stats( | |
[ | |
ui.stat( | |
value=cfg.llm_backbone, | |
label="LLM Backbone", | |
), | |
ui.stat( | |
value=str(cfg.training.lora), | |
label="Lora", | |
), | |
ui.stat( | |
value=str(cfg.training.epochs), | |
label="Epochs", | |
), | |
ui.stat( | |
value=str(cfg.training.batch_size), | |
label="Batch Size", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
ui.stats( | |
[ | |
ui.stat( | |
value=str(input_dict["loss"]), | |
label="Loss Function", | |
), | |
ui.stat( | |
value=cfg.architecture.backbone_dtype, | |
label="Backbone Dtype", | |
), | |
ui.stat( | |
value=str(cfg.architecture.gradient_checkpointing), | |
label="Gradient Checkpointing", | |
), | |
ui.stat( | |
value=input_dict["gpu_list"], | |
label="GPU List", | |
), | |
], | |
justify="between", | |
inset=True, | |
), | |
], | |
) | |
q.client.delete_cards.add(card_name) | |
# code card | |
card_name = "experiment/display/summary/code" | |
content = get_experiment_summary_code_card(cfg=cfg) | |
q.page[card_name] = ui.markdown_card( | |
box=ui.box(zone="third"), | |
title="", | |
content=content, | |
) | |
q.client.delete_cards.add(card_name) | |
async def configs_tab(q): | |
experiment_cfg = load_config_yaml( | |
os.path.join(q.client["experiment/display/experiment_path"], "cfg.yaml") | |
) | |
items = get_cfg_list_items(experiment_cfg) | |
q.page["experiment/display/config"] = ui.stat_list_card( | |
box="first", items=items, title="" | |
) | |
q.client.delete_cards.add("experiment/display/config") | |
async def logs_tab(q): | |
logs_path = os.path.join(q.client["experiment/display/experiment_path"], "logs.log") | |
text = "" | |
in_pre = 0 | |
# Read log file only if it already exists | |
if os.path.exists(logs_path): | |
with open(logs_path, "r") as f: | |
for line in f.readlines(): | |
if in_pre == 0: | |
text += "<div>" | |
if "INFO: Lock" in line: | |
continue | |
# maximum line length | |
n = 250 | |
chunks = [line[i : i + n] for i in range(0, len(line), n)] | |
text += "</div><div>".join(chunks) | |
# Check for formatted HTML text | |
if "<pre>" in line: | |
in_pre += 1 | |
if "</pre>" in line: | |
in_pre -= 1 | |
if in_pre == 0: | |
text += "</div>" | |
items = [ui.text(text)] | |
q.page["experiment/display/logs"] = ui.form_card(box="first", items=items, title="") | |
q.client.delete_cards.add("experiment/display/logs") | |
def subsample(key1, key2, value, max_plot_points=1000): | |
act_plot_points = len(value["steps"]) | |
if act_plot_points > max_plot_points: | |
stride = int(np.ceil(act_plot_points / max_plot_points)) | |
value["steps"] = value["steps"][::stride] | |
value["values"] = value["values"][::stride] | |
logger.info( | |
f"{key1} {key2} sampled from size {act_plot_points} to size " | |
f"{len(value['steps'])} using stride {stride}." | |
) | |
return value | |
def unite_validation_metric_charts(charts_list): | |
unique_metrics = [] | |
for chart in charts_list: | |
unique_metrics.extend(list(chart.get("validation", {}).keys())) | |
unique_metrics = set([key for key in unique_metrics if key != "loss"]) | |
if len(unique_metrics) > 1: | |
for chart in charts_list: | |
if "validation" in chart: | |
for key in unique_metrics: | |
if key in chart["validation"]: | |
chart["validation"]["metric"] = chart["validation"][key] | |
del chart["validation"][key] | |
return charts_list | |
async def charts_tab(q, charts_list, legend_labels): | |
charts_list = unite_validation_metric_charts(charts_list) | |
box = ["top_left", "top_right", "bottom_left", "bottom_right"] | |
cnt = 0 | |
for k1 in ["meta", "train", "validation"]: | |
if all([k1 not in charts for charts in charts_list]): | |
continue | |
all_second_keys: Set = set() | |
for charts in charts_list: | |
if k1 in charts: | |
all_second_keys = all_second_keys.union(set(charts[k1].keys())) | |
# Always plot loss in the lower left corner | |
if "loss" in all_second_keys: | |
all_second_keys.remove("loss") | |
list_all_second_keys = ["loss"] + list(all_second_keys) | |
else: | |
list_all_second_keys = list(all_second_keys) | |
for k2 in list_all_second_keys: | |
logger.info(f"{k1} {k2}") | |
items = [] | |
if k1 == "meta" and k2 == "lr": | |
tooltip = "Current learning rate throughout the training process." | |
elif k1 == "train" and k2 == "loss": | |
tooltip = ( | |
"Current training loss throughout the training process. " | |
"Loss is calculated as the average of the last ten batches." | |
) | |
elif k1 == "validation" and k2 == "loss": | |
tooltip = ( | |
"Current validation loss throughout the training process. " | |
"Loss is calculated as the average of all validation batches. " | |
) | |
elif k1 == "validation" and k2 != "loss": | |
tooltip = ( | |
"Current validation metric throughout the training process. " | |
"Metric is calculated on full validation set predictions." | |
) | |
else: | |
continue | |
title = f"{k1} {k2}".upper().replace("META LR", "LEARNING RATE") | |
if k2 == "loss": | |
title = title.replace("LOSS", "BATCH LOSS") | |
items.append(ui.text(title, tooltip=tooltip)) | |
rows = [] | |
max_samples = q.client["chart_plot_max_points"] | |
for charts, label in zip(charts_list, legend_labels): | |
if k1 not in charts or k2 not in charts[k1]: | |
continue | |
v2 = charts[k1][k2] | |
v2 = subsample(k1, k2, v2, max_samples) | |
if k2 == "lr" and "lr_diff" in charts["meta"]: | |
v3 = charts["meta"]["lr_diff"] | |
v3 = subsample("meta", "lr_diff", v3, max_samples) | |
rows.extend( | |
[ | |
(v2["steps"][i], f"learning rate{label}", v2["values"][i]) | |
for i in range(len(v2["values"])) | |
] | |
+ [ | |
( | |
v3["steps"][i], | |
f"differential learning rate{label}", | |
v3["values"][i], | |
) | |
for i in range(len(v3["values"])) | |
] | |
) | |
color = "=type" | |
fields = ["step", "type", "value"] | |
elif len(charts_list) > 1: | |
rows.extend( | |
[ | |
(v2["steps"][i], label.strip(), v2["values"][i]) | |
for i in range(len(v2["values"])) | |
] | |
) | |
color = "=type" | |
fields = ["step", "type", "value"] | |
else: | |
rows.extend( | |
[ | |
(v2["steps"][i], v2["values"][i]) # type: ignore | |
for i in range(len(v2["values"])) | |
] | |
) | |
color = wave_theme.color | |
fields = ["step", "value"] | |
d = data(fields=fields, rows=rows, pack=True) | |
viz = ui.visualization( | |
plot=ui.plot( | |
[ | |
ui.mark( | |
type="line", | |
x_title="step", | |
x_scale="linear", | |
y_scale="linear", | |
x="=step", | |
y="=value", | |
color=color, | |
y_min=0 if k1 == "meta" and k2 == "lr" else None, | |
color_range=wave_theme.color_range, | |
) | |
] | |
), | |
data=d, # type: ignore | |
interactions=["brush"], | |
height="max(calc((100vh - 275px)*0.41), 225px)", | |
width="100%", | |
) | |
items.append(viz) | |
if k1 == "validation" and k2 == "loss" and np.sum(v2["values"]) == 0: | |
items.append( | |
ui.message_bar( | |
type="info", | |
text="Validation batch loss cannot be \ | |
calculated for this problem type.", | |
) | |
) | |
q.page[f"experiment/display/charts/{k1}_{k2}"] = ui.form_card( | |
box=box[cnt], items=items | |
) | |
q.client.delete_cards.add(f"experiment/display/charts/{k1}_{k2}") | |
cnt += 1 | |
async def experiment_artifact_build_error_dialog(q: Q, error: str): | |
q.page["meta"].dialog = ui.dialog( | |
"Failed to build artifact", items=[ui.text(error)], closable=True | |
) | |
q.client["keep_meta"] = True | |
async def experiment_download_artifact( | |
q: Q, | |
get_artifact_path_fn: Callable[[str, str], str], | |
save_artifact_fn: Callable[[str, str], str], | |
additional_log: Optional[str] = "", | |
min_disk_space: Optional[float] = 0.0, | |
): | |
"""Download specific artifact, if it does not exist, create it on demand | |
Args: | |
q: Q | |
get_artifact_path_fn: function that returns path to the artifact | |
save_artifact_fn: function that generates the artifact and returns its path | |
additional_log: additional information to be logged | |
min_disk_space: minimal disk available needed to generate artifact | |
""" | |
experiment = q.client["experiment/display/experiment"] | |
experiment_path = q.client["experiment/display/experiment_path"] | |
zip_path = get_artifact_path_fn(experiment.name, experiment_path) | |
if not os.path.exists(zip_path): | |
try: | |
check_available_space(experiment_path, min_disk_space) | |
except LLMResourceException as e: | |
error = f"Cannot create {os.path.basename(zip_path)}. {e}" | |
await experiment_artifact_build_error_dialog(q, error) | |
return | |
logger.info(f"Creating {zip_path} on demand") | |
zip_path = save_artifact_fn(experiment.name, experiment_path) | |
if additional_log: | |
logger.info(f"{additional_log}: {zip_path}") | |
q.page["meta"].script = ui.inline_script( | |
f'window.open("{get_download_link(q, zip_path)}", "_blank");' | |
) | |
await q.page.save() | |
async def experiment_download_predictions(q: Q): | |
"""Download experiment predictions.""" | |
await experiment_download_artifact( | |
q, get_predictions_path, save_prediction_outputs, "Predictions path", None | |
) | |
async def experiment_download_logs(q: Q): | |
"""Download experiment logs.""" | |
experiment = q.client["experiment/display/experiment"] | |
experiment_path = q.client["experiment/display/experiment_path"] | |
zip_path = get_logs_path(experiment.name, experiment_path) | |
if not os.path.exists(zip_path): | |
logs = q.client["experiment/display/charts"] | |
logger.info(f"Creating {zip_path} on demand") | |
zip_path = save_logs(experiment.name, experiment_path, logs) | |
download_url = get_download_link(q, zip_path) | |
logger.info(f"Logs URL: {download_url}") | |
q.page["meta"].script = ui.inline_script( | |
f'window.open("{download_url}", "_blank");' | |
) | |
await q.page.save() | |
async def config_import_uploaded_file(q: Q): | |
""" "Importing a config file from drag and drop to the filesystem""" | |
file_url = q.args["experiment/upload_yaml"][0] | |
file_name = file_url.split("/")[-1] | |
path = f"{get_data_dir(q)}/{file_name}" | |
local_path = await q.site.download(file_url, path) | |
await q.site.unload(q.args["experiment/upload_yaml"][0]) | |
with open(local_path, "r") as f: | |
yaml_data = yaml.safe_load(f) | |
yaml_data = flatten_dict(yaml_data) | |
q.client["experiment/yaml_data"] = yaml_data | |
async def show_message(q, msg_key, page, idx, msg_type): | |
info = q.client[msg_key] | |
if info: | |
q.page[page].items[idx].message_bar.text = info | |
q.page[page].items[idx].message_bar.type = msg_type | |
q.client[msg_key] = "" | |
def get_experiment_list_message_bar(q): | |
if q.client["experiment_halt_reason"]: | |
msg_bar = ui.message_bar(type="error", text=q.client["experiment_halt_reason"]) | |
del q.client["experiment_halt_reason"] | |
elif q.client["force_disable_pipelines"]: | |
msg_bar = ui.message_bar(type="info", text=q.client["force_disable_pipelines"]) | |
del q.client["force_disable_pipelines"] | |
else: | |
msg_bar = ui.message_bar(type="info", text="") | |
return msg_bar | |
async def experiment_download_model(q: Q): | |
experiment = q.client["experiment/display/experiment"] | |
experiment_path = q.client["experiment/display/experiment_path"] | |
zip_path = get_model_path(experiment.name, experiment_path) | |
if not os.path.exists(zip_path): | |
logger.info(f"Creating {zip_path} on demand") | |
cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) | |
experiments = get_experiments(q) | |
num_running_queued = len( | |
experiments[experiments["status"].isin(["queued", "running"])] | |
) | |
if num_running_queued > 0 or ( | |
cfg.training.lora and cfg.architecture.backbone_dtype in ("int4", "int8") | |
): | |
logger.info("Preparing model on CPU. This might slow down the progress.") | |
device = "cpu" | |
else: | |
device = q.client["gpu_used_for_download"] | |
logger.info( | |
f"Preparing model on {device}. In case of issues or OOM consider " | |
"changing the default device for downloading in settings." | |
) | |
with set_env(HF_TOKEN=q.client["default_huggingface_api_token"]): | |
cfg, model, tokenizer = load_cfg_model_tokenizer( | |
experiment_path, merge=True, device=device | |
) | |
model = unwrap_model(model) | |
checkpoint_path = cfg.output_directory | |
model_save_time = time.time() | |
model.backbone.save_pretrained(checkpoint_path) | |
# See PreTrainedTokenizerBase.save_pretrained for documentation | |
# Safeguard against None return if tokenizer class is | |
# not inherited from PreTrainedTokenizerBase | |
if cfg.problem_type in GENERATION_PROBLEM_TYPES: | |
tokenizer.chat_template = get_chat_template(cfg) | |
tokenizer_files = list(tokenizer.save_pretrained(checkpoint_path) or []) | |
card = get_model_card(cfg, model, repo_id="<path_to_local_folder>") | |
card.save(os.path.join(experiment_path, "model_card.md")) | |
logger.info(f"Creating Zip File at {zip_path}") | |
zf = zipfile.ZipFile(zip_path, "w") | |
FILES_TO_PUSH = [ | |
"vocab.json", | |
"sentencepiece.bpe.model", | |
"bpe_encoder.bin", | |
"tokenizer_config.json", | |
"tokenizer.json", | |
"special_tokens_map.json", | |
"merges.txt", | |
"generation_config.json", | |
"config.json", | |
"added_tokens.json", | |
"model_card.md", | |
"classification_head.pth", | |
"regression_head.pth", | |
] | |
FILES_TO_PUSH = set( | |
FILES_TO_PUSH | |
+ [os.path.split(tokenizer_file)[-1] for tokenizer_file in tokenizer_files] | |
) | |
# Add tokenizer and config.json files, as well as potential classification head | |
paths_added = [] | |
for file in FILES_TO_PUSH: | |
path = os.path.join(experiment_path, file) | |
if os.path.isfile(path): | |
paths_added.append(path) | |
add_file_to_zip(zf=zf, path=path) | |
# Add model weight files. save_pretrained() does not return the saved files | |
weight_paths = glob.glob(os.path.join(checkpoint_path, "pytorch_model*.*")) | |
for path in weight_paths: | |
paths_added.append(path) | |
add_file_to_zip(zf=zf, path=path) | |
# Add all files that were created after the model was saved. | |
# This is useful for potential changes/different | |
# naming conventions across different backbones. | |
# Also adds newly generated safetensor files. | |
for file in os.listdir(checkpoint_path): | |
file_path = os.path.join(checkpoint_path, file) | |
if ( | |
os.path.getmtime(file_path) > model_save_time | |
and file_path not in paths_added | |
and file_path != zip_path | |
): | |
add_file_to_zip(zf=zf, path=file_path) | |
paths_added.append(file_path) | |
logger.info( | |
f"Added {file_path} to zip file as it " | |
"was created when saving the model state." | |
) | |
# Add all files from subdirectories, which include the intermediate checkpoints | |
subdirectories = [ | |
d | |
for d in os.listdir(checkpoint_path) | |
if os.path.isdir(os.path.join(checkpoint_path, d)) | |
] | |
for subdirectory in subdirectories: | |
for file in os.listdir(os.path.join(checkpoint_path, subdirectory)): | |
file_path = os.path.join(checkpoint_path, subdirectory, file) | |
add_file_to_zip(zf=zf, path=file_path, folder=subdirectory) | |
paths_added.append(file_path) | |
logger.info(f"Added {file_path} to zip file.") | |
zf.close() | |
download_url = get_download_link(q, zip_path) | |
logger.info(f"Logs URL: {download_url}") | |
q.page["meta"].script = ui.inline_script( | |
f'window.open("{download_url}", "_blank");' | |
) | |
await q.page.save() | |
async def experiment_download_adapter(q: Q): | |
experiment = q.client["experiment/display/experiment"] | |
experiment_path = q.client["experiment/display/experiment_path"] | |
zip_path = get_adapter_model_path(experiment.name, experiment_path) | |
if not os.path.exists(zip_path): | |
logger.info(f"Creating {zip_path} on demand") | |
logger.info(f"Creating Zip File at {zip_path}") | |
zf = zipfile.ZipFile(zip_path, "w") | |
FILES_TO_PUSH = [ | |
"adapter_model/adapter_config.json", | |
"adapter_model/adapter_model.safetensors", | |
"adapter_model/README.md", | |
] | |
paths_added = [] | |
for file in FILES_TO_PUSH: | |
path = os.path.join(experiment_path, file) | |
if os.path.isfile(path): | |
paths_added.append(path) | |
add_file_to_zip(zf=zf, path=path) | |
zf.close() | |
download_url = get_download_link(q, zip_path) | |
logger.info(f"Logs URL: {download_url}") | |
q.page["meta"].script = ui.inline_script( | |
f'window.open("{download_url}", "_blank");' | |
) | |
await q.page.save() | |
async def experiment_push_to_huggingface_dialog(q: Q, error: str = ""): | |
if q.args["experiment/display/push_to_huggingface"] or error: | |
devices = ["cpu", "cpu_shard"] + [ | |
f"cuda:{idx}" for idx in range(torch.cuda.device_count()) | |
] | |
default_device = "cuda:0" | |
experiments = get_experiments(q) | |
num_running_queued = len( | |
experiments[experiments["status"].isin(["queued", "running"])] | |
) | |
experiment_path = q.client["experiment/display/experiment_path"] | |
cfg = load_config_yaml(os.path.join(experiment_path, "cfg.yaml")) | |
if num_running_queued > 0 or cfg.environment.use_deepspeed: | |
default_device = "cpu" | |
try: | |
huggingface_hub.login(q.client["default_huggingface_api_token"]) | |
user_id = huggingface_hub.whoami()["name"] | |
except Exception: | |
user_id = "" | |
dialog_items = [ | |
ui.message_bar("error", error, visible=True if error else False), | |
ui.textbox( | |
name="experiment/display/push_to_huggingface/account_name", | |
label="Account Name", | |
value=user_id, | |
width="500px", | |
required=False, | |
tooltip=( | |
"The account name on HF to push the model to. " | |
"Leaving it empty will push it to the default user account." | |
), | |
), | |
ui.textbox( | |
name="experiment/display/push_to_huggingface/model_name", | |
label="Model Name", | |
value=hf_repo_friendly_name( | |
q.client["experiment/display/experiment"].name | |
), | |
width="500px", | |
required=True, | |
tooltip="The name of the model as shown on HF.", | |
), | |
ui.dropdown( | |
name="experiment/display/push_to_huggingface/device", | |
label="Device for preparing the model", | |
required=True, | |
value=default_device, | |
width="500px", | |
choices=[ui.choice(str(d), str(d)) for d in devices], | |
tooltip=( | |
"The local device to prepare the model before pushing it to HF. " | |
"CPU will never load the weights to the GPU, which can be useful " | |
"for large models, but will be significantly slower. " | |
"Cpu_shard will first load on CPU and then shard on all GPUs " | |
"before pushing to HF." | |
), | |
), | |
ui.textbox( | |
name="experiment/display/push_to_huggingface/api_key", | |
label="Hugging Face API Key", | |
value=q.client["default_huggingface_api_token"], | |
width="500px", | |
password=True, | |
required=True, | |
tooltip="HF API key, needs write access.", | |
), | |
ui.toggle( | |
name="default_safe_serialization", | |
label="Use Hugging Face safetensors for safe serialization", | |
value=q.client["default_safe_serialization"], | |
), | |
ui.buttons( | |
[ | |
ui.button( | |
name="experiment/display/push_to_huggingface_submit", | |
label="Export", | |
primary=True, | |
), | |
ui.button(name="cancel", label="Cancel", primary=False), | |
] | |
), | |
] | |
elif q.args["experiment/display/push_to_huggingface_submit"]: | |
await busy_dialog( | |
q=q, | |
title="Exporting to HuggingFace", | |
text="Model size can affect the export time significantly.", | |
) | |
experiment_path = q.client["experiment/display/experiment_path"] | |
device = q.client["experiment/display/push_to_huggingface/device"] | |
api_key = q.client["experiment/display/push_to_huggingface/api_key"] | |
user_id = q.client["experiment/display/push_to_huggingface/account_name"] | |
safe_serialization = q.client["default_safe_serialization"] | |
model_name = q.client[ | |
"experiment/display/push_to_huggingface/model_name" | |
].replace(".", "-") | |
publish_model_to_hugging_face( | |
path_to_experiment=experiment_path, | |
device=device, | |
api_key=api_key, | |
user_id=user_id, | |
model_name=model_name, | |
safe_serialization=safe_serialization, | |
hf_transfer=q.client["default_hf_hub_enable_hf_transfer"], | |
) | |
dialog_items = [ | |
ui.message_bar("success", "Success"), | |
ui.buttons( | |
[ | |
ui.button(name="ok", label="OK", primary=True), | |
] | |
), | |
] | |
dialog = ui.dialog( | |
title="Push to HuggingFace Hub", | |
items=dialog_items, | |
closable=True, | |
name="push_to_huggingface_dialog", | |
) | |
q.page["meta"].dialog = dialog | |
q.client["keep_meta"] = True | |
def get_experiment_summary_code_card(cfg) -> str: | |
repo_id: Optional[str] = None | |
hf_yaml_path = f"{cfg.output_directory}/hf.yaml" | |
with open( | |
os.path.join("model_cards", cfg.environment._summary_card_template), "r" | |
) as f: | |
text = f.read() | |
if os.path.exists(hf_yaml_path): | |
with open(hf_yaml_path, "r") as fp: | |
repo_id = yaml.load(fp, Loader=yaml.FullLoader)["repo_id"] | |
if repo_id is None: | |
repo_id = "account/model" | |
# Model repo | |
text = text.replace("{{repo_id}}", repo_id) | |
# Versions | |
text = text.replace("{{transformers_version}}", transformers.__version__) | |
text = text.replace("{{einops_version}}", einops.__version__) | |
text = text.replace("{{accelerate_version}}", accelerate.__version__) | |
text = text.replace("{{torch_version}}", torch.__version__) | |
# Configs | |
text = text.replace("{{text_prompt_start}}", str(cfg.dataset.text_prompt_start)) | |
text = text.replace( | |
"{{text_answer_separator}}", str(cfg.dataset.text_answer_separator) | |
) | |
text = text.replace( | |
"{{end_of_sentence}}", | |
( | |
str(cfg.tokenizer._tokenizer_eos_token) | |
if cfg.dataset.add_eos_token_to_prompt | |
else "" | |
), | |
) | |
text = text.replace("{{trust_remote_code}}", str(cfg.environment.trust_remote_code)) | |
if cfg.problem_type not in NON_GENERATION_PROBLEM_TYPES: | |
text = text.replace( | |
"{{min_new_tokens}}", str(cfg.prediction.min_length_inference) | |
) | |
text = text.replace( | |
"{{max_new_tokens}}", str(cfg.prediction.max_length_inference) | |
) | |
text = text.replace("{{do_sample}}", str(cfg.prediction.do_sample)) | |
text = text.replace("{{num_beams}}", str(cfg.prediction.num_beams)) | |
text = text.replace("{{temperature}}", str(cfg.prediction.temperature)) | |
text = text.replace( | |
"{{repetition_penalty}}", str(cfg.prediction.repetition_penalty) | |
) | |
if cfg.dataset.system_column != "None": | |
text = text.replace( | |
"{{sample_messages}}", | |
"""[ | |
{ | |
"role": "system", | |
"content": "You are a friendly and polite chatbot.", | |
}, | |
{"role": "user", "content": "Hi, how are you?"}, | |
{"role": "assistant", "content": "I'm doing great, how about you?"}, | |
{"role": "user", "content": "Why is drinking water so healthy?"}, | |
]""", | |
) | |
else: | |
text = text.replace( | |
"{{sample_messages}}", | |
"""[ | |
{"role": "user", "content": "Hi, how are you?"}, | |
{"role": "assistant", "content": "I'm doing great, how about you?"}, | |
{"role": "user", "content": "Why is drinking water so healthy?"}, | |
]""", | |
) | |
return text | |