llm-studio / llm_studio /app_utils /initializers.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
import logging
import os
import shutil
from tempfile import NamedTemporaryFile
from bokeh.resources import Resources as BokehResources
from h2o_wave import Q, ui
from llm_studio.app_utils.config import default_cfg
from llm_studio.app_utils.db import Database, Dataset
from llm_studio.app_utils.default_datasets import (
prepare_default_dataset_causal_language_modeling,
prepare_default_dataset_classification_modeling,
prepare_default_dataset_dpo_modeling,
prepare_default_dataset_regression_modeling,
)
from llm_studio.app_utils.sections.common import interface
from llm_studio.app_utils.setting_utils import load_user_settings_and_secrets
from llm_studio.app_utils.utils import (
get_data_dir,
get_database_dir,
get_download_dir,
get_output_dir,
get_user_db_path,
get_user_name,
)
from llm_studio.src.utils.config_utils import load_config_py, save_config_yaml
logger = logging.getLogger(__name__)
async def import_default_data(q: Q):
"""Imports default data"""
try:
if q.client.app_db.get_dataset(1) is None:
logger.info("Downloading default dataset...")
q.page["meta"].dialog = ui.dialog(
title="Creating default datasets",
blocking=True,
items=[ui.progress(label="Please be patient...")],
)
await q.page.save()
dataset = prepare_oasst(q)
q.client.app_db.add_dataset(dataset)
dataset = prepare_dpo(q)
q.client.app_db.add_dataset(dataset)
dataset = prepare_imdb(q)
q.client.app_db.add_dataset(dataset)
dataset = prepare_helpsteer(q)
q.client.app_db.add_dataset(dataset)
except Exception as e:
q.client.app_db._session.rollback()
logger.warning(f"Could not download default dataset: {e}")
pass
def prepare_oasst(q: Q) -> Dataset:
path = f"{get_data_dir(q)}/oasst"
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
df = prepare_default_dataset_causal_language_modeling(path)
cfg = load_config_py(
config_path=os.path.join("llm_studio/python_configs", default_cfg.cfg_file),
config_name="ConfigProblemBase",
)
cfg.dataset.train_dataframe = os.path.join(path, "train_full.pq")
cfg.dataset.prompt_column = ("instruction",)
cfg.dataset.answer_column = "output"
cfg.dataset.parent_id_column = "None"
cfg_path = os.path.join(path, f"{default_cfg.cfg_file}.yaml")
save_config_yaml(cfg_path, cfg)
dataset = Dataset(
id=1,
name="oasst",
path=path,
config_file=cfg_path,
train_rows=df.shape[0],
)
return dataset
def prepare_dpo(q: Q) -> Dataset:
path = f"{get_data_dir(q)}/dpo"
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
train_df = prepare_default_dataset_dpo_modeling()
train_df.to_parquet(os.path.join(path, "train.pq"), index=False)
from llm_studio.python_configs.text_dpo_modeling_config import ConfigDPODataset
from llm_studio.python_configs.text_dpo_modeling_config import (
ConfigProblemBase as ConfigProblemBaseDPO,
)
cfg: ConfigProblemBaseDPO = ConfigProblemBaseDPO(
dataset=ConfigDPODataset(
train_dataframe=os.path.join(path, "train.pq"),
system_column="system",
prompt_column=("question",),
answer_column="chosen",
rejected_answer_column="rejected",
),
)
cfg_path = os.path.join(path, "text_dpo_modeling_config.yaml")
save_config_yaml(cfg_path, cfg)
dataset = Dataset(
id=2,
name="dpo",
path=path,
config_file=cfg_path,
train_rows=train_df.shape[0],
)
return dataset
def prepare_imdb(q: Q) -> Dataset:
path = f"{get_data_dir(q)}/imdb"
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
train_df = prepare_default_dataset_classification_modeling()
train_df.to_parquet(os.path.join(path, "train.pq"), index=False)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigNLPCausalClassificationDataset,
)
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigProblemBase as ConfigProblemBaseClassification,
)
cfg: ConfigProblemBaseClassification = ConfigProblemBaseClassification(
dataset=ConfigNLPCausalClassificationDataset(
train_dataframe=os.path.join(path, "train.pq"),
prompt_column=("text",),
answer_column=("label",),
),
)
cfg_path = os.path.join(path, "text_causal_classification_modeling_config.yaml")
save_config_yaml(cfg_path, cfg)
dataset = Dataset(
id=3,
name="imdb",
path=path,
config_file=cfg_path,
train_rows=train_df.shape[0],
)
return dataset
def prepare_helpsteer(q: Q) -> Dataset:
path = f"{get_data_dir(q)}/helpsteer"
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
train_df = prepare_default_dataset_regression_modeling()
train_df.to_parquet(os.path.join(path, "train.pq"), index=False)
from llm_studio.python_configs.text_causal_regression_modeling_config import (
ConfigNLPCausalRegressionDataset,
)
from llm_studio.python_configs.text_causal_regression_modeling_config import (
ConfigProblemBase as ConfigProblemBaseRegression,
)
cfg: ConfigProblemBaseRegression = ConfigProblemBaseRegression(
dataset=ConfigNLPCausalRegressionDataset(
train_dataframe=os.path.join(path, "train.pq"),
prompt_column=("prompt", "response"),
answer_column=(
"helpfulness",
"correctness",
"coherence",
"complexity",
"verbosity",
),
),
)
cfg_path = os.path.join(path, "text_causal_regression_modeling_config.yaml")
save_config_yaml(cfg_path, cfg)
dataset = Dataset(
id=4,
name="helpsteer",
path=path,
config_file=cfg_path,
train_rows=train_df.shape[0],
)
return dataset
async def initialize_client(q: Q) -> None:
"""Initialize the client."""
if not q.client.client_initialized:
logger.info("Initializing client ...")
q.client.delete_cards = set()
q.client.delete_cards.add("init_app")
os.makedirs(get_data_dir(q), exist_ok=True)
os.makedirs(get_database_dir(q), exist_ok=True)
os.makedirs(get_output_dir(q), exist_ok=True)
os.makedirs(get_download_dir(q), exist_ok=True)
db_path = get_user_db_path(q)
q.client.app_db = Database(db_path)
logger.info(f"User name: {get_user_name(q)}")
q.client.client_initialized = True
q.client["mode_curr"] = "full"
load_user_settings_and_secrets(q)
await interface(q)
await import_default_data(q)
q.args.__wave_submission_name__ = default_cfg.start_page
logger.info("Initializing client ... done")
return
async def initialize_app(q: Q) -> None:
"""
Initialize the app.
This function is called once when the app is started and stores values in q.app.
"""
if not q.app.initialized:
logger.info("Initializing app ...")
icons_pth = "llm_studio/app_utils/static/"
(q.app["icon_path"],) = await q.site.upload([f"{icons_pth}/icon_300.svg"])
script_sources = []
with NamedTemporaryFile(mode="w", suffix=".min.js") as f:
# write all Bokeh scripts to one file to make sure
# they are loaded sequentially
for js_raw in BokehResources(mode="inline").js_raw:
f.write(js_raw)
f.write("\n")
(url,) = await q.site.upload([f.name])
script_sources.append(url)
q.app["script_sources"] = script_sources
q.app["initialized"] = True
q.app.version = default_cfg.version
q.app.name = default_cfg.name
q.app.heap_mode = default_cfg.heap_mode
logger.info("Initializing app ... done")