Tonic commited on
Commit
cb30cda
·
1 Parent(s): 943cfba

adds tabbed interface, advanced mode, connectors

Browse files
Files changed (1) hide show
  1. interface.py +423 -127
interface.py CHANGED
@@ -470,6 +470,94 @@ config = SmolLM3GeneratedConfig(
470
  """
471
  return _write_generated_config("_generated_smollm3_custom.py", py)
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  def ensure_dataset_repo(username: str, dataset_name: str, token: str) -> Tuple[str, bool, str]:
474
  """Create or ensure dataset repo exists. Returns (repo_id, created_or_exists, message)."""
475
  from huggingface_hub import create_repo # type: ignore
@@ -907,6 +995,8 @@ def on_config_change(family: str, config_choice: str):
907
  "", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
908
  None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
909
  2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
 
 
910
  # Advanced fields (SmolLM3)
911
  "HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
912
  4096, 2, 8, 5e-6, 500, 100, 10,
@@ -1041,6 +1131,15 @@ def on_config_change(family: str, config_choice: str):
1041
  adv_log,
1042
  adv_eval,
1043
  adv_save,
 
 
 
 
 
 
 
 
 
1044
  # Advanced (SmolLM3)
1045
  adv_sm_model_name,
1046
  adv_sm_dataset_name,
@@ -1237,90 +1336,170 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1237
  gpt_oss_advanced_group = gr.Group(visible=False)
1238
  with gpt_oss_advanced_group:
1239
  gr.Markdown("Advanced configuration for GPT-OSS")
1240
- with gr.Accordion("Dataset", open=True):
1241
- adv_dataset_name = gr.Textbox(value="", label="Dataset name")
1242
- with gr.Row():
1243
- adv_dataset_split = gr.Textbox(value="train", label="Dataset split")
1244
- adv_dataset_format = gr.Dropdown(
1245
- choices=["openhermes_fr", "messages", "text"],
1246
- value="openhermes_fr",
1247
- label="Dataset format",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1248
  )
1249
- with gr.Row():
1250
- adv_input_field = gr.Textbox(value="prompt", label="Input field")
1251
- adv_target_field = gr.Textbox(value="accepted_completion", label="Target field (optional)")
1252
- with gr.Row():
1253
- adv_system_message = gr.Textbox(value="", label="System message (optional)")
1254
- adv_developer_message = gr.Textbox(value="", label="Developer message (optional)")
1255
- adv_model_identity = gr.Textbox(value="", label="Model identity (optional)")
1256
- with gr.Row():
1257
- adv_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
1258
- adv_min_length = gr.Number(value=10, precision=0, label="Min length")
1259
- adv_max_length = gr.Number(value=None, precision=0, label="Max length (optional)")
1260
-
1261
- with gr.Accordion("Training", open=True):
1262
- with gr.Row():
1263
- adv_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
1264
- adv_batch_size = gr.Number(value=4, precision=0, label="Batch size")
1265
- adv_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
1266
- with gr.Row():
1267
- adv_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
1268
- adv_min_lr_num = gr.Number(value=2e-5, precision=6, label="Min LR")
1269
- adv_weight_decay = gr.Number(value=0.01, precision=6, label="Weight decay")
1270
- adv_warmup_ratio = gr.Number(value=0.03, precision=3, label="Warmup ratio")
1271
- adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
1272
-
1273
- with gr.Accordion("LoRA & Quantization", open=False):
1274
- with gr.Row():
1275
- adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
1276
- adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
1277
- adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
1278
- with gr.Row():
1279
- adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
1280
- adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
1281
- adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
1282
- adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
1283
-
1284
- with gr.Accordion("Eval & Logging", open=False):
1285
- with gr.Row():
1286
- adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1287
- adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1288
- adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1289
-
1290
- with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
1291
- scheduler_override = gr.Dropdown(
1292
- choices=[c for c in SCHEDULER_CHOICES if c is not None],
1293
- value=None,
1294
- allow_custom_value=True,
1295
- label="Scheduler override",
1296
- )
1297
- with gr.Row():
1298
- min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
1299
- min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
1300
 
1301
  smollm3_advanced_group = gr.Group(visible=False)
1302
  with smollm3_advanced_group:
1303
  gr.Markdown("Advanced configuration for SmolLM3")
1304
- with gr.Accordion("Dataset", open=True):
1305
- adv_sm_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
1306
- adv_sm_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
1307
- with gr.Row():
1308
- adv_sm_input_field = gr.Textbox(value="prompt", label="Input field")
1309
- adv_sm_target_field = gr.Textbox(value="completion", label="Target field")
1310
- with gr.Row():
1311
- adv_sm_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
1312
- adv_sm_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
1313
- adv_sm_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
1314
- with gr.Accordion("Training", open=True):
1315
- with gr.Row():
1316
- adv_sm_max_seq_length = gr.Number(value=4096, precision=0, label="Max seq length")
1317
- adv_sm_batch_size = gr.Number(value=2, precision=0, label="Batch size")
1318
- adv_sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
1319
- adv_sm_learning_rate = gr.Number(value=5e-6, precision=6, label="Learning rate")
1320
- with gr.Row():
1321
- adv_sm_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1322
- adv_sm_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1323
- adv_sm_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1324
 
1325
  def _toggle_advanced(enable: bool, family_val: str):
1326
  return (
@@ -1334,6 +1513,19 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1334
  outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
1335
  )
1336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1337
  # Final action & logs
1338
  start_btn = gr.Button("Run Pipeline", variant="primary")
1339
  logs = gr.Textbox(value="", label="Logs", lines=20)
@@ -1402,6 +1594,15 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1402
  adv_logging_steps,
1403
  adv_eval_steps,
1404
  adv_save_steps,
 
 
 
 
 
 
 
 
 
1405
  # Advanced (SmolLM3)
1406
  adv_sm_model_name,
1407
  adv_sm_dataset_name,
@@ -1449,6 +1650,7 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1449
  min_lr_v,
1450
  min_lr_rate_v,
1451
  advanced_enabled_v,
 
1452
  # GPT-OSS advanced
1453
  adv_dataset_name_v,
1454
  adv_dataset_split_v,
@@ -1479,7 +1681,17 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1479
  adv_logging_steps_v,
1480
  adv_eval_steps_v,
1481
  adv_save_steps_v,
 
 
 
 
 
 
 
 
 
1482
  # SmolLM3 advanced
 
1483
  adv_sm_model_name_v,
1484
  adv_sm_dataset_name_v,
1485
  adv_sm_input_field_v,
@@ -1494,61 +1706,115 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1494
  adv_sm_save_steps_v,
1495
  adv_sm_eval_steps_v,
1496
  adv_sm_logging_steps_v,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1497
  ):
1498
  # If advanced overrides enabled, generate a config file and pass its path
1499
  override_path: Optional[str] = None
1500
  if advanced_enabled_v:
1501
  try:
1502
  if model_family_v == "GPT-OSS":
1503
- cfg_path = generate_gpt_oss_custom_config_file(
1504
- dataset_name=str(adv_dataset_name_v or ""),
1505
- dataset_split=str(adv_dataset_split_v or "train"),
1506
- dataset_format=str(adv_dataset_format_v or "openhermes_fr"),
1507
- input_field=str(adv_input_field_v or "prompt"),
1508
- target_field=(str(adv_target_field_v) if adv_target_field_v else None),
1509
- system_message=(str(adv_system_message_v) if adv_system_message_v else None),
1510
- developer_message=(str(adv_developer_message_v) if adv_developer_message_v else None),
1511
- model_identity=(str(adv_model_identity_v) if adv_model_identity_v else None),
1512
- max_samples=(int(adv_max_samples_v) if adv_max_samples_v else None),
1513
- min_length=int(adv_min_length_v or 10),
1514
- max_length=(int(adv_max_length_v) if adv_max_length_v else None),
1515
- num_train_epochs=float(adv_num_train_epochs_v or 1.0),
1516
- batch_size=int(adv_batch_size_v or 4),
1517
- gradient_accumulation_steps=int(adv_gas_v or 4),
1518
- learning_rate=float(adv_lr_v or 2e-4),
1519
- min_lr=float(adv_min_lr_num_v or 2e-5),
1520
- weight_decay=float(adv_wd_v or 0.01),
1521
- warmup_ratio=float(adv_warmup_ratio_v or 0.03),
1522
- max_seq_length=int(adv_max_seq_length_v or 2048),
1523
- lora_r=int(adv_lora_r_v or 16),
1524
- lora_alpha=int(adv_lora_alpha_v or 32),
1525
- lora_dropout=float(adv_lora_dropout_v or 0.05),
1526
- mixed_precision=str(adv_mixed_precision_v or "bf16"),
1527
- num_workers=int(adv_num_workers_v or 4),
1528
- quantization_type=str(adv_quantization_type_v or "mxfp4"),
1529
- max_grad_norm=float(adv_max_grad_norm_v or 1.0),
1530
- logging_steps=int(adv_logging_steps_v or 10),
1531
- eval_steps=int(adv_eval_steps_v or 100),
1532
- save_steps=int(adv_save_steps_v or 500),
1533
- )
 
 
 
 
 
 
 
 
 
 
 
 
1534
  else:
1535
- cfg_path = generate_smollm3_custom_config_file(
1536
- model_name=str(adv_sm_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
1537
- dataset_name=(str(adv_sm_dataset_name_v) if adv_sm_dataset_name_v else None),
1538
- max_seq_length=int(adv_sm_max_seq_length_v or 4096),
1539
- batch_size=int(adv_sm_batch_size_v or 2),
1540
- gradient_accumulation_steps=int(adv_sm_gas_v or 8),
1541
- learning_rate=float(adv_sm_learning_rate_v or 5e-6),
1542
- save_steps=int(adv_sm_save_steps_v or 500),
1543
- eval_steps=int(adv_sm_eval_steps_v or 100),
1544
- logging_steps=int(adv_sm_logging_steps_v or 10),
1545
- filter_bad_entries=bool(adv_sm_filter_bad_entries_v),
1546
- input_field=str(adv_sm_input_field_v or "prompt"),
1547
- target_field=str(adv_sm_target_field_v or "completion"),
1548
- sample_size=(int(adv_sm_sample_size_v) if adv_sm_sample_size_v else None),
1549
- sample_seed=int(adv_sm_sample_seed_v or 42),
1550
- trainer_type=str(trainer_type_v).lower(),
1551
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1552
  override_path = str(cfg_path)
1553
  except Exception as e:
1554
  # Surface error in logs via generator
@@ -1605,6 +1871,7 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1605
  min_lr,
1606
  min_lr_rate,
1607
  advanced_enabled,
 
1608
  # GPT-OSS advanced
1609
  adv_dataset_name,
1610
  adv_dataset_split,
@@ -1635,7 +1902,17 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1635
  adv_logging_steps,
1636
  adv_eval_steps,
1637
  adv_save_steps,
 
 
 
 
 
 
 
 
 
1638
  # SmolLM3 advanced
 
1639
  adv_sm_model_name,
1640
  adv_sm_dataset_name,
1641
  adv_sm_input_field,
@@ -1650,6 +1927,25 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
1650
  adv_sm_save_steps,
1651
  adv_sm_eval_steps,
1652
  adv_sm_logging_steps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1653
  ],
1654
  outputs=[logs],
1655
  )
 
470
  """
471
  return _write_generated_config("_generated_smollm3_custom.py", py)
472
 
473
+
474
+ def generate_smollm3_long_context_config_file(
475
+ model_name: str,
476
+ dataset_name: Optional[str],
477
+ input_field: str,
478
+ target_field: str,
479
+ filter_bad_entries: bool,
480
+ sample_size: Optional[int],
481
+ sample_seed: int,
482
+ max_seq_length: int,
483
+ batch_size: int,
484
+ gradient_accumulation_steps: int,
485
+ learning_rate: float,
486
+ warmup_steps: int,
487
+ max_iters: int,
488
+ save_steps: int,
489
+ eval_steps: int,
490
+ logging_steps: int,
491
+ use_chat_template: bool,
492
+ no_think_system_message: bool,
493
+ trainer_type: str,
494
+ ) -> Path:
495
+ """Create a SmolLM3 long-context config file with optional dataset fields."""
496
+ def _bool(b: bool) -> str:
497
+ return "True" if b else "False"
498
+
499
+ ds_section = """
500
+ # HF Dataset configuration
501
+ dataset_name={}
502
+ dataset_split="train"
503
+ input_field={}
504
+ target_field={}
505
+ filter_bad_entries={}
506
+ bad_entry_field="bad_entry"
507
+ sample_size={}
508
+ sample_seed={}
509
+ """.format(
510
+ repr(dataset_name) if dataset_name else "None",
511
+ repr(input_field),
512
+ repr(target_field),
513
+ _bool(filter_bad_entries),
514
+ repr(sample_size) if sample_size is not None else "None",
515
+ sample_seed,
516
+ )
517
+
518
+ py = f"""
519
+ from dataclasses import dataclass
520
+ from typing import Optional
521
+ from config.train_smollm3 import SmolLM3Config
522
+
523
+ @dataclass
524
+ class SmolLM3LongContextGeneratedConfig(SmolLM3Config):
525
+ {ds_section}
526
+
527
+ config = SmolLM3LongContextGeneratedConfig(
528
+ trainer_type={repr(trainer_type.lower())},
529
+ model_name={repr(model_name)},
530
+ max_seq_length={max_seq_length},
531
+ use_flash_attention=True,
532
+ use_gradient_checkpointing=True,
533
+
534
+ batch_size={batch_size},
535
+ gradient_accumulation_steps={gradient_accumulation_steps},
536
+ learning_rate={learning_rate},
537
+ weight_decay=0.01,
538
+ warmup_steps={warmup_steps},
539
+ max_iters={max_iters},
540
+
541
+ fp16=True,
542
+ bf16=False,
543
+ save_steps={save_steps},
544
+ eval_steps={eval_steps},
545
+ logging_steps={logging_steps},
546
+ save_total_limit=3,
547
+ eval_strategy="steps",
548
+ metric_for_best_model="eval_loss",
549
+ greater_is_better=False,
550
+ load_best_model_at_end=True,
551
+
552
+ use_chat_template={_bool(use_chat_template)},
553
+ chat_template_kwargs={{
554
+ "add_generation_prompt": True,
555
+ "no_think_system_message": {_bool(no_think_system_message)}
556
+ }}
557
+ )
558
+ """
559
+ return _write_generated_config("_generated_smollm3_long_context.py", py)
560
+
561
  def ensure_dataset_repo(username: str, dataset_name: str, token: str) -> Tuple[str, bool, str]:
562
  """Create or ensure dataset repo exists. Returns (repo_id, created_or_exists, message)."""
563
  from huggingface_hub import create_repo # type: ignore
 
995
  "", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
996
  None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
997
  2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
998
+ # GPT-OSS Medical o1 SFT defaults
999
+ "default", "", "", 1.0, 4, 4, 2e-4, 2048,
1000
  # Advanced fields (SmolLM3)
1001
  "HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
1002
  4096, 2, 8, 5e-6, 500, 100, 10,
 
1131
  adv_log,
1132
  adv_eval,
1133
  adv_save,
1134
+ # GPT-OSS Medical o1 SFT defaults
1135
+ "default",
1136
+ "",
1137
+ "",
1138
+ 1.0,
1139
+ 4,
1140
+ 4,
1141
+ 2e-4,
1142
+ 2048,
1143
  # Advanced (SmolLM3)
1144
  adv_sm_model_name,
1145
  adv_sm_dataset_name,
 
1336
  gpt_oss_advanced_group = gr.Group(visible=False)
1337
  with gpt_oss_advanced_group:
1338
  gr.Markdown("Advanced configuration for GPT-OSS")
1339
+ adv_gpt_mode = gr.Radio(
1340
+ choices=["custom", "medical_o1_sft"],
1341
+ value="custom",
1342
+ label="Advanced mode",
1343
+ )
1344
+
1345
+ # --- GPT-OSS Custom advanced controls ---
1346
+ gpt_oss_custom_group = gr.Group(visible=True)
1347
+ with gpt_oss_custom_group:
1348
+ with gr.Accordion("Dataset", open=True):
1349
+ adv_dataset_name = gr.Textbox(value="", label="Dataset name")
1350
+ with gr.Row():
1351
+ adv_dataset_split = gr.Textbox(value="train", label="Dataset split")
1352
+ adv_dataset_format = gr.Dropdown(
1353
+ choices=["openhermes_fr", "messages", "text"],
1354
+ value="openhermes_fr",
1355
+ label="Dataset format",
1356
+ )
1357
+ with gr.Row():
1358
+ adv_input_field = gr.Textbox(value="prompt", label="Input field")
1359
+ adv_target_field = gr.Textbox(value="accepted_completion", label="Target field (optional)")
1360
+ with gr.Row():
1361
+ adv_system_message = gr.Textbox(value="", label="System message (optional)")
1362
+ adv_developer_message = gr.Textbox(value="", label="Developer message (optional)")
1363
+ adv_model_identity = gr.Textbox(value="", label="Model identity (optional)")
1364
+ with gr.Row():
1365
+ adv_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
1366
+ adv_min_length = gr.Number(value=10, precision=0, label="Min length")
1367
+ adv_max_length = gr.Number(value=None, precision=0, label="Max length (optional)")
1368
+
1369
+ with gr.Accordion("Training", open=True):
1370
+ with gr.Row():
1371
+ adv_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
1372
+ adv_batch_size = gr.Number(value=4, precision=0, label="Batch size")
1373
+ adv_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
1374
+ with gr.Row():
1375
+ adv_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
1376
+ adv_min_lr_num = gr.Number(value=2e-5, precision=6, label="Min LR")
1377
+ adv_weight_decay = gr.Number(value=0.01, precision=6, label="Weight decay")
1378
+ adv_warmup_ratio = gr.Number(value=0.03, precision=3, label="Warmup ratio")
1379
+ adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
1380
+
1381
+ with gr.Accordion("LoRA & Quantization", open=False):
1382
+ with gr.Row():
1383
+ adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
1384
+ adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
1385
+ adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
1386
+ with gr.Row():
1387
+ adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
1388
+ adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
1389
+ adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
1390
+ adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
1391
+
1392
+ with gr.Accordion("Eval & Logging", open=False):
1393
+ with gr.Row():
1394
+ adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1395
+ adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1396
+ adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1397
+
1398
+ with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
1399
+ scheduler_override = gr.Dropdown(
1400
+ choices=[c for c in SCHEDULER_CHOICES if c is not None],
1401
+ value=None,
1402
+ allow_custom_value=True,
1403
+ label="Scheduler override",
1404
  )
1405
+ with gr.Row():
1406
+ min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
1407
+ min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
1408
+
1409
+ # --- GPT-OSS Medical o1 SFT controls ---
1410
+ gpt_oss_medical_group = gr.Group(visible=False)
1411
+ with gpt_oss_medical_group:
1412
+ gr.Markdown("Build a Medical o1 SFT configuration (dataset fixed to FreedomIntelligence/medical-o1-reasoning-SFT)")
1413
+ with gr.Accordion("Dataset", open=True):
1414
+ adv_med_dataset_config = gr.Textbox(value="default", label="Dataset config (subset)")
1415
+ with gr.Accordion("Context (optional)", open=False):
1416
+ with gr.Row():
1417
+ adv_med_system_message = gr.Textbox(value="", label="System message")
1418
+ adv_med_developer_message = gr.Textbox(value="", label="Developer message")
1419
+ with gr.Accordion("Training", open=True):
1420
+ with gr.Row():
1421
+ adv_med_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
1422
+ adv_med_batch_size = gr.Number(value=4, precision=0, label="Batch size")
1423
+ adv_med_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
1424
+ with gr.Row():
1425
+ adv_med_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
1426
+ adv_med_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1427
 
1428
  smollm3_advanced_group = gr.Group(visible=False)
1429
  with smollm3_advanced_group:
1430
  gr.Markdown("Advanced configuration for SmolLM3")
1431
+ adv_sm_mode = gr.Radio(
1432
+ choices=["custom", "long_context"],
1433
+ value="custom",
1434
+ label="Advanced mode",
1435
+ )
1436
+ # --- SmolLM3 Custom ---
1437
+ sm_custom_group = gr.Group(visible=True)
1438
+ with sm_custom_group:
1439
+ with gr.Accordion("Dataset", open=True):
1440
+ adv_sm_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
1441
+ adv_sm_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
1442
+ with gr.Row():
1443
+ adv_sm_input_field = gr.Textbox(value="prompt", label="Input field")
1444
+ adv_sm_target_field = gr.Textbox(value="completion", label="Target field")
1445
+ with gr.Row():
1446
+ adv_sm_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
1447
+ adv_sm_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
1448
+ adv_sm_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
1449
+ with gr.Accordion("Training", open=True):
1450
+ with gr.Row():
1451
+ adv_sm_max_seq_length = gr.Number(value=4096, precision=0, label="Max seq length")
1452
+ adv_sm_batch_size = gr.Number(value=2, precision=0, label="Batch size")
1453
+ adv_sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
1454
+ adv_sm_learning_rate = gr.Number(value=5e-6, precision=6, label="Learning rate")
1455
+ with gr.Row():
1456
+ adv_sm_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1457
+ adv_sm_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1458
+ adv_sm_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1459
+
1460
+ # --- SmolLM3 Long-Context ---
1461
+ sm_long_group = gr.Group(visible=False)
1462
+ with sm_long_group:
1463
+ gr.Markdown("Generate a Long-Context SmolLM3 config")
1464
+ with gr.Accordion("Dataset", open=True):
1465
+ adv_sm_lc_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
1466
+ adv_sm_lc_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
1467
+ with gr.Row():
1468
+ adv_sm_lc_input_field = gr.Textbox(value="prompt", label="Input field")
1469
+ adv_sm_lc_target_field = gr.Textbox(value="completion", label="Target field")
1470
+ with gr.Row():
1471
+ adv_sm_lc_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
1472
+ adv_sm_lc_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
1473
+ adv_sm_lc_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
1474
+ with gr.Accordion("Training", open=True):
1475
+ with gr.Row():
1476
+ adv_sm_lc_max_seq_length = gr.Number(value=131072, precision=0, label="Max seq length (up to 131072)")
1477
+ adv_sm_lc_batch_size = gr.Number(value=1, precision=0, label="Batch size")
1478
+ adv_sm_lc_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
1479
+ adv_sm_lc_learning_rate = gr.Number(value=1e-5, precision=6, label="Learning rate")
1480
+ with gr.Row():
1481
+ adv_sm_lc_warmup_steps = gr.Number(value=200, precision=0, label="Warmup steps")
1482
+ adv_sm_lc_max_iters = gr.Number(value=500, precision=0, label="Max iters")
1483
+ with gr.Row():
1484
+ adv_sm_lc_save_steps = gr.Number(value=100, precision=0, label="Save steps")
1485
+ adv_sm_lc_eval_steps = gr.Number(value=50, precision=0, label="Eval steps")
1486
+ adv_sm_lc_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
1487
+ with gr.Accordion("Chat Template", open=False):
1488
+ with gr.Row():
1489
+ adv_sm_lc_use_chat_template = gr.Checkbox(value=True, label="Use chat template")
1490
+ adv_sm_lc_no_think_system_message = gr.Checkbox(value=True, label="No-think system message")
1491
+
1492
+ def _toggle_sm_mode(mode: str):
1493
+ return (
1494
+ gr.update(visible=mode == "custom"),
1495
+ gr.update(visible=mode == "long_context"),
1496
+ )
1497
+
1498
+ adv_sm_mode.change(
1499
+ _toggle_sm_mode,
1500
+ inputs=[adv_sm_mode],
1501
+ outputs=[sm_custom_group, sm_long_group],
1502
+ )
1503
 
1504
  def _toggle_advanced(enable: bool, family_val: str):
1505
  return (
 
1513
  outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
1514
  )
1515
 
1516
+ # Toggle between GPT-OSS Custom and Medical modes
1517
+ def _toggle_gpt_oss_mode(mode: str):
1518
+ return (
1519
+ gr.update(visible=mode == "custom"),
1520
+ gr.update(visible=mode == "medical_o1_sft"),
1521
+ )
1522
+
1523
+ adv_gpt_mode.change(
1524
+ _toggle_gpt_oss_mode,
1525
+ inputs=[adv_gpt_mode],
1526
+ outputs=[gpt_oss_custom_group, gpt_oss_medical_group],
1527
+ )
1528
+
1529
  # Final action & logs
1530
  start_btn = gr.Button("Run Pipeline", variant="primary")
1531
  logs = gr.Textbox(value="", label="Logs", lines=20)
 
1594
  adv_logging_steps,
1595
  adv_eval_steps,
1596
  adv_save_steps,
1597
+ # GPT-OSS Medical o1 SFT outputs (prefill defaults)
1598
+ adv_med_dataset_config,
1599
+ adv_med_system_message,
1600
+ adv_med_developer_message,
1601
+ adv_med_num_train_epochs,
1602
+ adv_med_batch_size,
1603
+ adv_med_gradient_accumulation_steps,
1604
+ adv_med_learning_rate,
1605
+ adv_med_max_seq_length,
1606
  # Advanced (SmolLM3)
1607
  adv_sm_model_name,
1608
  adv_sm_dataset_name,
 
1650
  min_lr_v,
1651
  min_lr_rate_v,
1652
  advanced_enabled_v,
1653
+ adv_gpt_mode_v,
1654
  # GPT-OSS advanced
1655
  adv_dataset_name_v,
1656
  adv_dataset_split_v,
 
1681
  adv_logging_steps_v,
1682
  adv_eval_steps_v,
1683
  adv_save_steps_v,
1684
+ # GPT-OSS Medical o1 SFT
1685
+ adv_med_dataset_config_v,
1686
+ adv_med_system_message_v,
1687
+ adv_med_developer_message_v,
1688
+ adv_med_num_train_epochs_v,
1689
+ adv_med_batch_size_v,
1690
+ adv_med_gradient_accumulation_steps_v,
1691
+ adv_med_learning_rate_v,
1692
+ adv_med_max_seq_length_v,
1693
  # SmolLM3 advanced
1694
+ adv_sm_mode_v,
1695
  adv_sm_model_name_v,
1696
  adv_sm_dataset_name_v,
1697
  adv_sm_input_field_v,
 
1706
  adv_sm_save_steps_v,
1707
  adv_sm_eval_steps_v,
1708
  adv_sm_logging_steps_v,
1709
+ # SmolLM3 long context
1710
+ adv_sm_lc_model_name_v,
1711
+ adv_sm_lc_dataset_name_v,
1712
+ adv_sm_lc_input_field_v,
1713
+ adv_sm_lc_target_field_v,
1714
+ adv_sm_lc_filter_bad_entries_v,
1715
+ adv_sm_lc_sample_size_v,
1716
+ adv_sm_lc_sample_seed_v,
1717
+ adv_sm_lc_max_seq_length_v,
1718
+ adv_sm_lc_batch_size_v,
1719
+ adv_sm_lc_gas_v,
1720
+ adv_sm_lc_learning_rate_v,
1721
+ adv_sm_lc_warmup_steps_v,
1722
+ adv_sm_lc_max_iters_v,
1723
+ adv_sm_lc_save_steps_v,
1724
+ adv_sm_lc_eval_steps_v,
1725
+ adv_sm_lc_logging_steps_v,
1726
+ adv_sm_lc_use_chat_template_v,
1727
+ adv_sm_lc_no_think_system_message_v,
1728
  ):
