Tonic commited on
Commit
4b9e23f
·
1 Parent(s): a8fc78d

adds interface.py improvements to the flow and tab

Browse files
Files changed (1) hide show
  1. interface.py +538 -26
interface.py CHANGED
@@ -598,6 +598,8 @@ class PipelineInputs:
598
  scheduler_override: Optional[str]
599
  min_lr: Optional[float]
600
  min_lr_rate: Optional[float]
 
 
601
 
602
 
603
  def make_defaults(model_family: str) -> Tuple[str, str]:
@@ -641,17 +643,28 @@ def run_pipeline(params: PipelineInputs) -> Generator[str, None, None]:
641
  yield ("✅ " if ok else "⚠️ ") + msg
642
  dataset_repo = rid
643
 
644
- # Resolve config file and model name
645
  conf_map = get_config_map(params.model_family)
646
- if params.config_choice not in conf_map:
647
- yield f"❌ Unknown config choice: {params.config_choice}"
648
- return
649
- config_file = PROJECT_ROOT / conf_map[params.config_choice]["config_file"]
650
- base_model_fallback = conf_map[params.config_choice]["default_model"]
651
- if not config_file.exists():
652
- yield f"❌ Config file not found: {config_file}"
653
- return
654
- cfg_obj = import_config_object(config_file)
 
 
 
 
 
 
 
 
 
 
 
655
  base_model = getattr(cfg_obj, "model_name", base_model_fallback) if cfg_obj else base_model_fallback
656
  dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
657
  batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
@@ -681,6 +694,26 @@ def run_pipeline(params: PipelineInputs) -> Generator[str, None, None]:
681
  for line in run_command_stream(args, env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
682
  yield line
683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  # Training output directory
685
  out_dir = PROJECT_ROOT / "outputs" / f"{params.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
686
  out_dir.mkdir(parents=True, exist_ok=True)
@@ -856,16 +889,27 @@ def on_family_change(family: str):
856
  gr.update(visible=True), # show step 2 (trainer)
857
  gr.update(visible=False), # hide step 3 until trainer selected
858
  gr.update(visible=False), # hide step 4 until monitoring selected
859
- gr.update(visible=(family == "GPT-OSS")), # advanced (scheduler) visibility
 
860
  )
861
 
862
 
863
  def on_config_change(family: str, config_choice: str):
864
- """When a prebuilt configuration is selected, update dataset info and helpful details."""
 
 
 
865
  if not config_choice:
866
  return (
867
  "",
868
  gr.update(choices=[], value=None),
 
 
 
 
 
 
 
869
  )
870
 
871
  conf_map = get_config_map(family)
@@ -896,7 +940,124 @@ def on_config_change(family: str, config_choice: str):
896
  # dataset selection (allow custom but prefill with the config's dataset if any)
897
  ds_choices = [dataset_name] if dataset_name else []
898
 
899
- return training_md, gr.update(choices=ds_choices, value=(dataset_name or None))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
 
901
 
902
  def on_trainer_selected(_: str):
@@ -1070,17 +1231,108 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1070
 
1071
  with gr.Tab("Advanced"):
1072
  # GPT-OSS specific scheduler overrides
1073
- advanced_scheduler_group = gr.Group(visible=False)
1074
- with advanced_scheduler_group:
1075
- scheduler_override = gr.Dropdown(
1076
- choices=[c for c in SCHEDULER_CHOICES if c is not None],
1077
- value=None,
1078
- allow_custom_value=True,
1079
- label="Scheduler override",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1080
  )
1081
- with gr.Row():
1082
- min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
1083
- min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
 
 
 
1084
 
1085
  # Final action & logs
1086
  start_btn = gr.Button("Run Pipeline", variant="primary")
@@ -1101,7 +1353,8 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1101
  step2_group,
1102
  step3_group,
1103
  step4_group,
1104
- advanced_scheduler_group,
 
1105
  ],
1106
  )
1107
 
@@ -1116,11 +1369,224 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1116
  config_choice.change(
1117
  on_config_change,
1118
  inputs=[model_family, config_choice],
1119
- outputs=[training_info, dataset_choice],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  )
1121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
  start_btn.click(
1123
- start_pipeline,
1124
  inputs=[
1125
  model_family,
1126
  config_choice,
@@ -1138,6 +1604,52 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1138
  scheduler_override,
1139
  min_lr,
1140
  min_lr_rate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  ],
1142
  outputs=[logs],
1143
  )
 
598
  scheduler_override: Optional[str]
599
  min_lr: Optional[float]
