Spaces:
Runtime error
Runtime error
import gradio as gr | |
import logging | |
import os | |
import threading | |
import time | |
import transformers | |
from transformers.trainer import TRAINING_ARGS_NAME | |
from typing import Any, Dict, Generator, List, Tuple | |
from llmtuner.extras.callbacks import LogCallback | |
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES | |
from llmtuner.extras.logging import LoggerHandler | |
from llmtuner.extras.misc import torch_gc | |
from llmtuner.tuner import run_exp | |
from llmtuner.webui.common import get_model_path, get_save_dir, load_config | |
from llmtuner.webui.locales import ALERTS | |
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar | |
class Runner: | |
def __init__(self): | |
self.aborted = False | |
self.running = False | |
self.logger_handler = LoggerHandler() | |
self.logger_handler.setLevel(logging.INFO) | |
logging.root.addHandler(self.logger_handler) | |
transformers.logging.add_handler(self.logger_handler) | |
def set_abort(self): | |
self.aborted = True | |
self.running = False | |
def _initialize( | |
self, lang: str, model_name: str, dataset: List[str] | |
) -> str: | |
if self.running: | |
return ALERTS["err_conflict"][lang] | |
if not model_name: | |
return ALERTS["err_no_model"][lang] | |
if not get_model_path(model_name): | |
return ALERTS["err_no_path"][lang] | |
if len(dataset) == 0: | |
return ALERTS["err_no_dataset"][lang] | |
self.aborted = False | |
self.logger_handler.reset() | |
self.trainer_callback = LogCallback(self) | |
return "" | |
def _finalize( | |
self, lang: str, finish_info: str | |
) -> str: | |
self.running = False | |
torch_gc() | |
if self.aborted: | |
return ALERTS["info_aborted"][lang] | |
else: | |
return finish_info | |
def _parse_train_args( | |
self, | |
lang: str, | |
model_name: str, | |
checkpoints: List[str], | |
finetuning_type: str, | |
quantization_bit: str, | |
template: str, | |
system_prompt: str, | |
training_stage: str, | |
dataset_dir: str, | |
dataset: List[str], | |
cutoff_len: int, | |
learning_rate: str, | |
num_train_epochs: str, | |
max_samples: str, | |
compute_type: str, | |
batch_size: int, | |
gradient_accumulation_steps: int, | |
lr_scheduler_type: str, | |
max_grad_norm: str, | |
val_size: float, | |
logging_steps: int, | |
save_steps: int, | |
warmup_steps: int, | |
flash_attn: bool, | |
rope_scaling: bool, | |
lora_rank: int, | |
lora_dropout: float, | |
lora_target: str, | |
resume_lora_training: bool, | |
dpo_beta: float, | |
reward_model: str, | |
output_dir: str | |
) -> Tuple[str, str, List[str], str, Dict[str, Any]]: | |
if checkpoints: | |
checkpoint_dir = ",".join( | |
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] | |
) | |
else: | |
checkpoint_dir = None | |
output_dir = get_save_dir(model_name, finetuning_type, output_dir) | |
user_config = load_config() | |
cache_dir = user_config.get("cache_dir", None) | |
args = dict( | |
stage=TRAINING_STAGES[training_stage], | |
model_name_or_path=get_model_path(model_name), | |
do_train=True, | |
overwrite_cache=False, | |
cache_dir=cache_dir, | |
checkpoint_dir=checkpoint_dir, | |
finetuning_type=finetuning_type, | |
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, | |
template=template, | |
system_prompt=system_prompt, | |
dataset_dir=dataset_dir, | |
dataset=",".join(dataset), | |
cutoff_len=cutoff_len, | |
learning_rate=float(learning_rate), | |
num_train_epochs=float(num_train_epochs), | |
max_samples=int(max_samples), | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
lr_scheduler_type=lr_scheduler_type, | |
max_grad_norm=float(max_grad_norm), | |
logging_steps=logging_steps, | |
save_steps=save_steps, | |
warmup_steps=warmup_steps, | |
flash_attn=flash_attn, | |
rope_scaling="linear" if rope_scaling else None, | |
lora_rank=lora_rank, | |
lora_dropout=lora_dropout, | |
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), | |
resume_lora_training=( | |
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training | |
), | |
output_dir=output_dir | |
) | |
args[compute_type] = True | |
if args["stage"] == "ppo": | |
args["reward_model"] = reward_model | |
val_size = 0 | |
if args["stage"] == "dpo": | |
args["dpo_beta"] = dpo_beta | |
if val_size > 1e-6: | |
args["val_size"] = val_size | |
args["evaluation_strategy"] = "steps" | |
args["eval_steps"] = save_steps | |
args["load_best_model_at_end"] = True | |
return lang, model_name, dataset, output_dir, args | |
def _parse_eval_args( | |
self, | |
lang: str, | |
model_name: str, | |
checkpoints: List[str], | |
finetuning_type: str, | |
quantization_bit: str, | |
template: str, | |
system_prompt: str, | |
dataset_dir: str, | |
dataset: List[str], | |
cutoff_len: int, | |
max_samples: str, | |
batch_size: int, | |
predict: bool, | |
max_new_tokens: int, | |
top_p: float, | |
temperature: float | |
) -> Tuple[str, str, List[str], str, Dict[str, Any]]: | |
if checkpoints: | |
checkpoint_dir = ",".join( | |
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] | |
) | |
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints)) | |
else: | |
checkpoint_dir = None | |
output_dir = get_save_dir(model_name, finetuning_type, "eval_base") | |
user_config = load_config() | |
cache_dir = user_config.get("cache_dir", None) | |
args = dict( | |
stage="sft", | |
model_name_or_path=get_model_path(model_name), | |
do_eval=True, | |
overwrite_cache=False, | |
predict_with_generate=True, | |
cache_dir=cache_dir, | |
checkpoint_dir=checkpoint_dir, | |
finetuning_type=finetuning_type, | |
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, | |
template=template, | |
system_prompt=system_prompt, | |
dataset_dir=dataset_dir, | |
dataset=",".join(dataset), | |
cutoff_len=cutoff_len, | |
max_samples=int(max_samples), | |
per_device_eval_batch_size=batch_size, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
temperature=temperature, | |
output_dir=output_dir | |
) | |
if predict: | |
args.pop("do_eval", None) | |
args["do_predict"] = True | |
return lang, model_name, dataset, output_dir, args | |
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: | |
lang, model_name, dataset, _, args = self._parse_train_args(*args) | |
error = self._initialize(lang, model_name, dataset) | |
if error: | |
yield error, gr.update(visible=False) | |
else: | |
yield gen_cmd(args), gr.update(visible=False) | |
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: | |
lang, model_name, dataset, _, args = self._parse_eval_args(*args) | |
error = self._initialize(lang, model_name, dataset) | |
if error: | |
yield error, gr.update(visible=False) | |
else: | |
yield gen_cmd(args), gr.update(visible=False) | |
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: | |
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args) | |
error = self._initialize(lang, model_name, dataset) | |
if error: | |
yield error, gr.update(visible=False) | |
return | |
self.running = True | |
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) | |
thread = threading.Thread(target=run_exp, kwargs=run_kwargs) | |
thread.start() | |
while thread.is_alive(): | |
time.sleep(2) | |
if self.aborted: | |
yield ALERTS["info_aborting"][lang], gr.update(visible=False) | |
else: | |
yield self.logger_handler.log, update_process_bar(self.trainer_callback) | |
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): | |
finish_info = ALERTS["info_finished"][lang] | |
else: | |
finish_info = ALERTS["err_failed"][lang] | |
yield self._finalize(lang, finish_info), gr.update(visible=False) | |
def run_eval(self, *args) -> Generator[str, None, None]: | |
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args) | |
error = self._initialize(lang, model_name, dataset) | |
if error: | |
yield error, gr.update(visible=False) | |
return | |
self.running = True | |
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) | |
thread = threading.Thread(target=run_exp, kwargs=run_kwargs) | |
thread.start() | |
while thread.is_alive(): | |
time.sleep(2) | |
if self.aborted: | |
yield ALERTS["info_aborting"][lang], gr.update(visible=False) | |
else: | |
yield self.logger_handler.log, update_process_bar(self.trainer_callback) | |
if os.path.exists(os.path.join(output_dir, "all_results.json")): | |
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) | |
else: | |
finish_info = ALERTS["err_failed"][lang] | |
yield self._finalize(lang, finish_info), gr.update(visible=False) | |