Spaces:
Running
Running
adds interface.py improvements to the flow and tab
Browse files- interface.py +538 -26
interface.py
CHANGED
@@ -598,6 +598,8 @@ class PipelineInputs:
|
|
598 |
scheduler_override: Optional[str]
|
599 |
min_lr: Optional[float]
|
600 |
min_lr_rate: Optional[float]
|
|
|
|
|
601 |
|
602 |
|
603 |
def make_defaults(model_family: str) -> Tuple[str, str]:
|
@@ -641,17 +643,28 @@ def run_pipeline(params: PipelineInputs) -> Generator[str, None, None]:
|
|
641 |
yield ("✅ " if ok else "⚠️ ") + msg
|
642 |
dataset_repo = rid
|
643 |
|
644 |
-
# Resolve config file and model name
|
645 |
conf_map = get_config_map(params.model_family)
|
646 |
-
if params.
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
base_model = getattr(cfg_obj, "model_name", base_model_fallback) if cfg_obj else base_model_fallback
|
656 |
dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
|
657 |
batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
|
@@ -681,6 +694,26 @@ def run_pipeline(params: PipelineInputs) -> Generator[str, None, None]:
|
|
681 |
for line in run_command_stream(args, env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
|
682 |
yield line
|
683 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
# Training output directory
|
685 |
out_dir = PROJECT_ROOT / "outputs" / f"{params.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
686 |
out_dir.mkdir(parents=True, exist_ok=True)
|
@@ -856,16 +889,27 @@ def on_family_change(family: str):
|
|
856 |
gr.update(visible=True), # show step 2 (trainer)
|
857 |
gr.update(visible=False), # hide step 3 until trainer selected
|
858 |
gr.update(visible=False), # hide step 4 until monitoring selected
|
859 |
-
gr.update(visible=
|
|
|
860 |
)
|
861 |
|
862 |
|
863 |
def on_config_change(family: str, config_choice: str):
|
864 |
-
"""When a prebuilt configuration is selected, update dataset info and helpful details.
|
|
|
|
|
|
|
865 |
if not config_choice:
|
866 |
return (
|
867 |
"",
|
868 |
gr.update(choices=[], value=None),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
)
|
870 |
|
871 |
conf_map = get_config_map(family)
|
@@ -896,7 +940,124 @@ def on_config_change(family: str, config_choice: str):
|
|
896 |
# dataset selection (allow custom but prefill with the config's dataset if any)
|
897 |
ds_choices = [dataset_name] if dataset_name else []
|
898 |
|
899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
900 |
|
901 |
|
902 |
def on_trainer_selected(_: str):
|
@@ -1070,17 +1231,108 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1070 |
|
1071 |
with gr.Tab("Advanced"):
|
1072 |
# GPT-OSS specific scheduler overrides
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1080 |
)
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
|
|
|
|
|
|
1084 |
|
1085 |
# Final action & logs
|
1086 |
start_btn = gr.Button("Run Pipeline", variant="primary")
|
@@ -1101,7 +1353,8 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1101 |
step2_group,
|
1102 |
step3_group,
|
1103 |
step4_group,
|
1104 |
-
|
|
|
1105 |
],
|
1106 |
)
|
1107 |
|
@@ -1116,11 +1369,224 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1116 |
config_choice.change(
|
1117 |
on_config_change,
|
1118 |
inputs=[model_family, config_choice],
|
1119 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
)
|
1121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1122 |
start_btn.click(
|
1123 |
-
|
1124 |
inputs=[
|
1125 |
model_family,
|
1126 |
config_choice,
|
@@ -1138,6 +1604,52 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1138 |
scheduler_override,
|
1139 |
min_lr,
|
1140 |
min_lr_rate,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
],
|
1142 |
outputs=[logs],
|
1143 |
)
|
|
|
598 |
scheduler_override: Optional[str]
|
599 |
min_lr: Optional[float]
|
600 |
min_lr_rate: Optional[float]
|
601 |
+
# Optional override config path generated from Advanced tab
|
602 |
+
override_config_path: Optional[str] = None
|
603 |
|
604 |
|
605 |
def make_defaults(model_family: str) -> Tuple[str, str]:
|
|
|
643 |
yield ("✅ " if ok else "⚠️ ") + msg
|
644 |
dataset_repo = rid
|
645 |
|
646 |
+
# Resolve config file and model name (allow override from Advanced tab)
|
647 |
conf_map = get_config_map(params.model_family)
|
648 |
+
if params.override_config_path:
|
649 |
+
config_file = Path(params.override_config_path)
|
650 |
+
if not config_file.exists():
|
651 |
+
yield f"❌ Generated config file not found: {config_file}"
|
652 |
+
return
|
653 |
+
# Best-effort to infer base model from generated config
|
654 |
+
cfg_obj = import_config_object(config_file)
|
655 |
+
base_model_fallback = getattr(cfg_obj, "model_name", None) or (
|
656 |
+
conf_map.get(params.config_choice, {}).get("default_model", "")
|
657 |
+
)
|
658 |
+
else:
|
659 |
+
if params.config_choice not in conf_map:
|
660 |
+
yield f"❌ Unknown config choice: {params.config_choice}"
|
661 |
+
return
|
662 |
+
config_file = PROJECT_ROOT / conf_map[params.config_choice]["config_file"]
|
663 |
+
base_model_fallback = conf_map[params.config_choice]["default_model"]
|
664 |
+
if not config_file.exists():
|
665 |
+
yield f"❌ Config file not found: {config_file}"
|
666 |
+
return
|
667 |
+
cfg_obj = import_config_object(config_file)
|
668 |
base_model = getattr(cfg_obj, "model_name", base_model_fallback) if cfg_obj else base_model_fallback
|
669 |
dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
|
670 |
batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
|
|
|
694 |
for line in run_command_stream(args, env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
|
695 |
yield line
|
696 |
|
697 |
+
# Dataset setup and Trackio configuration (mirror launch.sh) when monitoring is enabled
|
698 |
+
if params.monitoring_mode != "none":
|
699 |
+
# Ensure HF Dataset structure
|
700 |
+
yield f"\n=== Setting up HF Dataset: {dataset_repo} ==="
|
701 |
+
ds_args = [
|
702 |
+
str(PROJECT_ROOT / "scripts/dataset_tonic/setup_hf_dataset.py"),
|
703 |
+
write_token,
|
704 |
+
]
|
705 |
+
for line in run_command_stream(ds_args, env, cwd=PROJECT_ROOT / "scripts/dataset_tonic"):
|
706 |
+
yield line
|
707 |
+
# Configure Trackio Space
|
708 |
+
yield f"\n=== Configuring Trackio Space ({params.trackio_space_name or 'N/A'}) ==="
|
709 |
+
conf_args = [str(PROJECT_ROOT / "scripts/trackio_tonic/configure_trackio.py")]
|
710 |
+
# Use space deploy token (READ for dataset-only; WRITE otherwise)
|
711 |
+
conf_env = env.copy()
|
712 |
+
conf_env["HF_TOKEN"] = space_deploy_token
|
713 |
+
conf_env["HUGGING_FACE_HUB_TOKEN"] = space_deploy_token
|
714 |
+
for line in run_command_stream(conf_args, conf_env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
|
715 |
+
yield line
|
716 |
+
|
717 |
# Training output directory
|
718 |
out_dir = PROJECT_ROOT / "outputs" / f"{params.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
719 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
889 |
gr.update(visible=True), # show step 2 (trainer)
|
890 |
gr.update(visible=False), # hide step 3 until trainer selected
|
891 |
gr.update(visible=False), # hide step 4 until monitoring selected
|
892 |
+
gr.update(visible=False), # GPT-OSS advanced group hidden until enabled
|
893 |
+
gr.update(visible=False), # SmolLM3 advanced group hidden until enabled
|
894 |
)
|
895 |
|
896 |
|
897 |
def on_config_change(family: str, config_choice: str):
|
898 |
+
"""When a prebuilt configuration is selected, update dataset info and helpful details.
|
899 |
+
|
900 |
+
Also auto-fill advanced fields with defaults from the selected config.
|
901 |
+
"""
|
902 |
if not config_choice:
|
903 |
return (
|
904 |
"",
|
905 |
gr.update(choices=[], value=None),
|
906 |
+
# Advanced fields (GPT-OSS)
|
907 |
+
"", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
|
908 |
+
None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
|
909 |
+
2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
|
910 |
+
# Advanced fields (SmolLM3)
|
911 |
+
"HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
|
912 |
+
4096, 2, 8, 5e-6, 500, 100, 10,
|
913 |
)
|
914 |
|
915 |
conf_map = get_config_map(family)
|
|
|
940 |
# dataset selection (allow custom but prefill with the config's dataset if any)
|
941 |
ds_choices = [dataset_name] if dataset_name else []
|
942 |
|
943 |
+
# Defaults for Advanced (GPT-OSS)
|
944 |
+
adv_dataset_name = dataset_name or ("HuggingFaceH4/Multilingual-Thinking" if family == "GPT-OSS" else (dataset_name or ""))
|
945 |
+
adv_dataset_split = getattr(cfg_obj, "dataset_split", "train") if cfg_obj else "train"
|
946 |
+
# Infer dataset_format heuristically
|
947 |
+
if family == "GPT-OSS":
|
948 |
+
adv_dataset_format = getattr(cfg_obj, "dataset_format", None) or (
|
949 |
+
"messages" if getattr(cfg_obj, "input_field", "") == "messages" else "openhermes_fr"
|
950 |
+
)
|
951 |
+
adv_input_field = getattr(cfg_obj, "input_field", "prompt")
|
952 |
+
adv_target_field = getattr(cfg_obj, "target_field", "accepted_completion") or ""
|
953 |
+
adv_num_train_epochs = float(getattr(cfg_obj, "num_train_epochs", 1.0)) if cfg_obj and hasattr(cfg_obj, "num_train_epochs") else 1.0
|
954 |
+
adv_batch_size = int(getattr(cfg_obj, "batch_size", 4) or 4)
|
955 |
+
adv_gas = int(getattr(cfg_obj, "gradient_accumulation_steps", 4) or 4)
|
956 |
+
adv_lr = float(getattr(cfg_obj, "learning_rate", 2e-4) or 2e-4)
|
957 |
+
adv_min_lr = float(getattr(cfg_obj, "min_lr", 2e-5) or 2e-5)
|
958 |
+
adv_wd = float(getattr(cfg_obj, "weight_decay", 0.01) or 0.01)
|
959 |
+
adv_warmup = float(getattr(cfg_obj, "warmup_ratio", 0.03) or 0.03)
|
960 |
+
adv_msl = int(getattr(cfg_obj, "max_seq_length", 2048) or 2048)
|
961 |
+
lora_cfg = getattr(cfg_obj, "lora_config", {}) or {}
|
962 |
+
adv_lora_r = int(lora_cfg.get("r", 16))
|
963 |
+
adv_lora_alpha = int(lora_cfg.get("lora_alpha", 32))
|
964 |
+
adv_lora_dropout = float(lora_cfg.get("lora_dropout", 0.05))
|
965 |
+
adv_mixed_precision = "bf16" if getattr(cfg_obj, "bf16", True) else ("fp16" if getattr(cfg_obj, "fp16", False) else "fp32")
|
966 |
+
adv_num_workers = int(getattr(cfg_obj, "dataloader_num_workers", 4) or 4)
|
967 |
+
qcfg = getattr(cfg_obj, "quantization_config", {}) or {}
|
968 |
+
if qcfg.get("load_in_4bit", False):
|
969 |
+
adv_quantization_type = "bnb4"
|
970 |
+
elif qcfg.get("dequantize", False):
|
971 |
+
adv_quantization_type = "mxfp4"
|
972 |
+
else:
|
973 |
+
adv_quantization_type = "none"
|
974 |
+
adv_mgn = float(getattr(cfg_obj, "max_grad_norm", 1.0) or 1.0)
|
975 |
+
adv_log = int(getattr(cfg_obj, "logging_steps", 10) or 10)
|
976 |
+
adv_eval = int(getattr(cfg_obj, "eval_steps", 100) or 100)
|
977 |
+
adv_save = int(getattr(cfg_obj, "save_steps", 500) or 500)
|
978 |
+
else:
|
979 |
+
# SmolLM3 defaults for Advanced
|
980 |
+
adv_dataset_format = "openhermes_fr"
|
981 |
+
adv_input_field = getattr(cfg_obj, "input_field", "prompt") if cfg_obj else "prompt"
|
982 |
+
adv_target_field = getattr(cfg_obj, "target_field", "completion") if cfg_obj else "completion"
|
983 |
+
adv_num_train_epochs = 1.0
|
984 |
+
adv_batch_size = int(getattr(cfg_obj, "batch_size", 2) or 2)
|
985 |
+
adv_gas = int(getattr(cfg_obj, "gradient_accumulation_steps", 8) or 8)
|
986 |
+
adv_lr = float(getattr(cfg_obj, "learning_rate", 5e-6) or 5e-6)
|
987 |
+
adv_min_lr = float(getattr(cfg_obj, "min_lr", 1e-6) or 1e-6)
|
988 |
+
adv_wd = float(getattr(cfg_obj, "weight_decay", 0.01) or 0.01)
|
989 |
+
adv_warmup = float(getattr(cfg_obj, "warmup_steps", 100) or 100) # Smol uses steps
|
990 |
+
adv_msl = int(getattr(cfg_obj, "max_seq_length", 4096) or 4096)
|
991 |
+
adv_lora_r = 16
|
992 |
+
adv_lora_alpha = 32
|
993 |
+
adv_lora_dropout = 0.05
|
994 |
+
adv_mixed_precision = "fp16" if getattr(cfg_obj, "fp16", True) else ("bf16" if getattr(cfg_obj, "bf16", False) else "fp32")
|
995 |
+
adv_num_workers = int(getattr(cfg_obj, "dataloader_num_workers", 4) or 4)
|
996 |
+
adv_quantization_type = "none"
|
997 |
+
adv_mgn = float(getattr(cfg_obj, "max_grad_norm", 1.0) or 1.0)
|
998 |
+
adv_log = int(getattr(cfg_obj, "logging_steps", 10) or 10)
|
999 |
+
adv_eval = int(getattr(cfg_obj, "eval_steps", 100) or 100)
|
1000 |
+
adv_save = int(getattr(cfg_obj, "save_steps", 500) or 500)
|
1001 |
+
|
1002 |
+
# SmolLM3 advanced model/dataset
|
1003 |
+
adv_sm_model_name = getattr(cfg_obj, "model_name", "HuggingFaceTB/SmolLM3-3B") if cfg_obj else "HuggingFaceTB/SmolLM3-3B"
|
1004 |
+
adv_sm_dataset_name = dataset_name if family == "SmolLM3" else None
|
1005 |
+
adv_sm_input_field = adv_input_field
|
1006 |
+
adv_sm_target_field = adv_target_field
|
1007 |
+
adv_sm_filter_bad = bool(getattr(cfg_obj, "filter_bad_entries", False)) if cfg_obj else False
|
1008 |
+
adv_sm_sample_size = getattr(cfg_obj, "sample_size", None)
|
1009 |
+
adv_sm_sample_seed = getattr(cfg_obj, "sample_seed", 42)
|
1010 |
+
|
1011 |
+
return (
|
1012 |
+
training_md,
|
1013 |
+
gr.update(choices=ds_choices, value=(dataset_name or None)),
|
1014 |
+
# Advanced (GPT-OSS)
|
1015 |
+
adv_dataset_name,
|
1016 |
+
adv_dataset_split,
|
1017 |
+
adv_dataset_format,
|
1018 |
+
adv_input_field,
|
1019 |
+
adv_target_field,
|
1020 |
+
getattr(cfg_obj, "system_message", None) if cfg_obj else "",
|
1021 |
+
getattr(cfg_obj, "developer_message", None) if cfg_obj else "",
|
1022 |
+
getattr(cfg_obj, "chat_template_kwargs", {}).get("model_identity") if cfg_obj and getattr(cfg_obj, "chat_template_kwargs", None) else "",
|
1023 |
+
getattr(cfg_obj, "max_samples", None) if cfg_obj else None,
|
1024 |
+
int(getattr(cfg_obj, "min_length", 10) or 10) if cfg_obj else 10,
|
1025 |
+
getattr(cfg_obj, "max_length", None) if cfg_obj else None,
|
1026 |
+
adv_num_train_epochs,
|
1027 |
+
adv_batch_size,
|
1028 |
+
adv_gas,
|
1029 |
+
adv_lr,
|
1030 |
+
adv_min_lr,
|
1031 |
+
adv_wd,
|
1032 |
+
adv_warmup,
|
1033 |
+
adv_msl,
|
1034 |
+
adv_lora_r,
|
1035 |
+
adv_lora_alpha,
|
1036 |
+
adv_lora_dropout,
|
1037 |
+
adv_mixed_precision,
|
1038 |
+
adv_num_workers,
|
1039 |
+
adv_quantization_type,
|
1040 |
+
adv_mgn,
|
1041 |
+
adv_log,
|
1042 |
+
adv_eval,
|
1043 |
+
adv_save,
|
1044 |
+
# Advanced (SmolLM3)
|
1045 |
+
adv_sm_model_name,
|
1046 |
+
adv_sm_dataset_name,
|
1047 |
+
adv_sm_input_field,
|
1048 |
+
adv_sm_target_field,
|
1049 |
+
adv_sm_filter_bad,
|
1050 |
+
adv_sm_sample_size,
|
1051 |
+
adv_sm_sample_seed,
|
1052 |
+
# SmolLM3 training overrides
|
1053 |
+
int(getattr(cfg_obj, "max_seq_length", 4096) or 4096) if family == "SmolLM3" else 4096,
|
1054 |
+
int(getattr(cfg_obj, "batch_size", 2) or 2) if family == "SmolLM3" else 2,
|
1055 |
+
int(getattr(cfg_obj, "gradient_accumulation_steps", 8) or 8) if family == "SmolLM3" else 8,
|
1056 |
+
float(getattr(cfg_obj, "learning_rate", 5e-6) or 5e-6) if family == "SmolLM3" else 5e-6,
|
1057 |
+
int(getattr(cfg_obj, "save_steps", 500) or 500) if family == "SmolLM3" else 500,
|
1058 |
+
int(getattr(cfg_obj, "eval_steps", 100) or 100) if family == "SmolLM3" else 100,
|
1059 |
+
int(getattr(cfg_obj, "logging_steps", 10) or 10) if family == "SmolLM3" else 10,
|
1060 |
+
)
|
1061 |
|
1062 |
|
1063 |
def on_trainer_selected(_: str):
|
|
|
1231 |
|
1232 |
with gr.Tab("Advanced"):
|
1233 |
# GPT-OSS specific scheduler overrides
|
1234 |
+
advanced_enabled = gr.Checkbox(value=False, label="Use advanced overrides (generate config)")
|
1235 |
+
|
1236 |
+
# Family-specific advanced groups
|
1237 |
+
gpt_oss_advanced_group = gr.Group(visible=False)
|
1238 |
+
with gpt_oss_advanced_group:
|
1239 |
+
gr.Markdown("Advanced configuration for GPT-OSS")
|
1240 |
+
with gr.Accordion("Dataset", open=True):
|
1241 |
+
adv_dataset_name = gr.Textbox(value="", label="Dataset name")
|
1242 |
+
with gr.Row():
|
1243 |
+
adv_dataset_split = gr.Textbox(value="train", label="Dataset split")
|
1244 |
+
adv_dataset_format = gr.Dropdown(
|
1245 |
+
choices=["openhermes_fr", "messages", "text"],
|
1246 |
+
value="openhermes_fr",
|
1247 |
+
label="Dataset format",
|
1248 |
+
)
|
1249 |
+
with gr.Row():
|
1250 |
+
adv_input_field = gr.Textbox(value="prompt", label="Input field")
|
1251 |
+
adv_target_field = gr.Textbox(value="accepted_completion", label="Target field (optional)")
|
1252 |
+
with gr.Row():
|
1253 |
+
adv_system_message = gr.Textbox(value="", label="System message (optional)")
|
1254 |
+
adv_developer_message = gr.Textbox(value="", label="Developer message (optional)")
|
1255 |
+
adv_model_identity = gr.Textbox(value="", label="Model identity (optional)")
|
1256 |
+
with gr.Row():
|
1257 |
+
adv_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
|
1258 |
+
adv_min_length = gr.Number(value=10, precision=0, label="Min length")
|
1259 |
+
adv_max_length = gr.Number(value=None, precision=0, label="Max length (optional)")
|
1260 |
+
|
1261 |
+
with gr.Accordion("Training", open=True):
|
1262 |
+
with gr.Row():
|
1263 |
+
adv_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
|
1264 |
+
adv_batch_size = gr.Number(value=4, precision=0, label="Batch size")
|
1265 |
+
adv_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
|
1266 |
+
with gr.Row():
|
1267 |
+
adv_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
1268 |
+
adv_min_lr_num = gr.Number(value=2e-5, precision=6, label="Min LR")
|
1269 |
+
adv_weight_decay = gr.Number(value=0.01, precision=6, label="Weight decay")
|
1270 |
+
adv_warmup_ratio = gr.Number(value=0.03, precision=3, label="Warmup ratio")
|
1271 |
+
adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
|
1272 |
+
|
1273 |
+
with gr.Accordion("LoRA & Quantization", open=False):
|
1274 |
+
with gr.Row():
|
1275 |
+
adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
|
1276 |
+
adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
1277 |
+
adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
|
1278 |
+
with gr.Row():
|
1279 |
+
adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
|
1280 |
+
adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
|
1281 |
+
adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
|
1282 |
+
adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
|
1283 |
+
|
1284 |
+
with gr.Accordion("Eval & Logging", open=False):
|
1285 |
+
with gr.Row():
|
1286 |
+
adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1287 |
+
adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1288 |
+
adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1289 |
+
|
1290 |
+
with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
|
1291 |
+
scheduler_override = gr.Dropdown(
|
1292 |
+
choices=[c for c in SCHEDULER_CHOICES if c is not None],
|
1293 |
+
value=None,
|
1294 |
+
allow_custom_value=True,
|
1295 |
+
label="Scheduler override",
|
1296 |
+
)
|
1297 |
+
with gr.Row():
|
1298 |
+
min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
|
1299 |
+
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
|
1300 |
+
|
1301 |
+
smollm3_advanced_group = gr.Group(visible=False)
|
1302 |
+
with smollm3_advanced_group:
|
1303 |
+
gr.Markdown("Advanced configuration for SmolLM3")
|
1304 |
+
with gr.Accordion("Dataset", open=True):
|
1305 |
+
adv_sm_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
|
1306 |
+
adv_sm_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
|
1307 |
+
with gr.Row():
|
1308 |
+
adv_sm_input_field = gr.Textbox(value="prompt", label="Input field")
|
1309 |
+
adv_sm_target_field = gr.Textbox(value="completion", label="Target field")
|
1310 |
+
with gr.Row():
|
1311 |
+
adv_sm_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
|
1312 |
+
adv_sm_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
|
1313 |
+
adv_sm_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
|
1314 |
+
with gr.Accordion("Training", open=True):
|
1315 |
+
with gr.Row():
|
1316 |
+
adv_sm_max_seq_length = gr.Number(value=4096, precision=0, label="Max seq length")
|
1317 |
+
adv_sm_batch_size = gr.Number(value=2, precision=0, label="Batch size")
|
1318 |
+
adv_sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
|
1319 |
+
adv_sm_learning_rate = gr.Number(value=5e-6, precision=6, label="Learning rate")
|
1320 |
+
with gr.Row():
|
1321 |
+
adv_sm_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1322 |
+
adv_sm_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1323 |
+
adv_sm_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1324 |
+
|
1325 |
+
def _toggle_advanced(enable: bool, family_val: str):
|
1326 |
+
return (
|
1327 |
+
gr.update(visible=enable and family_val == "GPT-OSS"),
|
1328 |
+
gr.update(visible=enable and family_val == "SmolLM3"),
|
1329 |
)
|
1330 |
+
|
1331 |
+
advanced_enabled.change(
|
1332 |
+
_toggle_advanced,
|
1333 |
+
inputs=[advanced_enabled, model_family],
|
1334 |
+
outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
|
1335 |
+
)
|
1336 |
|
1337 |
# Final action & logs
|
1338 |
start_btn = gr.Button("Run Pipeline", variant="primary")
|
|
|
1353 |
step2_group,
|
1354 |
step3_group,
|
1355 |
step4_group,
|
1356 |
+
gpt_oss_advanced_group, # show advanced for GPT-OSS
|
1357 |
+
smollm3_advanced_group, # show advanced for SmolLM3
|
1358 |
],
|
1359 |
)
|
1360 |
|
|
|
1369 |
config_choice.change(
|
1370 |
on_config_change,
|
1371 |
inputs=[model_family, config_choice],
|
1372 |
+
outputs=[
|
1373 |
+
training_info,
|
1374 |
+
dataset_choice,
|
1375 |
+
# Advanced (GPT-OSS) outputs
|
1376 |
+
adv_dataset_name,
|
1377 |
+
adv_dataset_split,
|
1378 |
+
adv_dataset_format,
|
1379 |
+
adv_input_field,
|
1380 |
+
adv_target_field,
|
1381 |
+
adv_system_message,
|
1382 |
+
adv_developer_message,
|
1383 |
+
adv_model_identity,
|
1384 |
+
adv_max_samples,
|
1385 |
+
adv_min_length,
|
1386 |
+
adv_max_length,
|
1387 |
+
adv_num_train_epochs,
|
1388 |
+
adv_batch_size,
|
1389 |
+
adv_gradient_accumulation_steps,
|
1390 |
+
adv_learning_rate,
|
1391 |
+
adv_min_lr_num,
|
1392 |
+
adv_weight_decay,
|
1393 |
+
adv_warmup_ratio,
|
1394 |
+
adv_max_seq_length,
|
1395 |
+
adv_lora_r,
|
1396 |
+
adv_lora_alpha,
|
1397 |
+
adv_lora_dropout,
|
1398 |
+
adv_mixed_precision,
|
1399 |
+
adv_num_workers,
|
1400 |
+
adv_quantization_type,
|
1401 |
+
adv_max_grad_norm,
|
1402 |
+
adv_logging_steps,
|
1403 |
+
adv_eval_steps,
|
1404 |
+
adv_save_steps,
|
1405 |
+
# Advanced (SmolLM3)
|
1406 |
+
adv_sm_model_name,
|
1407 |
+
adv_sm_dataset_name,
|
1408 |
+
adv_sm_input_field,
|
1409 |
+
adv_sm_target_field,
|
1410 |
+
adv_sm_filter_bad_entries,
|
1411 |
+
adv_sm_sample_size,
|
1412 |
+
adv_sm_sample_seed,
|
1413 |
+
adv_sm_max_seq_length,
|
1414 |
+
adv_sm_batch_size,
|
1415 |
+
adv_sm_gas,
|
1416 |
+
adv_sm_learning_rate,
|
1417 |
+
adv_sm_save_steps,
|
1418 |
+
adv_sm_eval_steps,
|
1419 |
+
adv_sm_logging_steps,
|
1420 |
+
],
|
1421 |
)
|
1422 |
|
1423 |
+
# Keep Advanced dataset fields in sync when user selects a different dataset
|
1424 |
+
def _sync_dataset_fields(ds_value: Optional[str]):
|
1425 |
+
ds_text = ds_value or ""
|
1426 |
+
return ds_text, ds_text
|
1427 |
+
|
1428 |
+
dataset_choice.change(
|
1429 |
+
_sync_dataset_fields,
|
1430 |
+
inputs=[dataset_choice],
|
1431 |
+
outputs=[adv_dataset_name, adv_sm_dataset_name],
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
def _start_with_overrides(
|
1435 |
+
model_family_v,
|
1436 |
+
config_choice_v,
|
1437 |
+
trainer_type_v,
|
1438 |
+
monitoring_mode_v,
|
1439 |
+
experiment_name_v,
|
1440 |
+
repo_short_v,
|
1441 |
+
author_name_v,
|
1442 |
+
model_description_v,
|
1443 |
+
trackio_space_name_v,
|
1444 |
+
deploy_trackio_space_v,
|
1445 |
+
create_dataset_repo_v,
|
1446 |
+
push_to_hub_v,
|
1447 |
+
switch_to_read_after_v,
|
1448 |
+
scheduler_override_v,
|
1449 |
+
min_lr_v,
|
1450 |
+
min_lr_rate_v,
|
1451 |
+
advanced_enabled_v,
|
1452 |
+
# GPT-OSS advanced
|
1453 |
+
adv_dataset_name_v,
|
1454 |
+
adv_dataset_split_v,
|
1455 |
+
adv_dataset_format_v,
|
1456 |
+
adv_input_field_v,
|
1457 |
+
adv_target_field_v,
|
1458 |
+
adv_system_message_v,
|
1459 |
+
adv_developer_message_v,
|
1460 |
+
adv_model_identity_v,
|
1461 |
+
adv_max_samples_v,
|
1462 |
+
adv_min_length_v,
|
1463 |
+
adv_max_length_v,
|
1464 |
+
adv_num_train_epochs_v,
|
1465 |
+
adv_batch_size_v,
|
1466 |
+
adv_gas_v,
|
1467 |
+
adv_lr_v,
|
1468 |
+
adv_min_lr_num_v,
|
1469 |
+
adv_wd_v,
|
1470 |
+
adv_warmup_ratio_v,
|
1471 |
+
adv_max_seq_length_v,
|
1472 |
+
adv_lora_r_v,
|
1473 |
+
adv_lora_alpha_v,
|
1474 |
+
adv_lora_dropout_v,
|
1475 |
+
adv_mixed_precision_v,
|
1476 |
+
adv_num_workers_v,
|
1477 |
+
adv_quantization_type_v,
|
1478 |
+
adv_max_grad_norm_v,
|
1479 |
+
adv_logging_steps_v,
|
1480 |
+
adv_eval_steps_v,
|
1481 |
+
adv_save_steps_v,
|
1482 |
+
# SmolLM3 advanced
|
1483 |
+
adv_sm_model_name_v,
|
1484 |
+
adv_sm_dataset_name_v,
|
1485 |
+
adv_sm_input_field_v,
|
1486 |
+
adv_sm_target_field_v,
|
1487 |
+
adv_sm_filter_bad_entries_v,
|
1488 |
+
adv_sm_sample_size_v,
|
1489 |
+
adv_sm_sample_seed_v,
|
1490 |
+
adv_sm_max_seq_length_v,
|
1491 |
+
adv_sm_batch_size_v,
|
1492 |
+
adv_sm_gas_v,
|
1493 |
+
adv_sm_learning_rate_v,
|
1494 |
+
adv_sm_save_steps_v,
|
1495 |
+
adv_sm_eval_steps_v,
|
1496 |
+
adv_sm_logging_steps_v,
|
1497 |
+
):
|
1498 |
+
# If advanced overrides enabled, generate a config file and pass its path
|
1499 |
+
override_path: Optional[str] = None
|
1500 |
+
if advanced_enabled_v:
|
1501 |
+
try:
|
1502 |
+
if model_family_v == "GPT-OSS":
|
1503 |
+
cfg_path = generate_gpt_oss_custom_config_file(
|
1504 |
+
dataset_name=str(adv_dataset_name_v or ""),
|
1505 |
+
dataset_split=str(adv_dataset_split_v or "train"),
|
1506 |
+
dataset_format=str(adv_dataset_format_v or "openhermes_fr"),
|
1507 |
+
input_field=str(adv_input_field_v or "prompt"),
|
1508 |
+
target_field=(str(adv_target_field_v) if adv_target_field_v else None),
|
1509 |
+
system_message=(str(adv_system_message_v) if adv_system_message_v else None),
|
1510 |
+
developer_message=(str(adv_developer_message_v) if adv_developer_message_v else None),
|
1511 |
+
model_identity=(str(adv_model_identity_v) if adv_model_identity_v else None),
|
1512 |
+
max_samples=(int(adv_max_samples_v) if adv_max_samples_v else None),
|
1513 |
+
min_length=int(adv_min_length_v or 10),
|
1514 |
+
max_length=(int(adv_max_length_v) if adv_max_length_v else None),
|
1515 |
+
num_train_epochs=float(adv_num_train_epochs_v or 1.0),
|
1516 |
+
batch_size=int(adv_batch_size_v or 4),
|
1517 |
+
gradient_accumulation_steps=int(adv_gas_v or 4),
|
1518 |
+
learning_rate=float(adv_lr_v or 2e-4),
|
1519 |
+
min_lr=float(adv_min_lr_num_v or 2e-5),
|
1520 |
+
weight_decay=float(adv_wd_v or 0.01),
|
1521 |
+
warmup_ratio=float(adv_warmup_ratio_v or 0.03),
|
1522 |
+
max_seq_length=int(adv_max_seq_length_v or 2048),
|
1523 |
+
lora_r=int(adv_lora_r_v or 16),
|
1524 |
+
lora_alpha=int(adv_lora_alpha_v or 32),
|
1525 |
+
lora_dropout=float(adv_lora_dropout_v or 0.05),
|
1526 |
+
mixed_precision=str(adv_mixed_precision_v or "bf16"),
|
1527 |
+
num_workers=int(adv_num_workers_v or 4),
|
1528 |
+
quantization_type=str(adv_quantization_type_v or "mxfp4"),
|
1529 |
+
max_grad_norm=float(adv_max_grad_norm_v or 1.0),
|
1530 |
+
logging_steps=int(adv_logging_steps_v or 10),
|
1531 |
+
eval_steps=int(adv_eval_steps_v or 100),
|
1532 |
+
save_steps=int(adv_save_steps_v or 500),
|
1533 |
+
)
|
1534 |
+
else:
|
1535 |
+
cfg_path = generate_smollm3_custom_config_file(
|
1536 |
+
model_name=str(adv_sm_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
|
1537 |
+
dataset_name=(str(adv_sm_dataset_name_v) if adv_sm_dataset_name_v else None),
|
1538 |
+
max_seq_length=int(adv_sm_max_seq_length_v or 4096),
|
1539 |
+
batch_size=int(adv_sm_batch_size_v or 2),
|
1540 |
+
gradient_accumulation_steps=int(adv_sm_gas_v or 8),
|
1541 |
+
learning_rate=float(adv_sm_learning_rate_v or 5e-6),
|
1542 |
+
save_steps=int(adv_sm_save_steps_v or 500),
|
1543 |
+
eval_steps=int(adv_sm_eval_steps_v or 100),
|
1544 |
+
logging_steps=int(adv_sm_logging_steps_v or 10),
|
1545 |
+
filter_bad_entries=bool(adv_sm_filter_bad_entries_v),
|
1546 |
+
input_field=str(adv_sm_input_field_v or "prompt"),
|
1547 |
+
target_field=str(adv_sm_target_field_v or "completion"),
|
1548 |
+
sample_size=(int(adv_sm_sample_size_v) if adv_sm_sample_size_v else None),
|
1549 |
+
sample_seed=int(adv_sm_sample_seed_v or 42),
|
1550 |
+
trainer_type=str(trainer_type_v).lower(),
|
1551 |
+
)
|
1552 |
+
override_path = str(cfg_path)
|
1553 |
+
except Exception as e:
|
1554 |
+
# Surface error in logs via generator
|
1555 |
+
def _err_gen():
|
1556 |
+
yield f"❌ Failed to generate advanced config: {e}"
|
1557 |
+
return _err_gen()
|
1558 |
+
|
1559 |
+
def _gen():
|
1560 |
+
params = PipelineInputs(
|
1561 |
+
model_family=model_family_v,
|
1562 |
+
config_choice=config_choice_v,
|
1563 |
+
trainer_type=trainer_type_v,
|
1564 |
+
monitoring_mode=monitoring_mode_v,
|
1565 |
+
experiment_name=experiment_name_v,
|
1566 |
+
repo_short=repo_short_v,
|
1567 |
+
author_name=author_name_v,
|
1568 |
+
model_description=model_description_v,
|
1569 |
+
trackio_space_name=trackio_space_name_v or None,
|
1570 |
+
deploy_trackio_space=bool(deploy_trackio_space_v),
|
1571 |
+
create_dataset_repo=bool(create_dataset_repo_v),
|
1572 |
+
push_to_hub=bool(push_to_hub_v),
|
1573 |
+
switch_to_read_after=bool(switch_to_read_after_v),
|
1574 |
+
scheduler_override=(scheduler_override_v or None),
|
1575 |
+
min_lr=min_lr_v,
|
1576 |
+
min_lr_rate=min_lr_rate_v,
|
1577 |
+
override_config_path=override_path,
|
1578 |
+
)
|
1579 |
+
write_token = os.environ.get("HF_WRITE_TOKEN") or os.environ.get("HF_TOKEN")
|
1580 |
+
read_token = os.environ.get("HF_READ_TOKEN")
|
1581 |
+
yield f"HF_WRITE_TOKEN: {mask_token(write_token)}"
|
1582 |
+
yield f"HF_READ_TOKEN: {mask_token(read_token)}"
|
1583 |
+
for line in run_pipeline(params):
|
1584 |
+
yield line
|
1585 |
+
time.sleep(0.01)
|
1586 |
+
return _gen()
|
1587 |
+
|
1588 |
start_btn.click(
|
1589 |
+
_start_with_overrides,
|
1590 |
inputs=[
|
1591 |
model_family,
|
1592 |
config_choice,
|
|
|
1604 |
scheduler_override,
|
1605 |
min_lr,
|
1606 |
min_lr_rate,
|
1607 |
+
advanced_enabled,
|
1608 |
+
# GPT-OSS advanced
|
1609 |
+
adv_dataset_name,
|
1610 |
+
adv_dataset_split,
|
1611 |
+
adv_dataset_format,
|
1612 |
+
adv_input_field,
|
1613 |
+
adv_target_field,
|
1614 |
+
adv_system_message,
|
1615 |
+
adv_developer_message,
|
1616 |
+
adv_model_identity,
|
1617 |
+
adv_max_samples,
|
1618 |
+
adv_min_length,
|
1619 |
+
adv_max_length,
|
1620 |
+
adv_num_train_epochs,
|
1621 |
+
adv_batch_size,
|
1622 |
+
adv_gradient_accumulation_steps,
|
1623 |
+
adv_learning_rate,
|
1624 |
+
adv_min_lr_num,
|
1625 |
+
adv_weight_decay,
|
1626 |
+
adv_warmup_ratio,
|
1627 |
+
adv_max_seq_length,
|
1628 |
+
adv_lora_r,
|
1629 |
+
adv_lora_alpha,
|
1630 |
+
adv_lora_dropout,
|
1631 |
+
adv_mixed_precision,
|
1632 |
+
adv_num_workers,
|
1633 |
+
adv_quantization_type,
|
1634 |
+
adv_max_grad_norm,
|
1635 |
+
adv_logging_steps,
|
1636 |
+
adv_eval_steps,
|
1637 |
+
adv_save_steps,
|
1638 |
+
# SmolLM3 advanced
|
1639 |
+
adv_sm_model_name,
|
1640 |
+
adv_sm_dataset_name,
|
1641 |
+
adv_sm_input_field,
|
1642 |
+
adv_sm_target_field,
|
1643 |
+
adv_sm_filter_bad_entries,
|
1644 |
+
adv_sm_sample_size,
|
1645 |
+
adv_sm_sample_seed,
|
1646 |
+
adv_sm_max_seq_length,
|
1647 |
+
adv_sm_batch_size,
|
1648 |
+
adv_sm_gas,
|
1649 |
+
adv_sm_learning_rate,
|
1650 |
+
adv_sm_save_steps,
|
1651 |
+
adv_sm_eval_steps,
|
1652 |
+
adv_sm_logging_steps,
|
1653 |
],
|
1654 |
outputs=[logs],
|
1655 |
)
|