600
  min_lr_rate: Optional[float]
601
+ # Optional override config path generated from Advanced tab
602
+ override_config_path: Optional[str] = None
603
 
604
 
605
  def make_defaults(model_family: str) -> Tuple[str, str]:
 
643
  yield ("✅ " if ok else "⚠️ ") + msg
644
  dataset_repo = rid
645
 
646
+ # Resolve config file and model name (allow override from Advanced tab)
647
  conf_map = get_config_map(params.model_family)
648
+ if params.override_config_path:
649
+ config_file = Path(params.override_config_path)
650
+ if not config_file.exists():
651
+ yield f"❌ Generated config file not found: {config_file}"
652
+ return
653
+ # Best-effort to infer base model from generated config
654
+ cfg_obj = import_config_object(config_file)
655
+ base_model_fallback = getattr(cfg_obj, "model_name", None) or (
656
+ conf_map.get(params.config_choice, {}).get("default_model", "")
657
+ )
658
+ else:
659
+ if params.config_choice not in conf_map:
660
+ yield f"❌ Unknown config choice: {params.config_choice}"
661
+ return
662
+ config_file = PROJECT_ROOT / conf_map[params.config_choice]["config_file"]
663
+ base_model_fallback = conf_map[params.config_choice]["default_model"]
664
+ if not config_file.exists():
665
+ yield f"❌ Config file not found: {config_file}"
666
+ return
667
+ cfg_obj = import_config_object(config_file)
668
  base_model = getattr(cfg_obj, "model_name", base_model_fallback) if cfg_obj else base_model_fallback
669
  dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
670
  batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
 
694
  for line in run_command_stream(args, env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
695
  yield line
696
 
697
+ # Dataset setup and Trackio configuration (mirror launch.sh) when monitoring is enabled
698
+ if params.monitoring_mode != "none":
699
+ # Ensure HF Dataset structure
700
+ yield f"\n=== Setting up HF Dataset: {dataset_repo} ==="
701
+ ds_args = [
702
+ str(PROJECT_ROOT / "scripts/dataset_tonic/setup_hf_dataset.py"),
703
+ write_token,
704
+ ]
705
+ for line in run_command_stream(ds_args, env, cwd=PROJECT_ROOT / "scripts/dataset_tonic"):
706
+ yield line
707
+ # Configure Trackio Space
708
+ yield f"\n=== Configuring Trackio Space ({params.trackio_space_name or 'N/A'}) ==="
709
+ conf_args = [str(PROJECT_ROOT / "scripts/trackio_tonic/configure_trackio.py")]
710
+ # Use space deploy token (READ for dataset-only; WRITE otherwise)
711
+ conf_env = env.copy()
712
+ conf_env["HF_TOKEN"] = space_deploy_token
713
+ conf_env["HUGGING_FACE_HUB_TOKEN"] = space_deploy_token
714
+ for line in run_command_stream(conf_args, conf_env, cwd=PROJECT_ROOT / "scripts/trackio_tonic"):
715
+ yield line
716
+
717
  # Training output directory
718
  out_dir = PROJECT_ROOT / "outputs" / f"{params.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
719
  out_dir.mkdir(parents=True, exist_ok=True)
 
889
  gr.update(visible=True), # show step 2 (trainer)
890
  gr.update(visible=False), # hide step 3 until trainer selected
891
  gr.update(visible=False), # hide step 4 until monitoring selected
892
+ gr.update(visible=False), # GPT-OSS advanced group hidden until enabled
893
+ gr.update(visible=False), # SmolLM3 advanced group hidden until enabled
894
  )
895
 
896
 
897
  def on_config_change(family: str, config_choice: str):
898
+ """When a prebuilt configuration is selected, update dataset info and helpful details.
899
+
900
+ Also auto-fill advanced fields with defaults from the selected config.
901
+ """
902
  if not config_choice:
903
  return (
904
  "",
905
  gr.update(choices=[], value=None),
906
+ # Advanced fields (GPT-OSS)
907
+ "", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
908
+ None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
909
+ 2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
910
+ # Advanced fields (SmolLM3)
911
+ "HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
912
+ 4096, 2, 8, 5e-6, 500, 100, 10,
913
  )
914
 
915
  conf_map = get_config_map(family)
 
940
  # dataset selection (allow custom but prefill with the config's dataset if any)
941
  ds_choices = [dataset_name] if dataset_name else []
942
 
