Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import shutil | |
from tempfile import NamedTemporaryFile | |
from bokeh.resources import Resources as BokehResources | |
from h2o_wave import Q, ui | |
from llm_studio.app_utils.config import default_cfg | |
from llm_studio.app_utils.db import Database, Dataset | |
from llm_studio.app_utils.default_datasets import ( | |
prepare_default_dataset_causal_language_modeling, | |
prepare_default_dataset_classification_modeling, | |
prepare_default_dataset_dpo_modeling, | |
prepare_default_dataset_regression_modeling, | |
) | |
from llm_studio.app_utils.sections.common import interface | |
from llm_studio.app_utils.setting_utils import load_user_settings_and_secrets | |
from llm_studio.app_utils.utils import ( | |
get_data_dir, | |
get_database_dir, | |
get_download_dir, | |
get_output_dir, | |
get_user_db_path, | |
get_user_name, | |
) | |
from llm_studio.src.utils.config_utils import load_config_py, save_config_yaml | |
logger = logging.getLogger(__name__) | |
async def import_default_data(q: Q): | |
"""Imports default data""" | |
try: | |
if q.client.app_db.get_dataset(1) is None: | |
logger.info("Downloading default dataset...") | |
q.page["meta"].dialog = ui.dialog( | |
title="Creating default datasets", | |
blocking=True, | |
items=[ui.progress(label="Please be patient...")], | |
) | |
await q.page.save() | |
dataset = prepare_oasst(q) | |
q.client.app_db.add_dataset(dataset) | |
dataset = prepare_dpo(q) | |
q.client.app_db.add_dataset(dataset) | |
dataset = prepare_imdb(q) | |
q.client.app_db.add_dataset(dataset) | |
dataset = prepare_helpsteer(q) | |
q.client.app_db.add_dataset(dataset) | |
except Exception as e: | |
q.client.app_db._session.rollback() | |
logger.warning(f"Could not download default dataset: {e}") | |
pass | |
def prepare_oasst(q: Q) -> Dataset: | |
path = f"{get_data_dir(q)}/oasst" | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
os.makedirs(path, exist_ok=True) | |
df = prepare_default_dataset_causal_language_modeling(path) | |
cfg = load_config_py( | |
config_path=os.path.join("llm_studio/python_configs", default_cfg.cfg_file), | |
config_name="ConfigProblemBase", | |
) | |
cfg.dataset.train_dataframe = os.path.join(path, "train_full.pq") | |
cfg.dataset.prompt_column = ("instruction",) | |
cfg.dataset.answer_column = "output" | |
cfg.dataset.parent_id_column = "None" | |
cfg_path = os.path.join(path, f"{default_cfg.cfg_file}.yaml") | |
save_config_yaml(cfg_path, cfg) | |
dataset = Dataset( | |
id=1, | |
name="oasst", | |
path=path, | |
config_file=cfg_path, | |
train_rows=df.shape[0], | |
) | |
return dataset | |
def prepare_dpo(q: Q) -> Dataset: | |
path = f"{get_data_dir(q)}/dpo" | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
os.makedirs(path, exist_ok=True) | |
train_df = prepare_default_dataset_dpo_modeling() | |
train_df.to_parquet(os.path.join(path, "train.pq"), index=False) | |
from llm_studio.python_configs.text_dpo_modeling_config import ConfigDPODataset | |
from llm_studio.python_configs.text_dpo_modeling_config import ( | |
ConfigProblemBase as ConfigProblemBaseDPO, | |
) | |
cfg: ConfigProblemBaseDPO = ConfigProblemBaseDPO( | |
dataset=ConfigDPODataset( | |
train_dataframe=os.path.join(path, "train.pq"), | |
system_column="system", | |
prompt_column=("question",), | |
answer_column="chosen", | |
rejected_answer_column="rejected", | |
), | |
) | |
cfg_path = os.path.join(path, "text_dpo_modeling_config.yaml") | |
save_config_yaml(cfg_path, cfg) | |
dataset = Dataset( | |
id=2, | |
name="dpo", | |
path=path, | |
config_file=cfg_path, | |
train_rows=train_df.shape[0], | |
) | |
return dataset | |
def prepare_imdb(q: Q) -> Dataset: | |
path = f"{get_data_dir(q)}/imdb" | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
os.makedirs(path, exist_ok=True) | |
train_df = prepare_default_dataset_classification_modeling() | |
train_df.to_parquet(os.path.join(path, "train.pq"), index=False) | |
from llm_studio.python_configs.text_causal_classification_modeling_config import ( | |
ConfigNLPCausalClassificationDataset, | |
) | |
from llm_studio.python_configs.text_causal_classification_modeling_config import ( | |
ConfigProblemBase as ConfigProblemBaseClassification, | |
) | |
cfg: ConfigProblemBaseClassification = ConfigProblemBaseClassification( | |
dataset=ConfigNLPCausalClassificationDataset( | |
train_dataframe=os.path.join(path, "train.pq"), | |
prompt_column=("text",), | |
answer_column=("label",), | |
), | |
) | |
cfg_path = os.path.join(path, "text_causal_classification_modeling_config.yaml") | |
save_config_yaml(cfg_path, cfg) | |
dataset = Dataset( | |
id=3, | |
name="imdb", | |
path=path, | |
config_file=cfg_path, | |
train_rows=train_df.shape[0], | |
) | |
return dataset | |
def prepare_helpsteer(q: Q) -> Dataset: | |
path = f"{get_data_dir(q)}/helpsteer" | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
os.makedirs(path, exist_ok=True) | |
train_df = prepare_default_dataset_regression_modeling() | |
train_df.to_parquet(os.path.join(path, "train.pq"), index=False) | |
from llm_studio.python_configs.text_causal_regression_modeling_config import ( | |
ConfigNLPCausalRegressionDataset, | |
) | |
from llm_studio.python_configs.text_causal_regression_modeling_config import ( | |
ConfigProblemBase as ConfigProblemBaseRegression, | |
) | |
cfg: ConfigProblemBaseRegression = ConfigProblemBaseRegression( | |
dataset=ConfigNLPCausalRegressionDataset( | |
train_dataframe=os.path.join(path, "train.pq"), | |
prompt_column=("prompt", "response"), | |
answer_column=( | |
"helpfulness", | |
"correctness", | |
"coherence", | |
"complexity", | |
"verbosity", | |
), | |
), | |
) | |
cfg_path = os.path.join(path, "text_causal_regression_modeling_config.yaml") | |
save_config_yaml(cfg_path, cfg) | |
dataset = Dataset( | |
id=4, | |
name="helpsteer", | |
path=path, | |
config_file=cfg_path, | |
train_rows=train_df.shape[0], | |
) | |
return dataset | |
async def initialize_client(q: Q) -> None: | |
"""Initialize the client.""" | |
if not q.client.client_initialized: | |
logger.info("Initializing client ...") | |
q.client.delete_cards = set() | |
q.client.delete_cards.add("init_app") | |
os.makedirs(get_data_dir(q), exist_ok=True) | |
os.makedirs(get_database_dir(q), exist_ok=True) | |
os.makedirs(get_output_dir(q), exist_ok=True) | |
os.makedirs(get_download_dir(q), exist_ok=True) | |
db_path = get_user_db_path(q) | |
q.client.app_db = Database(db_path) | |
logger.info(f"User name: {get_user_name(q)}") | |
q.client.client_initialized = True | |
q.client["mode_curr"] = "full" | |
load_user_settings_and_secrets(q) | |
await interface(q) | |
await import_default_data(q) | |
q.args.__wave_submission_name__ = default_cfg.start_page | |
logger.info("Initializing client ... done") | |
return | |
async def initialize_app(q: Q) -> None: | |
""" | |
Initialize the app. | |
This function is called once when the app is started and stores values in q.app. | |
""" | |
if not q.app.initialized: | |
logger.info("Initializing app ...") | |
icons_pth = "llm_studio/app_utils/static/" | |
(q.app["icon_path"],) = await q.site.upload([f"{icons_pth}/icon_300.svg"]) | |
script_sources = [] | |
with NamedTemporaryFile(mode="w", suffix=".min.js") as f: | |
# write all Bokeh scripts to one file to make sure | |
# they are loaded sequentially | |
for js_raw in BokehResources(mode="inline").js_raw: | |
f.write(js_raw) | |
f.write("\n") | |
(url,) = await q.site.upload([f.name]) | |
script_sources.append(url) | |
q.app["script_sources"] = script_sources | |
q.app["initialized"] = True | |
q.app.version = default_cfg.version | |
q.app.name = default_cfg.name | |
q.app.heap_mode = default_cfg.heap_mode | |
logger.info("Initializing app ... done") | |