1729
  # If advanced overrides enabled, generate a config file and pass its path
1730
  override_path: Optional[str] = None
1731
  if advanced_enabled_v:
1732
  try:
1733
  if model_family_v == "GPT-OSS":
1734
+ if str(adv_gpt_mode_v) == "medical_o1_sft":
1735
+ cfg_path = generate_medical_o1_config_file(
1736
+ dataset_config=str(adv_med_dataset_config_v or "default"),
1737
+ system_message=(str(adv_med_system_message_v) if adv_med_system_message_v else None),
1738
+ developer_message=(str(adv_med_developer_message_v) if adv_med_developer_message_v else None),
1739
+ num_train_epochs=float(adv_med_num_train_epochs_v or 1.0),
1740
+ batch_size=int(adv_med_batch_size_v or 4),
1741
+ gradient_accumulation_steps=int(adv_med_gradient_accumulation_steps_v or 4),
1742
+ learning_rate=float(adv_med_learning_rate_v or 2e-4),
1743
+ max_seq_length=int(adv_med_max_seq_length_v or 2048),
1744
+ )
1745
+ else:
1746
+ cfg_path = generate_gpt_oss_custom_config_file(
1747
+ dataset_name=str(adv_dataset_name_v or ""),
1748
+ dataset_split=str(adv_dataset_split_v or "train"),
1749
+ dataset_format=str(adv_dataset_format_v or "openhermes_fr"),
1750
+ input_field=str(adv_input_field_v or "prompt"),
1751
+ target_field=(str(adv_target_field_v) if adv_target_field_v else None),
1752
+ system_message=(str(adv_system_message_v) if adv_system_message_v else None),
1753
+ developer_message=(str(adv_developer_message_v) if adv_developer_message_v else None),
1754
+ model_identity=(str(adv_model_identity_v) if adv_model_identity_v else None),
1755
+ max_samples=(int(adv_max_samples_v) if adv_max_samples_v else None),
1756
+ min_length=int(adv_min_length_v or 10),
1757
+ max_length=(int(adv_max_length_v) if adv_max_length_v else None),
1758
+ num_train_epochs=float(adv_num_train_epochs_v or 1.0),
1759
+ batch_size=int(adv_batch_size_v or 4),
1760
+ gradient_accumulation_steps=int(adv_gas_v or 4),
1761
+ learning_rate=float(adv_lr_v or 2e-4),
1762
+ min_lr=float(adv_min_lr_num_v or 2e-5),
1763
+ weight_decay=float(adv_wd_v or 0.01),
1764
+ warmup_ratio=float(adv_warmup_ratio_v or 0.03),
1765
+ max_seq_length=int(adv_max_seq_length_v or 2048),
1766
+ lora_r=int(adv_lora_r_v or 16),
1767
+ lora_alpha=int(adv_lora_alpha_v or 32),
1768
+ lora_dropout=float(adv_lora_dropout_v or 0.05),
1769
+ mixed_precision=str(adv_mixed_precision_v or "bf16"),
1770
+ num_workers=int(adv_num_workers_v or 4),
1771
+ quantization_type=str(adv_quantization_type_v or "mxfp4"),
1772
+ max_grad_norm=float(adv_max_grad_norm_v or 1.0),
1773
+ logging_steps=int(adv_logging_steps_v or 10),
1774
+ eval_steps=int(adv_eval_steps_v or 100),
1775
+ save_steps=int(adv_save_steps_v or 500),
1776
+ )
1777
  else:
1778
+ if str(adv_sm_mode_v) == "long_context":
1779
+ cfg_path = generate_smollm3_long_context_config_file(
1780
+ model_name=str(adv_sm_lc_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
1781
+ dataset_name=(str(adv_sm_lc_dataset_name_v) if adv_sm_lc_dataset_name_v else None),
1782
+ input_field=str(adv_sm_lc_input_field_v or "prompt"),
1783
+ target_field=str(adv_sm_lc_target_field_v or "completion"),
1784
+ filter_bad_entries=bool(adv_sm_lc_filter_bad_entries_v),
1785
+ sample_size=(int(adv_sm_lc_sample_size_v) if adv_sm_lc_sample_size_v else None),
1786
+ sample_seed=int(adv_sm_lc_sample_seed_v or 42),
1787
+ max_seq_length=int(adv_sm_lc_max_seq_length_v or 131072),
1788
+ batch_size=int(adv_sm_lc_batch_size_v or 1),
1789
+ gradient_accumulation_steps=int(adv_sm_lc_gas_v or 8),
1790
+ learning_rate=float(adv_sm_lc_learning_rate_v or 1e-5),
1791
+ warmup_steps=int(adv_sm_lc_warmup_steps_v or 200),
1792
+ max_iters=int(adv_sm_lc_max_iters_v or 500),
1793
+ save_steps=int(adv_sm_lc_save_steps_v or 100),
1794
+ eval_steps=int(adv_sm_lc_eval_steps_v or 50),
1795
+ logging_steps=int(adv_sm_lc_logging_steps_v or 10),
1796
+ use_chat_template=bool(adv_sm_lc_use_chat_template_v),
1797
+ no_think_system_message=bool(adv_sm_lc_no_think_system_message_v),
1798
+ trainer_type=str(trainer_type_v).lower(),
1799
+ )
1800
+ else:
1801
+ cfg_path = generate_smollm3_custom_config_file(
1802
+ model_name=str(adv_sm_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
1803
+ dataset_name=(str(adv_sm_dataset_name_v) if adv_sm_dataset_name_v else None),
1804
+ max_seq_length=int(adv_sm_max_seq_length_v or 4096),
1805
+ batch_size=int(adv_sm_batch_size_v or 2),
1806
+ gradient_accumulation_steps=int(adv_sm_gas_v or 8),
1807
+ learning_rate=float(adv_sm_learning_rate_v or 5e-6),
1808
+ save_steps=int(adv_sm_save_steps_v or 500),
1809
+ eval_steps=int(adv_sm_eval_steps_v or 100),
1810
+ logging_steps=int(adv_sm_logging_steps_v or 10),
1811
+ filter_bad_entries=bool(adv_sm_filter_bad_entries_v),
1812
+ input_field=str(adv_sm_input_field_v or "prompt"),
1813
+ target_field=str(adv_sm_target_field_v or "completion"),
1814
+ sample_size=(int(adv_sm_sample_size_v) if adv_sm_sample_size_v else None),
1815
+ sample_seed=int(adv_sm_sample_seed_v or 42),
1816
+ trainer_type=str(trainer_type_v).lower(),
1817
+ )
1818
  override_path = str(cfg_path)
1819
  except Exception as e:
1820
  # Surface error in logs via generator
 
1871
  min_lr,
1872
  min_lr_rate,
1873
  advanced_enabled,
1874
+ adv_gpt_mode,
1875
  # GPT-OSS advanced
1876
  adv_dataset_name,
1877
  adv_dataset_split,
 
1902
  adv_logging_steps,
1903
  adv_eval_steps,
1904
  adv_save_steps,
1905
+ # GPT-OSS Medical o1 SFT
1906
+ adv_med_dataset_config,
1907
+ adv_med_system_message,
1908
+ adv_med_developer_message,
1909
+ adv_med_num_train_epochs,
1910
+ adv_med_batch_size,
1911
+ adv_med_gradient_accumulation_steps,
1912
+ adv_med_learning_rate,
1913
+ adv_med_max_seq_length,
1914
  # SmolLM3 advanced
1915
+ adv_sm_mode,
1916
  adv_sm_model_name,
1917
  adv_sm_dataset_name,
1918
  adv_sm_input_field,
 
1927
  adv_sm_save_steps,
1928
  adv_sm_eval_steps,
1929
  adv_sm_logging_steps,
1930
+ # SmolLM3 long context
1931
+ adv_sm_lc_model_name,
1932
+ adv_sm_lc_dataset_name,
1933
+ adv_sm_lc_input_field,
1934
+ adv_sm_lc_target_field,
1935
+ adv_sm_lc_filter_bad_entries,
1936
+ adv_sm_lc_sample_size,
1937
+ adv_sm_lc_sample_seed,
1938
+ adv_sm_lc_max_seq_length,
1939
+ adv_sm_lc_batch_size,
1940
+ adv_sm_lc_gas,
1941
+ adv_sm_lc_learning_rate,
1942
+ adv_sm_lc_warmup_steps,
1943
+ adv_sm_lc_max_iters,
1944
+ adv_sm_lc_save_steps,
1945
+ adv_sm_lc_eval_steps,
1946
+ adv_sm_lc_logging_steps,
1947
+ adv_sm_lc_use_chat_template,
1948
+ adv_sm_lc_no_think_system_message,
1949
  ],
1950
  outputs=[logs],
1951
  )