943
+ # Defaults for Advanced (GPT-OSS)
944
+ adv_dataset_name = dataset_name or ("HuggingFaceH4/Multilingual-Thinking" if family == "GPT-OSS" else (dataset_name or ""))
945
+ adv_dataset_split = getattr(cfg_obj, "dataset_split", "train") if cfg_obj else "train"
946
+ # Infer dataset_format heuristically
947
+ if family == "GPT-OSS":
948
+ adv_dataset_format = getattr(cfg_obj, "dataset_format", None) or (
949
+ "messages" if getattr(cfg_obj, "input_field", "") == "messages" else "openhermes_fr"
950
+ )
951
+ adv_input_field = getattr(cfg_obj, "input_field", "prompt")
952
+ adv_target_field = getattr(cfg_obj, "target_field", "accepted_completion") or ""
953
+ adv_num_train_epochs = float(getattr(cfg_obj, "num_train_epochs", 1.0)) if cfg_obj and hasattr(cfg_obj, "num_train_epochs") else 1.0
954
+ adv_batch_size = int(getattr(cfg_obj, "batch_size", 4) or 4)
955
+ adv_gas = int(getattr(cfg_obj, "gradient_accumulation_steps", 4) or 4)
956
+ adv_lr = float(getattr(cfg_obj, "learning_rate", 2e-4) or 2e-4)
957
+ adv_min_lr = float(getattr(cfg_obj, "min_lr", 2e-5) or 2e-5)
958
+ adv_wd = float(getattr(cfg_obj, "weight_decay", 0.01) or 0.01)
959
+ adv_warmup = float(getattr(cfg_obj, "warmup_ratio", 0.03) or 0.03)
960
+ adv_msl = int(getattr(cfg_obj, "max_seq_length", 2048) or 2048)
961
+ lora_cfg = getattr(cfg_obj, "lora_config", {}) or {}
962
+ adv_lora_r = int(lora_cfg.get("r", 16))
963
+ adv_lora_alpha = int(lora_cfg.get("lora_alpha", 32))
964
+ adv_lora_dropout = float(lora_cfg.get("lora_dropout", 0.05))
965
+ adv_mixed_precision = "bf16" if getattr(cfg_obj, "bf16", True) else ("fp16" if getattr(cfg_obj, "fp16", False) else "fp32")
966
+ adv_num_workers = int(getattr(cfg_obj, "dataloader_num_workers", 4) or 4)
967
+ qcfg = getattr(cfg_obj, "quantization_config", {}) or {}
968
+ if qcfg.get("load_in_4bit", False):
969
+ adv_quantization_type = "bnb4"
970
+ elif qcfg.get("dequantize", False):
971
+ adv_quantization_type = "mxfp4"
972
+ else:
973
+ adv_quantization_type = "none"
974
+ adv_mgn = float(getattr(cfg_obj, "max_grad_norm", 1.0) or 1.0)
975
+ adv_log = int(getattr(cfg_obj, "logging_steps", 10) or 10)
976
+ adv_eval = int(getattr(cfg_obj, "eval_steps", 100) or 100)
977
+ adv_save = int(getattr(cfg_obj, "save_steps", 500) or 500)
978
+ else:
979
+ # SmolLM3 defaults for Advanced
980
+ adv_dataset_format = "openhermes_fr"
981
+ adv_input_field = getattr(cfg_obj, "input_field", "prompt") if cfg_obj else "prompt"
982
+ adv_target_field = getattr(cfg_obj, "target_field", "completion") if cfg_obj else "completion"
983
+ adv_num_train_epochs = 1.0
984
+ adv_batch_size = int(getattr(cfg_obj, "batch_size", 2) or 2)
985
+ adv_gas = int(getattr(cfg_obj, "gradient_accumulation_steps", 8) or 8)
986
+ adv_lr = float(getattr(cfg_obj, "learning_rate", 5e-6) or 5e-6)
987
+ adv_min_lr = float(getattr(cfg_obj, "min_lr", 1e-6) or 1e-6)
988
+ adv_wd = float(getattr(cfg_obj, "weight_decay", 0.01) or 0.01)
989
+ adv_warmup = float(getattr(cfg_obj, "warmup_steps", 100) or 100) # Smol uses steps
990
+ adv_msl = int(getattr(cfg_obj, "max_seq_length", 4096) or 4096)
991
+ adv_lora_r = 16
992
+ adv_lora_alpha = 32
993
+ adv_lora_dropout = 0.05
994
+ adv_mixed_precision = "fp16" if getattr(cfg_obj, "fp16", True) else ("bf16" if getattr(cfg_obj, "bf16", False) else "fp32")
995
+ adv_num_workers = int(getattr(cfg_obj, "dataloader_num_workers", 4) or 4)
996
+ adv_quantization_type = "none"
997
+ adv_mgn = float(getattr(cfg_obj, "max_grad_norm", 1.0) or 1.0)
998
+ adv_log = int(getattr(cfg_obj, "logging_steps", 10) or 10)
999
+ adv_eval = int(getattr(cfg_obj, "eval_steps", 100) or 100)
1000
+ adv_save = int(getattr(cfg_obj, "save_steps", 500) or 500)
1001
+
1002
+ # SmolLM3 advanced model/dataset
1003
+ adv_sm_model_name = getattr(cfg_obj, "model_name", "HuggingFaceTB/SmolLM3-3B") if cfg_obj else "HuggingFaceTB/SmolLM3-3B"
1004
+ adv_sm_dataset_name = dataset_name if family == "SmolLM3" else None
1005
+ adv_sm_input_field = adv_input_field
1006
+ adv_sm_target_field = adv_target_field
1007
+ adv_sm_filter_bad = bool(getattr(cfg_obj, "filter_bad_entries", False)) if cfg_obj else False
1008
+ adv_sm_sample_size = getattr(cfg_obj, "sample_size", None)
1009
+ adv_sm_sample_seed = getattr(cfg_obj, "sample_seed", 42)
1010
+
1011
+ return (
1012
+ training_md,
1013
+ gr.update(choices=ds_choices, value=(dataset_name or None)),
1014
+ # Advanced (GPT-OSS)
1015
+ adv_dataset_name,
1016
+ adv_dataset_split,
1017
+ adv_dataset_format,
1018
+ adv_input_field,
1019
+ adv_target_field,
1020
+ getattr(cfg_obj, "system_message", None) if cfg_obj else "",
1021
+ getattr(cfg_obj, "developer_message", None) if cfg_obj else "",
1022
+ getattr(cfg_obj, "chat_template_kwargs", {}).get("model_identity") if cfg_obj and getattr(cfg_obj, "chat_template_kwargs", None) else "",
1023
+ getattr(cfg_obj, "max_samples", None) if cfg_obj else None,
1024
+ int(getattr(cfg_obj, "min_length", 10) or 10) if cfg_obj else 10,
1025
+ getattr(cfg_obj, "max_length", None) if cfg_obj else None,
1026
+ adv_num_train_epochs,
1027
+ adv_batch_size,
1028
+ adv_gas,
1029
+ adv_lr,
1030
+ adv_min_lr,
1031
+ adv_wd,
1032
+ adv_warmup,
1033
+ adv_msl,
1034
+ adv_lora_r,
1035
+ adv_lora_alpha,
1036
+ adv_lora_dropout,
1037
+ adv_mixed_precision,
1038
+ adv_num_workers,
1039
+ adv_quantization_type,
1040
+ adv_mgn,
1041
+ adv_log,
1042
+ adv_eval,
1043
+ adv_save,
1044
+ # Advanced (SmolLM3)
1045
+ adv_sm_model_name,
1046
+ adv_sm_dataset_name,
1047
+ adv_sm_input_field,
1048
+ adv_sm_target_field,
1049
+ adv_sm_filter_bad,
1050
+ adv_sm_sample_size,
1051
+ adv_sm_sample_seed,
1052
+ # SmolLM3 training overrides
1053
+ int(getattr(cfg_obj, "max_seq_length", 4096) or 4096) if family == "SmolLM3" else 4096,
1054
+ int(getattr(cfg_obj, "batch_size", 2) or 2) if family == "SmolLM3" else 2,
1055
+ int(getattr(cfg_obj, "gradient_accumulation_steps", 8) or 8) if family == "SmolLM3" else 8,
1056
+ float(getattr(cfg_obj, "learning_rate", 5e-6) or 5e-6) if family == "SmolLM3" else 5e-6,
1057
+ int(getattr(cfg_obj, "save_steps", 500) or 500) if family == "SmolLM3" else 500,
1058
+ int(getattr(cfg_obj, "eval_steps", 100) or 100) if family == "SmolLM3" else 100,
1059
+ int(getattr(cfg_obj, "logging_steps", 10) or 10) if family == "SmolLM3" else 10,
1060
+ )
1061
 
