Spaces:
Running
Running
adds tabbed interface, advanced mode, connectors
Browse files- interface.py +423 -127
interface.py
CHANGED
@@ -470,6 +470,94 @@ config = SmolLM3GeneratedConfig(
|
|
470 |
"""
|
471 |
return _write_generated_config("_generated_smollm3_custom.py", py)
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
def ensure_dataset_repo(username: str, dataset_name: str, token: str) -> Tuple[str, bool, str]:
|
474 |
"""Create or ensure dataset repo exists. Returns (repo_id, created_or_exists, message)."""
|
475 |
from huggingface_hub import create_repo # type: ignore
|
@@ -907,6 +995,8 @@ def on_config_change(family: str, config_choice: str):
|
|
907 |
"", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
|
908 |
None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
|
909 |
2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
|
|
|
|
|
910 |
# Advanced fields (SmolLM3)
|
911 |
"HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
|
912 |
4096, 2, 8, 5e-6, 500, 100, 10,
|
@@ -1041,6 +1131,15 @@ def on_config_change(family: str, config_choice: str):
|
|
1041 |
adv_log,
|
1042 |
adv_eval,
|
1043 |
adv_save,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
# Advanced (SmolLM3)
|
1045 |
adv_sm_model_name,
|
1046 |
adv_sm_dataset_name,
|
@@ -1237,90 +1336,170 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1237 |
gpt_oss_advanced_group = gr.Group(visible=False)
|
1238 |
with gpt_oss_advanced_group:
|
1239 |
gr.Markdown("Advanced configuration for GPT-OSS")
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1248 |
)
|
1249 |
-
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
1264 |
-
|
1265 |
-
|
1266 |
-
|
1267 |
-
|
1268 |
-
|
1269 |
-
|
1270 |
-
|
1271 |
-
adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
|
1272 |
-
|
1273 |
-
with gr.Accordion("LoRA & Quantization", open=False):
|
1274 |
-
with gr.Row():
|
1275 |
-
adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
|
1276 |
-
adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
1277 |
-
adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
|
1278 |
-
with gr.Row():
|
1279 |
-
adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
|
1280 |
-
adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
|
1281 |
-
adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
|
1282 |
-
adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
|
1283 |
-
|
1284 |
-
with gr.Accordion("Eval & Logging", open=False):
|
1285 |
-
with gr.Row():
|
1286 |
-
adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1287 |
-
adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1288 |
-
adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1289 |
-
|
1290 |
-
with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
|
1291 |
-
scheduler_override = gr.Dropdown(
|
1292 |
-
choices=[c for c in SCHEDULER_CHOICES if c is not None],
|
1293 |
-
value=None,
|
1294 |
-
allow_custom_value=True,
|
1295 |
-
label="Scheduler override",
|
1296 |
-
)
|
1297 |
-
with gr.Row():
|
1298 |
-
min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
|
1299 |
-
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
|
1300 |
|
1301 |
smollm3_advanced_group = gr.Group(visible=False)
|
1302 |
with smollm3_advanced_group:
|
1303 |
gr.Markdown("Advanced configuration for SmolLM3")
|
1304 |
-
|
1305 |
-
|
1306 |
-
|
1307 |
-
|
1308 |
-
|
1309 |
-
|
1310 |
-
|
1311 |
-
|
1312 |
-
|
1313 |
-
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
-
|
1318 |
-
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1324 |
|
1325 |
def _toggle_advanced(enable: bool, family_val: str):
|
1326 |
return (
|
@@ -1334,6 +1513,19 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1334 |
outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
|
1335 |
)
|
1336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1337 |
# Final action & logs
|
1338 |
start_btn = gr.Button("Run Pipeline", variant="primary")
|
1339 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
@@ -1402,6 +1594,15 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1402 |
adv_logging_steps,
|
1403 |
adv_eval_steps,
|
1404 |
adv_save_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1405 |
# Advanced (SmolLM3)
|
1406 |
adv_sm_model_name,
|
1407 |
adv_sm_dataset_name,
|
@@ -1449,6 +1650,7 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1449 |
min_lr_v,
|
1450 |
min_lr_rate_v,
|
1451 |
advanced_enabled_v,
|
|
|
1452 |
# GPT-OSS advanced
|
1453 |
adv_dataset_name_v,
|
1454 |
adv_dataset_split_v,
|
@@ -1479,7 +1681,17 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1479 |
adv_logging_steps_v,
|
1480 |
adv_eval_steps_v,
|
1481 |
adv_save_steps_v,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1482 |
# SmolLM3 advanced
|
|
|
1483 |
adv_sm_model_name_v,
|
1484 |
adv_sm_dataset_name_v,
|
1485 |
adv_sm_input_field_v,
|
@@ -1494,61 +1706,115 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1494 |
adv_sm_save_steps_v,
|
1495 |
adv_sm_eval_steps_v,
|
1496 |
adv_sm_logging_steps_v,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1497 |
):
|
1498 |
# If advanced overrides enabled, generate a config file and pass its path
|
1499 |
override_path: Optional[str] = None
|
1500 |
if advanced_enabled_v:
|
1501 |
try:
|
1502 |
if model_family_v == "GPT-OSS":
|
1503 |
-
|
1504 |
-
|
1505 |
-
|
1506 |
-
|
1507 |
-
|
1508 |
-
|
1509 |
-
|
1510 |
-
|
1511 |
-
|
1512 |
-
|
1513 |
-
|
1514 |
-
|
1515 |
-
|
1516 |
-
|
1517 |
-
|
1518 |
-
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
|
1524 |
-
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
|
1530 |
-
|
1531 |
-
|
1532 |
-
|
1533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1534 |
else:
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
1538 |
-
|
1539 |
-
|
1540 |
-
|
1541 |
-
|
1542 |
-
|
1543 |
-
|
1544 |
-
|
1545 |
-
|
1546 |
-
|
1547 |
-
|
1548 |
-
|
1549 |
-
|
1550 |
-
|
1551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1552 |
override_path = str(cfg_path)
|
1553 |
except Exception as e:
|
1554 |
# Surface error in logs via generator
|
@@ -1605,6 +1871,7 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1605 |
min_lr,
|
1606 |
min_lr_rate,
|
1607 |
advanced_enabled,
|
|
|
1608 |
# GPT-OSS advanced
|
1609 |
adv_dataset_name,
|
1610 |
adv_dataset_split,
|
@@ -1635,7 +1902,17 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1635 |
adv_logging_steps,
|
1636 |
adv_eval_steps,
|
1637 |
adv_save_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1638 |
# SmolLM3 advanced
|
|
|
1639 |
adv_sm_model_name,
|
1640 |
adv_sm_dataset_name,
|
1641 |
adv_sm_input_field,
|
@@ -1650,6 +1927,25 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
1650 |
adv_sm_save_steps,
|
1651 |
adv_sm_eval_steps,
|
1652 |
adv_sm_logging_steps,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1653 |
],
|
1654 |
outputs=[logs],
|
1655 |
)
|
|
|
470 |
"""
|
471 |
return _write_generated_config("_generated_smollm3_custom.py", py)
|
472 |
|
473 |
+
|
474 |
+
def generate_smollm3_long_context_config_file(
|
475 |
+
model_name: str,
|
476 |
+
dataset_name: Optional[str],
|
477 |
+
input_field: str,
|
478 |
+
target_field: str,
|
479 |
+
filter_bad_entries: bool,
|
480 |
+
sample_size: Optional[int],
|
481 |
+
sample_seed: int,
|
482 |
+
max_seq_length: int,
|
483 |
+
batch_size: int,
|
484 |
+
gradient_accumulation_steps: int,
|
485 |
+
learning_rate: float,
|
486 |
+
warmup_steps: int,
|
487 |
+
max_iters: int,
|
488 |
+
save_steps: int,
|
489 |
+
eval_steps: int,
|
490 |
+
logging_steps: int,
|
491 |
+
use_chat_template: bool,
|
492 |
+
no_think_system_message: bool,
|
493 |
+
trainer_type: str,
|
494 |
+
) -> Path:
|
495 |
+
"""Create a SmolLM3 long-context config file with optional dataset fields."""
|
496 |
+
def _bool(b: bool) -> str:
|
497 |
+
return "True" if b else "False"
|
498 |
+
|
499 |
+
ds_section = """
|
500 |
+
# HF Dataset configuration
|
501 |
+
dataset_name={}
|
502 |
+
dataset_split="train"
|
503 |
+
input_field={}
|
504 |
+
target_field={}
|
505 |
+
filter_bad_entries={}
|
506 |
+
bad_entry_field="bad_entry"
|
507 |
+
sample_size={}
|
508 |
+
sample_seed={}
|
509 |
+
""".format(
|
510 |
+
repr(dataset_name) if dataset_name else "None",
|
511 |
+
repr(input_field),
|
512 |
+
repr(target_field),
|
513 |
+
_bool(filter_bad_entries),
|
514 |
+
repr(sample_size) if sample_size is not None else "None",
|
515 |
+
sample_seed,
|
516 |
+
)
|
517 |
+
|
518 |
+
py = f"""
|
519 |
+
from dataclasses import dataclass
|
520 |
+
from typing import Optional
|
521 |
+
from config.train_smollm3 import SmolLM3Config
|
522 |
+
|
523 |
+
@dataclass
|
524 |
+
class SmolLM3LongContextGeneratedConfig(SmolLM3Config):
|
525 |
+
{ds_section}
|
526 |
+
|
527 |
+
config = SmolLM3LongContextGeneratedConfig(
|
528 |
+
trainer_type={repr(trainer_type.lower())},
|
529 |
+
model_name={repr(model_name)},
|
530 |
+
max_seq_length={max_seq_length},
|
531 |
+
use_flash_attention=True,
|
532 |
+
use_gradient_checkpointing=True,
|
533 |
+
|
534 |
+
batch_size={batch_size},
|
535 |
+
gradient_accumulation_steps={gradient_accumulation_steps},
|
536 |
+
learning_rate={learning_rate},
|
537 |
+
weight_decay=0.01,
|
538 |
+
warmup_steps={warmup_steps},
|
539 |
+
max_iters={max_iters},
|
540 |
+
|
541 |
+
fp16=True,
|
542 |
+
bf16=False,
|
543 |
+
save_steps={save_steps},
|
544 |
+
eval_steps={eval_steps},
|
545 |
+
logging_steps={logging_steps},
|
546 |
+
save_total_limit=3,
|
547 |
+
eval_strategy="steps",
|
548 |
+
metric_for_best_model="eval_loss",
|
549 |
+
greater_is_better=False,
|
550 |
+
load_best_model_at_end=True,
|
551 |
+
|
552 |
+
use_chat_template={_bool(use_chat_template)},
|
553 |
+
chat_template_kwargs={{
|
554 |
+
"add_generation_prompt": True,
|
555 |
+
"no_think_system_message": {_bool(no_think_system_message)}
|
556 |
+
}}
|
557 |
+
)
|
558 |
+
"""
|
559 |
+
return _write_generated_config("_generated_smollm3_long_context.py", py)
|
560 |
+
|
561 |
def ensure_dataset_repo(username: str, dataset_name: str, token: str) -> Tuple[str, bool, str]:
|
562 |
"""Create or ensure dataset repo exists. Returns (repo_id, created_or_exists, message)."""
|
563 |
from huggingface_hub import create_repo # type: ignore
|
|
|
995 |
"", "train", "openhermes_fr", "prompt", "accepted_completion", "", "", "",
|
996 |
None, 10, None, 1.0, 4, 4, 2e-4, 2e-5, 0.01, 0.03,
|
997 |
2048, 16, 32, 0.05, "bf16", 4, "mxfp4", 1.0, 10, 100, 500,
|
998 |
+
# GPT-OSS Medical o1 SFT defaults
|
999 |
+
"default", "", "", 1.0, 4, 4, 2e-4, 2048,
|
1000 |
# Advanced fields (SmolLM3)
|
1001 |
"HuggingFaceTB/SmolLM3-3B", None, "prompt", "completion", False, None, 42,
|
1002 |
4096, 2, 8, 5e-6, 500, 100, 10,
|
|
|
1131 |
adv_log,
|
1132 |
adv_eval,
|
1133 |
adv_save,
|
1134 |
+
# GPT-OSS Medical o1 SFT defaults
|
1135 |
+
"default",
|
1136 |
+
"",
|
1137 |
+
"",
|
1138 |
+
1.0,
|
1139 |
+
4,
|
1140 |
+
4,
|
1141 |
+
2e-4,
|
1142 |
+
2048,
|
1143 |
# Advanced (SmolLM3)
|
1144 |
adv_sm_model_name,
|
1145 |
adv_sm_dataset_name,
|
|
|
1336 |
gpt_oss_advanced_group = gr.Group(visible=False)
|
1337 |
with gpt_oss_advanced_group:
|
1338 |
gr.Markdown("Advanced configuration for GPT-OSS")
|
1339 |
+
adv_gpt_mode = gr.Radio(
|
1340 |
+
choices=["custom", "medical_o1_sft"],
|
1341 |
+
value="custom",
|
1342 |
+
label="Advanced mode",
|
1343 |
+
)
|
1344 |
+
|
1345 |
+
# --- GPT-OSS Custom advanced controls ---
|
1346 |
+
gpt_oss_custom_group = gr.Group(visible=True)
|
1347 |
+
with gpt_oss_custom_group:
|
1348 |
+
with gr.Accordion("Dataset", open=True):
|
1349 |
+
adv_dataset_name = gr.Textbox(value="", label="Dataset name")
|
1350 |
+
with gr.Row():
|
1351 |
+
adv_dataset_split = gr.Textbox(value="train", label="Dataset split")
|
1352 |
+
adv_dataset_format = gr.Dropdown(
|
1353 |
+
choices=["openhermes_fr", "messages", "text"],
|
1354 |
+
value="openhermes_fr",
|
1355 |
+
label="Dataset format",
|
1356 |
+
)
|
1357 |
+
with gr.Row():
|
1358 |
+
adv_input_field = gr.Textbox(value="prompt", label="Input field")
|
1359 |
+
adv_target_field = gr.Textbox(value="accepted_completion", label="Target field (optional)")
|
1360 |
+
with gr.Row():
|
1361 |
+
adv_system_message = gr.Textbox(value="", label="System message (optional)")
|
1362 |
+
adv_developer_message = gr.Textbox(value="", label="Developer message (optional)")
|
1363 |
+
adv_model_identity = gr.Textbox(value="", label="Model identity (optional)")
|
1364 |
+
with gr.Row():
|
1365 |
+
adv_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
|
1366 |
+
adv_min_length = gr.Number(value=10, precision=0, label="Min length")
|
1367 |
+
adv_max_length = gr.Number(value=None, precision=0, label="Max length (optional)")
|
1368 |
+
|
1369 |
+
with gr.Accordion("Training", open=True):
|
1370 |
+
with gr.Row():
|
1371 |
+
adv_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
|
1372 |
+
adv_batch_size = gr.Number(value=4, precision=0, label="Batch size")
|
1373 |
+
adv_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
|
1374 |
+
with gr.Row():
|
1375 |
+
adv_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
1376 |
+
adv_min_lr_num = gr.Number(value=2e-5, precision=6, label="Min LR")
|
1377 |
+
adv_weight_decay = gr.Number(value=0.01, precision=6, label="Weight decay")
|
1378 |
+
adv_warmup_ratio = gr.Number(value=0.03, precision=3, label="Warmup ratio")
|
1379 |
+
adv_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
|
1380 |
+
|
1381 |
+
with gr.Accordion("LoRA & Quantization", open=False):
|
1382 |
+
with gr.Row():
|
1383 |
+
adv_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
|
1384 |
+
adv_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
1385 |
+
adv_lora_dropout = gr.Number(value=0.05, precision=3, label="LoRA dropout")
|
1386 |
+
with gr.Row():
|
1387 |
+
adv_mixed_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
|
1388 |
+
adv_num_workers = gr.Number(value=4, precision=0, label="Data workers")
|
1389 |
+
adv_quantization_type = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
|
1390 |
+
adv_max_grad_norm = gr.Number(value=1.0, precision=3, label="Max grad norm")
|
1391 |
+
|
1392 |
+
with gr.Accordion("Eval & Logging", open=False):
|
1393 |
+
with gr.Row():
|
1394 |
+
adv_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1395 |
+
adv_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1396 |
+
adv_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1397 |
+
|
1398 |
+
with gr.Accordion("Scheduler (GPT-OSS only)", open=False):
|
1399 |
+
scheduler_override = gr.Dropdown(
|
1400 |
+
choices=[c for c in SCHEDULER_CHOICES if c is not None],
|
1401 |
+
value=None,
|
1402 |
+
allow_custom_value=True,
|
1403 |
+
label="Scheduler override",
|
1404 |
)
|
1405 |
+
with gr.Row():
|
1406 |
+
min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
|
1407 |
+
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
|
1408 |
+
|
1409 |
+
# --- GPT-OSS Medical o1 SFT controls ---
|
1410 |
+
gpt_oss_medical_group = gr.Group(visible=False)
|
1411 |
+
with gpt_oss_medical_group:
|
1412 |
+
gr.Markdown("Build a Medical o1 SFT configuration (dataset fixed to FreedomIntelligence/medical-o1-reasoning-SFT)")
|
1413 |
+
with gr.Accordion("Dataset", open=True):
|
1414 |
+
adv_med_dataset_config = gr.Textbox(value="default", label="Dataset config (subset)")
|
1415 |
+
with gr.Accordion("Context (optional)", open=False):
|
1416 |
+
with gr.Row():
|
1417 |
+
adv_med_system_message = gr.Textbox(value="", label="System message")
|
1418 |
+
adv_med_developer_message = gr.Textbox(value="", label="Developer message")
|
1419 |
+
with gr.Accordion("Training", open=True):
|
1420 |
+
with gr.Row():
|
1421 |
+
adv_med_num_train_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
|
1422 |
+
adv_med_batch_size = gr.Number(value=4, precision=0, label="Batch size")
|
1423 |
+
adv_med_gradient_accumulation_steps = gr.Number(value=4, precision=0, label="Grad accumulation")
|
1424 |
+
with gr.Row():
|
1425 |
+
adv_med_learning_rate = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
1426 |
+
adv_med_max_seq_length = gr.Number(value=2048, precision=0, label="Max seq length")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1427 |
|
1428 |
smollm3_advanced_group = gr.Group(visible=False)
|
1429 |
with smollm3_advanced_group:
|
1430 |
gr.Markdown("Advanced configuration for SmolLM3")
|
1431 |
+
adv_sm_mode = gr.Radio(
|
1432 |
+
choices=["custom", "long_context"],
|
1433 |
+
value="custom",
|
1434 |
+
label="Advanced mode",
|
1435 |
+
)
|
1436 |
+
# --- SmolLM3 Custom ---
|
1437 |
+
sm_custom_group = gr.Group(visible=True)
|
1438 |
+
with sm_custom_group:
|
1439 |
+
with gr.Accordion("Dataset", open=True):
|
1440 |
+
adv_sm_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
|
1441 |
+
adv_sm_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
|
1442 |
+
with gr.Row():
|
1443 |
+
adv_sm_input_field = gr.Textbox(value="prompt", label="Input field")
|
1444 |
+
adv_sm_target_field = gr.Textbox(value="completion", label="Target field")
|
1445 |
+
with gr.Row():
|
1446 |
+
adv_sm_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
|
1447 |
+
adv_sm_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
|
1448 |
+
adv_sm_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
|
1449 |
+
with gr.Accordion("Training", open=True):
|
1450 |
+
with gr.Row():
|
1451 |
+
adv_sm_max_seq_length = gr.Number(value=4096, precision=0, label="Max seq length")
|
1452 |
+
adv_sm_batch_size = gr.Number(value=2, precision=0, label="Batch size")
|
1453 |
+
adv_sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
|
1454 |
+
adv_sm_learning_rate = gr.Number(value=5e-6, precision=6, label="Learning rate")
|
1455 |
+
with gr.Row():
|
1456 |
+
adv_sm_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1457 |
+
adv_sm_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1458 |
+
adv_sm_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1459 |
+
|
1460 |
+
# --- SmolLM3 Long-Context ---
|
1461 |
+
sm_long_group = gr.Group(visible=False)
|
1462 |
+
with sm_long_group:
|
1463 |
+
gr.Markdown("Generate a Long-Context SmolLM3 config")
|
1464 |
+
with gr.Accordion("Dataset", open=True):
|
1465 |
+
adv_sm_lc_model_name = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
|
1466 |
+
adv_sm_lc_dataset_name = gr.Textbox(value="", label="Dataset name (optional)")
|
1467 |
+
with gr.Row():
|
1468 |
+
adv_sm_lc_input_field = gr.Textbox(value="prompt", label="Input field")
|
1469 |
+
adv_sm_lc_target_field = gr.Textbox(value="completion", label="Target field")
|
1470 |
+
with gr.Row():
|
1471 |
+
adv_sm_lc_filter_bad_entries = gr.Checkbox(value=False, label="Filter bad entries")
|
1472 |
+
adv_sm_lc_sample_size = gr.Number(value=None, precision=0, label="Sample size (optional)")
|
1473 |
+
adv_sm_lc_sample_seed = gr.Number(value=42, precision=0, label="Sample seed")
|
1474 |
+
with gr.Accordion("Training", open=True):
|
1475 |
+
with gr.Row():
|
1476 |
+
adv_sm_lc_max_seq_length = gr.Number(value=131072, precision=0, label="Max seq length (up to 131072)")
|
1477 |
+
adv_sm_lc_batch_size = gr.Number(value=1, precision=0, label="Batch size")
|
1478 |
+
adv_sm_lc_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
|
1479 |
+
adv_sm_lc_learning_rate = gr.Number(value=1e-5, precision=6, label="Learning rate")
|
1480 |
+
with gr.Row():
|
1481 |
+
adv_sm_lc_warmup_steps = gr.Number(value=200, precision=0, label="Warmup steps")
|
1482 |
+
adv_sm_lc_max_iters = gr.Number(value=500, precision=0, label="Max iters")
|
1483 |
+
with gr.Row():
|
1484 |
+
adv_sm_lc_save_steps = gr.Number(value=100, precision=0, label="Save steps")
|
1485 |
+
adv_sm_lc_eval_steps = gr.Number(value=50, precision=0, label="Eval steps")
|
1486 |
+
adv_sm_lc_logging_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1487 |
+
with gr.Accordion("Chat Template", open=False):
|
1488 |
+
with gr.Row():
|
1489 |
+
adv_sm_lc_use_chat_template = gr.Checkbox(value=True, label="Use chat template")
|
1490 |
+
adv_sm_lc_no_think_system_message = gr.Checkbox(value=True, label="No-think system message")
|
1491 |
+
|
1492 |
+
def _toggle_sm_mode(mode: str):
|
1493 |
+
return (
|
1494 |
+
gr.update(visible=mode == "custom"),
|
1495 |
+
gr.update(visible=mode == "long_context"),
|
1496 |
+
)
|
1497 |
+
|
1498 |
+
adv_sm_mode.change(
|
1499 |
+
_toggle_sm_mode,
|
1500 |
+
inputs=[adv_sm_mode],
|
1501 |
+
outputs=[sm_custom_group, sm_long_group],
|
1502 |
+
)
|
1503 |
|
1504 |
def _toggle_advanced(enable: bool, family_val: str):
|
1505 |
return (
|
|
|
1513 |
outputs=[gpt_oss_advanced_group, smollm3_advanced_group],
|
1514 |
)
|
1515 |
|
1516 |
+
# Toggle between GPT-OSS Custom and Medical modes
|
1517 |
+
def _toggle_gpt_oss_mode(mode: str):
|
1518 |
+
return (
|
1519 |
+
gr.update(visible=mode == "custom"),
|
1520 |
+
gr.update(visible=mode == "medical_o1_sft"),
|
1521 |
+
)
|
1522 |
+
|
1523 |
+
adv_gpt_mode.change(
|
1524 |
+
_toggle_gpt_oss_mode,
|
1525 |
+
inputs=[adv_gpt_mode],
|
1526 |
+
outputs=[gpt_oss_custom_group, gpt_oss_medical_group],
|
1527 |
+
)
|
1528 |
+
|
1529 |
# Final action & logs
|
1530 |
start_btn = gr.Button("Run Pipeline", variant="primary")
|
1531 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
|
|
1594 |
adv_logging_steps,
|
1595 |
adv_eval_steps,
|
1596 |
adv_save_steps,
|
1597 |
+
# GPT-OSS Medical o1 SFT outputs (prefill defaults)
|
1598 |
+
adv_med_dataset_config,
|
1599 |
+
adv_med_system_message,
|
1600 |
+
adv_med_developer_message,
|
1601 |
+
adv_med_num_train_epochs,
|
1602 |
+
adv_med_batch_size,
|
1603 |
+
adv_med_gradient_accumulation_steps,
|
1604 |
+
adv_med_learning_rate,
|
1605 |
+
adv_med_max_seq_length,
|
1606 |
# Advanced (SmolLM3)
|
1607 |
adv_sm_model_name,
|
1608 |
adv_sm_dataset_name,
|
|
|
1650 |
min_lr_v,
|
1651 |
min_lr_rate_v,
|
1652 |
advanced_enabled_v,
|
1653 |
+
adv_gpt_mode_v,
|
1654 |
# GPT-OSS advanced
|
1655 |
adv_dataset_name_v,
|
1656 |
adv_dataset_split_v,
|
|
|
1681 |
adv_logging_steps_v,
|
1682 |
adv_eval_steps_v,
|
1683 |
adv_save_steps_v,
|
1684 |
+
# GPT-OSS Medical o1 SFT
|
1685 |
+
adv_med_dataset_config_v,
|
1686 |
+
adv_med_system_message_v,
|
1687 |
+
adv_med_developer_message_v,
|
1688 |
+
adv_med_num_train_epochs_v,
|
1689 |
+
adv_med_batch_size_v,
|
1690 |
+
adv_med_gradient_accumulation_steps_v,
|
1691 |
+
adv_med_learning_rate_v,
|
1692 |
+
adv_med_max_seq_length_v,
|
1693 |
# SmolLM3 advanced
|
1694 |
+
adv_sm_mode_v,
|
1695 |
adv_sm_model_name_v,
|
1696 |
adv_sm_dataset_name_v,
|
1697 |
adv_sm_input_field_v,
|
|
|
1706 |
adv_sm_save_steps_v,
|
1707 |
adv_sm_eval_steps_v,
|
1708 |
adv_sm_logging_steps_v,
|
1709 |
+
# SmolLM3 long context
|
1710 |
+
adv_sm_lc_model_name_v,
|
1711 |
+
adv_sm_lc_dataset_name_v,
|
1712 |
+
adv_sm_lc_input_field_v,
|
1713 |
+
adv_sm_lc_target_field_v,
|
1714 |
+
adv_sm_lc_filter_bad_entries_v,
|
1715 |
+
adv_sm_lc_sample_size_v,
|
1716 |
+
adv_sm_lc_sample_seed_v,
|
1717 |
+
adv_sm_lc_max_seq_length_v,
|
1718 |
+
adv_sm_lc_batch_size_v,
|
1719 |
+
adv_sm_lc_gas_v,
|
1720 |
+
adv_sm_lc_learning_rate_v,
|
1721 |
+
adv_sm_lc_warmup_steps_v,
|
1722 |
+
adv_sm_lc_max_iters_v,
|
1723 |
+
adv_sm_lc_save_steps_v,
|
1724 |
+
adv_sm_lc_eval_steps_v,
|
1725 |
+
adv_sm_lc_logging_steps_v,
|
1726 |
+
adv_sm_lc_use_chat_template_v,
|
1727 |
+
adv_sm_lc_no_think_system_message_v,
|
1728 |
):
|
1729 |
# If advanced overrides enabled, generate a config file and pass its path
|
1730 |
override_path: Optional[str] = None
|
1731 |
if advanced_enabled_v:
|
1732 |
try:
|
1733 |
if model_family_v == "GPT-OSS":
|
1734 |
+
if str(adv_gpt_mode_v) == "medical_o1_sft":
|
1735 |
+
cfg_path = generate_medical_o1_config_file(
|
1736 |
+
dataset_config=str(adv_med_dataset_config_v or "default"),
|
1737 |
+
system_message=(str(adv_med_system_message_v) if adv_med_system_message_v else None),
|
1738 |
+
developer_message=(str(adv_med_developer_message_v) if adv_med_developer_message_v else None),
|
1739 |
+
num_train_epochs=float(adv_med_num_train_epochs_v or 1.0),
|
1740 |
+
batch_size=int(adv_med_batch_size_v or 4),
|
1741 |
+
gradient_accumulation_steps=int(adv_med_gradient_accumulation_steps_v or 4),
|
1742 |
+
learning_rate=float(adv_med_learning_rate_v or 2e-4),
|
1743 |
+
max_seq_length=int(adv_med_max_seq_length_v or 2048),
|
1744 |
+
)
|
1745 |
+
else:
|
1746 |
+
cfg_path = generate_gpt_oss_custom_config_file(
|
1747 |
+
dataset_name=str(adv_dataset_name_v or ""),
|
1748 |
+
dataset_split=str(adv_dataset_split_v or "train"),
|
1749 |
+
dataset_format=str(adv_dataset_format_v or "openhermes_fr"),
|
1750 |
+
input_field=str(adv_input_field_v or "prompt"),
|
1751 |
+
target_field=(str(adv_target_field_v) if adv_target_field_v else None),
|
1752 |
+
system_message=(str(adv_system_message_v) if adv_system_message_v else None),
|
1753 |
+
developer_message=(str(adv_developer_message_v) if adv_developer_message_v else None),
|
1754 |
+
model_identity=(str(adv_model_identity_v) if adv_model_identity_v else None),
|
1755 |
+
max_samples=(int(adv_max_samples_v) if adv_max_samples_v else None),
|
1756 |
+
min_length=int(adv_min_length_v or 10),
|
1757 |
+
max_length=(int(adv_max_length_v) if adv_max_length_v else None),
|
1758 |
+
num_train_epochs=float(adv_num_train_epochs_v or 1.0),
|
1759 |
+
batch_size=int(adv_batch_size_v or 4),
|
1760 |
+
gradient_accumulation_steps=int(adv_gas_v or 4),
|
1761 |
+
learning_rate=float(adv_lr_v or 2e-4),
|
1762 |
+
min_lr=float(adv_min_lr_num_v or 2e-5),
|
1763 |
+
weight_decay=float(adv_wd_v or 0.01),
|
1764 |
+
warmup_ratio=float(adv_warmup_ratio_v or 0.03),
|
1765 |
+
max_seq_length=int(adv_max_seq_length_v or 2048),
|
1766 |
+
lora_r=int(adv_lora_r_v or 16),
|
1767 |
+
lora_alpha=int(adv_lora_alpha_v or 32),
|
1768 |
+
lora_dropout=float(adv_lora_dropout_v or 0.05),
|
1769 |
+
mixed_precision=str(adv_mixed_precision_v or "bf16"),
|
1770 |
+
num_workers=int(adv_num_workers_v or 4),
|
1771 |
+
quantization_type=str(adv_quantization_type_v or "mxfp4"),
|
1772 |
+
max_grad_norm=float(adv_max_grad_norm_v or 1.0),
|
1773 |
+
logging_steps=int(adv_logging_steps_v or 10),
|
1774 |
+
eval_steps=int(adv_eval_steps_v or 100),
|
1775 |
+
save_steps=int(adv_save_steps_v or 500),
|
1776 |
+
)
|
1777 |
else:
|
1778 |
+
if str(adv_sm_mode_v) == "long_context":
|
1779 |
+
cfg_path = generate_smollm3_long_context_config_file(
|
1780 |
+
model_name=str(adv_sm_lc_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
|
1781 |
+
dataset_name=(str(adv_sm_lc_dataset_name_v) if adv_sm_lc_dataset_name_v else None),
|
1782 |
+
input_field=str(adv_sm_lc_input_field_v or "prompt"),
|
1783 |
+
target_field=str(adv_sm_lc_target_field_v or "completion"),
|
1784 |
+
filter_bad_entries=bool(adv_sm_lc_filter_bad_entries_v),
|
1785 |
+
sample_size=(int(adv_sm_lc_sample_size_v) if adv_sm_lc_sample_size_v else None),
|
1786 |
+
sample_seed=int(adv_sm_lc_sample_seed_v or 42),
|
1787 |
+
max_seq_length=int(adv_sm_lc_max_seq_length_v or 131072),
|
1788 |
+
batch_size=int(adv_sm_lc_batch_size_v or 1),
|
1789 |
+
gradient_accumulation_steps=int(adv_sm_lc_gas_v or 8),
|
1790 |
+
learning_rate=float(adv_sm_lc_learning_rate_v or 1e-5),
|
1791 |
+
warmup_steps=int(adv_sm_lc_warmup_steps_v or 200),
|
1792 |
+
max_iters=int(adv_sm_lc_max_iters_v or 500),
|
1793 |
+
save_steps=int(adv_sm_lc_save_steps_v or 100),
|
1794 |
+
eval_steps=int(adv_sm_lc_eval_steps_v or 50),
|
1795 |
+
logging_steps=int(adv_sm_lc_logging_steps_v or 10),
|
1796 |
+
use_chat_template=bool(adv_sm_lc_use_chat_template_v),
|
1797 |
+
no_think_system_message=bool(adv_sm_lc_no_think_system_message_v),
|
1798 |
+
trainer_type=str(trainer_type_v).lower(),
|
1799 |
+
)
|
1800 |
+
else:
|
1801 |
+
cfg_path = generate_smollm3_custom_config_file(
|
1802 |
+
model_name=str(adv_sm_model_name_v or "HuggingFaceTB/SmolLM3-3B"),
|
1803 |
+
dataset_name=(str(adv_sm_dataset_name_v) if adv_sm_dataset_name_v else None),
|
1804 |
+
max_seq_length=int(adv_sm_max_seq_length_v or 4096),
|
1805 |
+
batch_size=int(adv_sm_batch_size_v or 2),
|
1806 |
+
gradient_accumulation_steps=int(adv_sm_gas_v or 8),
|
1807 |
+
learning_rate=float(adv_sm_learning_rate_v or 5e-6),
|
1808 |
+
save_steps=int(adv_sm_save_steps_v or 500),
|
1809 |
+
eval_steps=int(adv_sm_eval_steps_v or 100),
|
1810 |
+
logging_steps=int(adv_sm_logging_steps_v or 10),
|
1811 |
+
filter_bad_entries=bool(adv_sm_filter_bad_entries_v),
|
1812 |
+
input_field=str(adv_sm_input_field_v or "prompt"),
|
1813 |
+
target_field=str(adv_sm_target_field_v or "completion"),
|
1814 |
+
sample_size=(int(adv_sm_sample_size_v) if adv_sm_sample_size_v else None),
|
1815 |
+
sample_seed=int(adv_sm_sample_seed_v or 42),
|
1816 |
+
trainer_type=str(trainer_type_v).lower(),
|
1817 |
+
)
|
1818 |
override_path = str(cfg_path)
|
1819 |
except Exception as e:
|
1820 |
# Surface error in logs via generator
|
|
|
1871 |
min_lr,
|
1872 |
min_lr_rate,
|
1873 |
advanced_enabled,
|
1874 |
+
adv_gpt_mode,
|
1875 |
# GPT-OSS advanced
|
1876 |
adv_dataset_name,
|
1877 |
adv_dataset_split,
|
|
|
1902 |
adv_logging_steps,
|
1903 |
adv_eval_steps,
|
1904 |
adv_save_steps,
|
1905 |
+
# GPT-OSS Medical o1 SFT
|
1906 |
+
adv_med_dataset_config,
|
1907 |
+
adv_med_system_message,
|
1908 |
+
adv_med_developer_message,
|
1909 |
+
adv_med_num_train_epochs,
|
1910 |
+
adv_med_batch_size,
|
1911 |
+
adv_med_gradient_accumulation_steps,
|
1912 |
+
adv_med_learning_rate,
|
1913 |
+
adv_med_max_seq_length,
|
1914 |
# SmolLM3 advanced
|
1915 |
+
adv_sm_mode,
|
1916 |
adv_sm_model_name,
|
1917 |
adv_sm_dataset_name,
|
1918 |
adv_sm_input_field,
|
|
|
1927 |
adv_sm_save_steps,
|
1928 |
adv_sm_eval_steps,
|
1929 |
adv_sm_logging_steps,
|
1930 |
+
# SmolLM3 long context
|
1931 |
+
adv_sm_lc_model_name,
|
1932 |
+
adv_sm_lc_dataset_name,
|
1933 |
+
adv_sm_lc_input_field,
|
1934 |
+
adv_sm_lc_target_field,
|
1935 |
+
adv_sm_lc_filter_bad_entries,
|
1936 |
+
adv_sm_lc_sample_size,
|
1937 |
+
adv_sm_lc_sample_seed,
|
1938 |
+
adv_sm_lc_max_seq_length,
|
1939 |
+
adv_sm_lc_batch_size,
|
1940 |
+
adv_sm_lc_gas,
|
1941 |
+
adv_sm_lc_learning_rate,
|
1942 |
+
adv_sm_lc_warmup_steps,
|
1943 |
+
adv_sm_lc_max_iters,
|
1944 |
+
adv_sm_lc_save_steps,
|
1945 |
+
adv_sm_lc_eval_steps,
|
1946 |
+
adv_sm_lc_logging_steps,
|
1947 |
+
adv_sm_lc_use_chat_template,
|
1948 |
+
adv_sm_lc_no_think_system_message,
|
1949 |
],
|
1950 |
outputs=[logs],
|
1951 |
)
|