1062
 
1063
  def on_trainer_selected(_: str):
 
1231
 
1232
  with gr.Tab("Advanced"):
1233
  # GPT-OSS specific scheduler overrides
1234
+ advanced_enabled = gr.Checkbox(value=False, label="Use advanced overrides (generate config)")
1235
+
1236
+ # Family-specific advanced groups
1237
+ gpt_oss_advanced_group = gr.Group(visible=False)
1238
+ with gpt_oss_advanced_group:
1239
+ gr.Markdown("Advanced configuration for GPT-OSS")
1240
+ with gr.Accordion("Dataset", open=True):
1241
+ adv_dataset_name = gr.Textbox(value="", label="Dataset name")
1242
+ with gr.Row():
1243
+ adv_dataset_split = gr.Textbox(value="train", label="Dataset split")
1244
+ adv_dataset_format = gr.Dropdown(
1245
+ choices=["openhermes_fr", "messages", "text"],
1246
+ value="openhermes_fr",
1247
+ label="Dataset format",
1248
+ )
1249
+ with gr.Row():
1250
+ adv_input_field = gr.Textbox(value="prompt", label="Input field")
1251
+ adv_target_field = gr.Textbox(value="accepted_completion", label="Target field (optional)")
1252
+ with gr.Row():
1253
+ adv_system_message = gr.Textbox(value="", label="System message (optional)")
1254
+ adv_developer_message = gr.Textbox(value="", label="Developer message (optional)")
1255
+ adv_model_identity = gr.Textbox(value="", label="Model identity (optional)")
1256
+ with gr.Row():
1257
+ adv_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
1258
+ adv_min_length = gr.Number(value=10, precision=0, label="Min length")
1259
+ adv_max_length = gr.Number(value=None, precision=0, label="Max length (optional)")
1260
+
1261
+ with gr.Accordion("Training", open=True):
1262
+ with gr.Row():
1263
+ adv_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
1264
+ adv_batch_size = gr.Number(value=4, precision=0, label="Batch size")
1265
+ adv_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
1266
+ with gr.Row():
1267
+ adv_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
1268
+ adv_min_lr_num = gr.Number(value=2e-5, precision=6, label="Min LR")
1269
+ adv_weight_decay = gr.Number(value=0.01, precision=6, label="Weight decay")
1270
+ adv_warmup_ratio = gr.Number(value=0.03, precision=3, label="Warmup ratio")
1271
+ adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
1272
+
1273
+ with gr.Accordion("LoRA & Quantization", open=False):
1274
+ with gr.Row():
1275
+ adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
1276
+ adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
1277
+ adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
1278
+ with gr.Row():
1279
+ adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
1280
+ adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
1281
+ adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
1282
+ adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
1283
+
1284
+ with gr.Accordion("Eval & Logging", open=False):
1285
+ with gr.Row():
1286
+ adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1287
+ adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1288
+ adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1289
+
1290
+ with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
1291
+ scheduler_override = gr.Dropdown(
1292
+ choices=[c for c in SCHEDULER_CHOICES if c is not None],
1293
+ value=None,
1294
+ allow_custom_value=True,
1295
+ label="Scheduler override",
1296
+ )
1297
+ with gr.Row():
1298
+ min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
1299
+ min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
1300
+
1301
+ smollm3_advanced_group = gr.Group(visible=False)
1302
+ with smollm3_advanced_group:
1303
+ gr.Markdown("Advanced configuration for SmolLM3")
1304
+ with gr.Accordion("Dataset", open=True):
1305
+ adv_sm_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
1306
+ adv_sm_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
1307
+ with gr.Row():
1308
+ adv_sm_input_field = gr.Textbox(value="prompt", label="Input field")
1309
+ adv_sm_target_field = gr.Textbox(value="completion", label="Target field")
1310
+ with gr.Row():
1311
+ adv_sm_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
1312
+ adv_sm_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
1313
+ adv_sm_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
1314
+ with gr.Accordion("Training", open=True):
1315
+ with gr.Row():
1316
+ adv_sm_max_seq_length = gr.Number(value=4096, precision=0, label="Max seq length")
1317
+ adv_sm_batch_size = gr.Number(value=2, precision=0, label="Batch size")
1318
+ adv_sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
1319
+ adv_sm_learning_rate = gr.Number(value=5e-6, precision=6, label="Learning rate")
1320
+ with gr.Row():
1321
+ adv_sm_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1322
+ adv_sm_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1323
+ adv_sm_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1324
+
1325
+ def _toggle_advanced(enable: bool, family_val: str):
1326
+ return (
1327
+ gr.update(visible=enable and family_val == "GPT-OSS"),
1328
+ gr.update(visible=enable and family_val == "SmolLM3"),
1329
  )
1330
+
1331
+ advanced_enabled.change(
1332
+ _toggle_advanced,
1333
+ inputs=[advanced_enabled, model_family],
1334
+ outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
1335
+ )
1336
 
1337
  # Final action & logs
1338
  start_btn = gr.Button("Run Pipeline", variant="primary")
 
1353
  step2_group,
1354
  step3_group,
1355
  step4_group,
1356
+ gpt_oss_advanced_group, # show advanced for GPT-OSS
1357
+ smollm3_advanced_group, # show advanced for SmolLM3
1358
  ],
1359
  )
1360
 
 
1369
  config_choice.change(
1370
  on_config_change,
1371
  inputs=[model_family, config_choice],
1372
+ outputs=[
1373
+ training_info,
1374
+ dataset_choice,
1375
+ # Advanced (GPT-OSS) outputs
1376
+ adv_dataset_name,
1377
+ adv_dataset_split,
1378
+ adv_dataset_format,
1379
+ adv_input_field,
1380
+ adv_target_field,
1381
+ adv_system_message,
1382
+ adv_developer_message,
1383
+ adv_model_identity,
1384
+ adv_max_samples,
1385
+ adv_min_length,
1386
+ adv_max_length,
1387
+ adv_num_train_epochs,
1388
+ adv_batch_size,
1389
+ adv_gradient_accumulation_steps,
1390
+ adv_learning_rate,
1391
+ adv_min_lr_num,
1392
+ adv_weight_decay,
1393
+ adv_warmup_ratio,
1394
+ adv_max_seq_length,
1395
+ adv_lora_r,
1396
+ adv_lora_alpha,
1397
+ adv_lora_dropout,
1398
+ adv_mixed_precision,
1399
+ adv_num_workers,
1400
+ adv_quantization_type,
1401
+ adv_max_grad_norm,
1402
+ adv_logging_steps,
1403
+ adv_eval_steps,
1404
+ adv_save_steps,
1405
+ # Advanced (SmolLM3)
1406
+ adv_sm_model_name,
1407
+ adv_sm_dataset_name,
1408
+ adv_sm_input_field,
1409
+ adv_sm_target_field,
1410
+ adv_sm_filter_bad_entries,
1411
+ adv_sm_sample_size,
1412
+ adv_sm_sample_seed,
1413
+ adv_sm_max_seq_length,
1414
+ adv_sm_batch_size,
1415
+ adv_sm_gas,
1416
+ adv_sm_learning_rate,
1417
+ adv_sm_save_steps,
1418
+ adv_sm_eval_steps,
1419
+ adv_sm_logging_steps,
1420
+ ],
1421
  )
1422
 
1423
+ # Keep Advanced dataset fields in sync when user selects a different dataset
1424
+ def _sync_dataset_fields(ds_value: Optional[str]):
1425
+ ds_text = ds_value or ""
1426
+ return ds_text, ds_text
1427
+
1428
+ dataset_choice.change(
1429
+ _sync_dataset_fields,
1430
+ inputs=[dataset_choice],
1431
+ outputs=[adv_dataset_name, adv_sm_dataset_name],
1432
+ )
1433
+
1434
+ def _start_with_overrides(
1435
+ model_family_v,
1436
+ config_choice_v,
1437
+ trainer_type_v,
1438
+ monitoring_mode_v,
1439
+ experiment_name_v,
1440
+ repo_short_v,
1441
+ author_name_v,
1442
+ model_description_v,
1443
+ trackio_space_name_v,
1444
+ deploy_trackio_space_v,
1445
+ create_dataset_repo_v,
1446
+ push_to_hub_v,
1447
+ switch_to_read_after_v,
1448
+ scheduler_override_v,
1449
+ min_lr_v,
1450
+ min_lr_rate_v,
1451
+ advanced_enabled_v,
1452
+ # GPT-OSS advanced
1453
+ adv_dataset_name_v,
1454
+ adv_dataset_split_v,
1455
+ adv_dataset_format_v,
1456
+ adv_input_field_v,
1457
+ adv_target_field_v,
1458
+ adv_system_message_v,
1459
+ adv_developer_message_v,
1460
+ adv_model_identity_v,
1461
+ adv_max_samples_v,
1462
+ adv_min_length_v,
1463
+ adv_max_length_v,
1464
+ adv_num_train_epochs_v,
1465
+ adv_batch_size_v,
1466
+ adv_gas_v,
1467
+ adv_lr_v,
1468
+ adv_min_lr_num_v,
1469
+ adv_wd_v,
1470
+ adv_warmup_ratio_v,
1471
+ adv_max_seq_length_v,
1472
+ adv_lora_r_v,
1473
+ adv_lora_alpha_v,
1474
+ adv_lora_dropout_v,
1475
+ adv_mixed_precision_v,
1476
+ adv_num_workers_v,
1477
+ adv_quantization_type_v,
1478
+ adv_max_grad_norm_v,
1479
+ adv_logging_steps_v,
1480
+ adv_eval_steps_v,
1481
+ adv_save_steps_v,
1482
+ # SmolLM3 advanced
1483
+ adv_sm_model_name_v,
1484
+ adv_sm_dataset_name_v,
1485
+ adv_sm_input_field_v,
1486
+ adv_sm_target_field_v,
1487
+ adv_sm_filter_bad_entries_v,
1488
+ adv_sm_sample_size_v,
1489
+ adv_sm_sample_seed_v,
1490
+ adv_sm_max_seq_length_v,
1491
+ adv_sm_batch_size_v,
1492
+ adv_sm_gas_v,
1493
+ adv_sm_learning_rate_v,
1494
+ adv_sm_save_steps_v,
1495
+ adv_sm_eval_steps_v,
1496
+ adv_sm_logging_steps_v,
1497
+ ):
1498
+ # If advanced overrides enabled, generate a config file and pass its path
1499
+ override_path: Optional[str] = None
1500
+ if advanced_enabled_v:
1501
+ try:
1502
+ if model_family_v == "GPT-OSS":
1503
+ cfg_path = generate_gpt_oss_custom_config_file(
1504
+ dataset_name=str(adv_dataset_name_v or ""),
1505
+ dataset_split=str(adv_dataset_split_v or "train"),
1506
+ dataset_format=str(adv_dataset_format_v or "openhermes_fr"),
1507
+ input_field=str(adv_input_field_v or "prompt"),
1508
+ target_field=(str(adv_target_field_v) if adv_target_field_v else None),
1509
+ system_message=(str(adv_system_message_v) if adv_system_message_v else None),
1510
+ developer_message=(str(adv_developer_message_v) if adv_developer_message_v else None),
1511
+ model_identity=(str(adv_model_identity_v) if adv_model_identity_v else None),
1512
+ max_samples=(int(adv_max_samples_v) if adv_max_samples_v else None),
1513
+ min_length=int(adv_min_length_v or 10),
1514
+ max_length=(int(adv_max_length_v) if adv_max_length_v else None),
1515
+ num_train_epochs=float(adv_num_train_epochs_v or 1.0),
1516
+ batch_size=int(adv_batch_size_v or 4),
1517
+ gradient_accumulation_steps=int(adv_gas_v or 4),
1518
+ learning_rate=float(adv_lr_v or 2e-4),
1519
+ min_lr=float(adv_min_lr_num_v or 2e-5),
1520
+ weight_decay=float(adv_wd_v or 0.01),
1521
+ warmup_ratio=float(adv_warmup_ratio_v or 0.03),
1522
+ max_seq_length=int(adv_max_seq_length_v or 2048),
1523
+ lora_r=int(adv_lora_r_v or 16),
1524
+ lora_alpha=int(adv_lora_alpha_v or 32),
1525
+ lora_dropout=float(adv_lora_dropout_v or 0.05),
1526
+ mixed_precision=str(adv_mixed_precision_v or "bf16"),
1527
+ num_workers=int(adv_num_workers_v or 4),
1528
+ quantization_type=str(adv_quantization_type_v or "mxfp4"),
1529
+ max_grad_norm=float(adv_max_grad_norm_v or 1.0),
1530
+ logging_steps=int(adv_logging_steps_v or 10),
1531
+ eval_steps=int(adv_eval_steps_v or 100),
1532
+ save_steps=int(adv_save_steps_v or 500),
1533
+ )
1534
+ else:
1535
+ cfg_path = generate_smollm3_custom_config_file(
1536
+ model_name=str(adv_sm_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
1537
+ dataset_name=(str(adv_sm_dataset_name_v) if adv_sm_dataset_name_v else None),
1538
+ max_seq_length=int(adv_sm_max_seq_length_v or 4096),
1539
+ batch_size=int(adv_sm_batch_size_v or 2),
1540
+ gradient_accumulation_steps=int(adv_sm_gas_v or 8),
1541
+ learning_rate=float(adv_sm_learning_rate_v or 5e-6),
1542
+ save_steps=int(adv_sm_save_steps_v or 500),
1543
+ eval_steps=int(adv_sm_eval_steps_v or 100),
1544
+ logging_steps=int(adv_sm_logging_steps_v or 10),
1545
+ filter_bad_entries=bool(adv_sm_filter_bad_entries_v),
1546
+ input_field=str(adv_sm_input_field_v or "prompt"),
1547
+ target_field=str(adv_sm_target_field_v or "completion"),
1548
+ sample_size=(int(adv_sm_sample_size_v) if adv_sm_sample_size_v else None),
1549
+ sample_seed=int(adv_sm_sample_seed_v or 42),
1550
+ trainer_type=str(trainer_type_v).lower(),
1551
+ )
1552
+ override_path = str(cfg_path)
1553
+ except Exception as e:
1554
+ # Surface error in logs via generator
1555
+ def _err_gen():
1556
+ yield f"❌ Failed to generate advanced config: {e}"
1557
+ return _err_gen()
1558
+
1559
+ def _gen():
1560
+ params = PipelineInputs(
1561
+ model_family=model_family_v,
1562
+ config_choice=config_choice_v,
1563
+ trainer_type=trainer_type_v,
1564
+ monitoring_mode=monitoring_mode_v,
1565
+ experiment_name=experiment_name_v,
1566
+ repo_short=repo_short_v,
1567
+ author_name=author_name_v,
1568
+ model_description=model_description_v,
1569
+ trackio_space_name=trackio_space_name_v or None,
1570
+ deploy_trackio_space=bool(deploy_trackio_space_v),
1571
+ create_dataset_repo=bool(create_dataset_repo_v),
1572
+ push_to_hub=bool(push_to_hub_v),
1573
+ switch_to_read_after=bool(switch_to_read_after_v),
1574
+ scheduler_override=(scheduler_override_v or None),
1575
+ min_lr=min_lr_v,
1576
+ min_lr_rate=min_lr_rate_v,
1577
+ override_config_path=override_path,
1578
+ )
1579
+ write_token = os.environ.get("HF_WRITE_TOKEN") or os.environ.get("HF_TOKEN")
1580
+ read_token = os.environ.get("HF_READ_TOKEN")
1581
+ yield f"HF_WRITE_TOKEN: {mask_token(write_token)}"
1582
+ yield f"HF_READ_TOKEN: {mask_token(read_token)}"
1583
+ for line in run_pipeline(params):
1584
+ yield line
1585
+ time.sleep(0.01)
1586
+ return _gen()
1587
+
1588
  start_btn.click(
1589
+ _start_with_overrides,
1590
  inputs=[
1591
  model_family,
1592
  config_choice,
 
1604
  scheduler_override,
1605
  min_lr,
1606
  min_lr_rate,
1607
+ advanced_enabled,
1608
+ # GPT-OSS advanced
1609
+ adv_dataset_name,
1610
+ adv_dataset_split,
1611
+ adv_dataset_format,
1612
+ adv_input_field,
1613
+ adv_target_field,
1614
+ adv_system_message,
1615
+ adv_developer_message,
1616
+ adv_model_identity,
1617
+ adv_max_samples,
1618
+ adv_min_length,
1619
+ adv_max_length,
1620
+ adv_num_train_epochs,
1621
+ adv_batch_size,
1622
+ adv_gradient_accumulation_steps,
1623
+ adv_learning_rate,
1624
+ adv_min_lr_num,
1625
+ adv_weight_decay,
1626
+ adv_warmup_ratio,
1627
+ adv_max_seq_length,
1628
+ adv_lora_r,
1629
+ adv_lora_alpha,
1630
+ adv_lora_dropout,
1631
+ adv_mixed_precision,
1632
+ adv_num_workers,
1633
+ adv_quantization_type,
1634
+ adv_max_grad_norm,
1635
+ adv_logging_steps,
1636
+ adv_eval_steps,
1637
+ adv_save_steps,
1638
+ # SmolLM3 advanced
1639
+ adv_sm_model_name,
1640
+ adv_sm_dataset_name,
1641
+ adv_sm_input_field,
1642
+ adv_sm_target_field,
1643
+ adv_sm_filter_bad_entries,
1644
+ adv_sm_sample_size,
1645
+ adv_sm_sample_seed,
1646
+ adv_sm_max_seq_length,
1647
+ adv_sm_batch_size,
1648
+ adv_sm_gas,
1649
+ adv_sm_learning_rate,
1650
+ adv_sm_save_steps,
1651
+ adv_sm_eval_steps,
1652
+ adv_sm_logging_steps,
1653
  ],
1654
  outputs=[logs],
1655
  )