Tonic commited on
Commit
a8fc78d
·
1 Parent(s): b867691

improves interface flow

Browse files
Files changed (1) hide show
  1. interface.py +181 -233
interface.py CHANGED
@@ -829,10 +829,91 @@ joinus = """
829
  """
830
 
831
 
832
- def on_family_change(family: str) -> Tuple[list[str], str, str, str, str]:
 
 
 
 
 
 
833
  confs = list(get_config_map(family).keys())
834
  exp, repo_short, desc, space = ui_defaults(family)
835
- return confs, confs[0] if confs else "", exp, repo_short, desc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
 
837
 
838
  def start_pipeline(
@@ -932,243 +1013,110 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
932
  )
933
  gr.Markdown(joinus)
934
 
935
- with gr.Row():
936
- model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="Model family")
937
- trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="Trainer type")
938
- monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="both", label="Monitoring mode")
939
-
940
- config_choice = gr.Dropdown(choices=list(get_config_map("SmolLM3").keys()), value="Basic Training", label="Training configuration")
941
-
942
- exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
943
- with gr.Row():
944
- experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
945
- repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
946
-
947
- with gr.Row():
948
- author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
949
- model_description = gr.Textbox(value=desc_default, label="Model description")
950
-
951
- with gr.Row():
952
- trackio_space_name = gr.Textbox(value=trackio_space_default, label="Trackio Space name (used when monitoring != none)")
953
- deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space")
954
- create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo")
955
-
956
- with gr.Row():
957
- push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
958
- switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
959
-
960
- with gr.Tabs():
961
- with gr.Tab("Run"):
962
- with gr.Row():
963
- model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="Model family")
964
- trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="Trainer type")
965
- monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="both", label="Monitoring mode")
966
-
967
- config_choice = gr.Dropdown(choices=list(get_config_map("SmolLM3").keys()), value="Basic Training", label="Training configuration")
968
-
969
- exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
970
- with gr.Row():
971
- experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
972
- repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
973
-
974
- with gr.Row():
975
- author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
976
- model_description = gr.Textbox(value=desc_default, label="Model description")
977
-
978
- with gr.Row():
979
- trackio_space_name = gr.Textbox(value=trackio_space_default, label="Trackio Space name (used when monitoring != none)")
980
- deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space")
981
- create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo")
982
-
983
- with gr.Row():
984
- push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
985
- switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
986
-
987
- gr.Markdown("### Medical SFT (GPT-OSS o1)")
988
- gr.Markdown("Configure GPT-OSS Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT)")
989
- med_dataset_config = gr.Dropdown(choices=["en", "en_mix", "zh", "zh_mix"], value="en", label="Dataset config")
990
- med_system = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="System message", lines=2)
991
- med_developer = gr.Textbox(value="You are are GPT-Tonic, an intelligent assistant that always answers health-related queries scientifically.", label="Developer message", lines=3)
992
- with gr.Row():
993
- med_epochs = gr.Number(value=2.0, precision=2, label="Epochs")
994
- med_bs = gr.Number(value=4, precision=0, label="Batch size")
995
- med_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
996
- med_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
997
- med_msl = gr.Number(value=2048, precision=0, label="Max seq length")
998
- med_generate = gr.Button("Generate Medical Config")
999
- med_status = gr.Textbox(label="Generated config path", interactive=False)
1000
-
1001
- logs = gr.Textbox(value="", label="Logs", lines=20)
1002
- start_btn = gr.Button("Run Pipeline")
1003
-
1004
- with gr.Tab("Advanced Config"):
1005
- with gr.Accordion("GPT-OSS Scheduler Overrides", open=False):
1006
- scheduler_override = gr.Dropdown(choices=[c for c in SCHEDULER_CHOICES if c is not None], value=None, allow_custom_value=True, label="Scheduler override")
1007
- min_lr = gr.Number(value=None, precision=6, label="min_lr (when cosine_with_min_lr)")
1008
- min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (when cosine_with_min_lr)")
1009
-
1010
- gr.Markdown("### GPT-OSS Custom Dataset")
1011
- with gr.Row():
1012
- cds_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset name")
1013
- cds_split = gr.Textbox(value="train", label="Split")
1014
- cds_format = gr.Dropdown(choices=["openhermes_fr", "messages", "text", "medical_o1_sft", "custom", "preference"], value="openhermes_fr", label="Format")
1015
- with gr.Row():
1016
- cds_input = gr.Textbox(value="prompt", label="Input field")
1017
- cds_target = gr.Textbox(value="accepted_completion", label="Target field (optional, blank for None)")
1018
- with gr.Row():
1019
- cds_sys = gr.Textbox(value="", label="System message (optional)")
1020
- cds_dev = gr.Textbox(value="", label="Developer message (optional)")
1021
- with gr.Row():
1022
- cds_identity = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="Model identity (chat_template_kwargs.model_identity)")
1023
- with gr.Row():
1024
- cds_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
1025
- cds_min_len = gr.Number(value=10, precision=0, label="Min length")
1026
- cds_max_len = gr.Number(value=None, precision=0, label="Max length (optional)")
1027
- gr.Markdown("#### Training Hyperparameters")
1028
- with gr.Row():
1029
- cds_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
1030
- cds_bs = gr.Number(value=4, precision=0, label="Batch size")
1031
- cds_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
1032
- cds_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
1033
- cds_minlr = gr.Number(value=2e-5, precision=6, label="Min LR")
1034
- with gr.Row():
1035
- cds_wd = gr.Number(value=0.01, precision=6, label="Weight decay")
1036
- cds_warm = gr.Number(value=0.03, precision=6, label="Warmup ratio")
1037
- cds_msl = gr.Number(value=2048, precision=0, label="Max seq length")
1038
- gr.Markdown("#### LoRA / Precision / Quantization / Perf")
1039
- with gr.Row():
1040
- cds_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
1041
- cds_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
1042
- cds_lora_dropout = gr.Number(value=0.05, precision=4, label="LoRA dropout")
1043
- with gr.Row():
1044
- cds_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
1045
- cds_workers = gr.Number(value=4, precision=0, label="Data workers")
1046
- cds_quant = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
1047
- with gr.Row():
1048
- cds_mgn = gr.Number(value=1.0, precision=4, label="Max grad norm")
1049
- cds_log_steps = gr.Number(value=10, precision=0, label="Logging steps")
1050
- cds_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
1051
- cds_save_steps = gr.Number(value=500, precision=0, label="Save steps")
1052
- cds_generate = gr.Button("Generate GPT-OSS Custom Config")
1053
- cds_status = gr.Textbox(label="Generated config path", interactive=False)
1054
-
1055
- gr.Markdown("### SmolLM3 Custom Configuration")
1056
- with gr.Row():
1057
- sm_model = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
1058
- sm_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset (optional; leave blank for local)")
1059
- with gr.Row():
1060
- sm_msl = gr.Number(value=4096, precision=0, label="Max seq length")
1061
- sm_bs = gr.Number(value=2, precision=0, label="Batch size")
1062
- sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
1063
- sm_lr = gr.Number(value=5e-6, precision=8, label="Learning rate")
1064
- with gr.Row():
1065
- sm_save = gr.Number(value=500, precision=0, label="Save steps")
1066
- sm_eval = gr.Number(value=100, precision=0, label="Eval steps")
1067
- sm_log = gr.Number(value=10, precision=0, label="Logging steps")
1068
- with gr.Row():
1069
- sm_filter = gr.Checkbox(value=False, label="Filter bad entries")
1070
- sm_in = gr.Textbox(value="prompt", label="Input field")
1071
- sm_out = gr.Textbox(value="accepted_completion", label="Target field")
1072
- with gr.Row():
1073
- sm_sample = gr.Number(value=None, precision=0, label="Sample size (optional)")
1074
- sm_seed = gr.Number(value=42, precision=0, label="Sample seed")
1075
- sm_trainer = gr.Dropdown(choices=["SFT", "DPO"], value="SFT", label="Trainer type")
1076
- sm_generate = gr.Button("Generate SmolLM3 Custom Config")
1077
- sm_status = gr.Textbox(label="Generated config path", interactive=False)
1078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1079
  logs = gr.Textbox(value="", label="Logs", lines=20)
1080
 
1081
- start_btn = gr.Button("Run Pipeline")
1082
-
1083
- # Events
1084
- model_family.change(on_family_change, inputs=model_family, outputs=[config_choice, config_choice, experiment_name, repo_short, model_description])
1085
-
1086
- # Generate config handlers
1087
- med_generate.click(
1088
- lambda dc, sysm, devm, ep, bs, gas, lr, msl: str(
1089
- generate_medical_o1_config_file(
1090
- dataset_config=dc,
1091
- system_message=sysm,
1092
- developer_message=devm,
1093
- num_train_epochs=float(ep or 2.0),
1094
- batch_size=int(bs or 4),
1095
- gradient_accumulation_steps=int(gas or 4),
1096
- learning_rate=float(lr or 2e-4),
1097
- max_seq_length=int(msl or 2048),
1098
- )
1099
- ),
1100
- inputs=[med_dataset_config, med_system, med_developer, med_epochs, med_bs, med_gas, med_lr, med_msl],
1101
- outputs=[med_status],
1102
  )
1103
 
1104
- cds_generate.click(
1105
- lambda dname, dsplit, dformat, ifld, tfld, sm, dm, ident, ms, minl, maxl, ep, bs, gas, lr, minlr, wd, warm, msl, lr_, la, ld, prec, nw, q, mgn, logst, evst, savst: str(
1106
- generate_gpt_oss_custom_config_file(
1107
- dataset_name=dname,
1108
- dataset_split=dsplit,
1109
- dataset_format=dformat,
1110
- input_field=ifld,
1111
- target_field=(tfld or None),
1112
- system_message=sm,
1113
- developer_message=dm,
1114
- model_identity=ident,
1115
- max_samples=(int(ms) if ms is not None else None),
1116
- min_length=int(minl or 10),
1117
- max_length=(int(maxl) if maxl is not None else None),
1118
- num_train_epochs=float(ep or 1.0),
1119
- batch_size=int(bs or 4),
1120
- gradient_accumulation_steps=int(gas or 4),
1121
- learning_rate=float(lr or 2e-4),
1122
- min_lr=float(minlr or 2e-5),
1123
- weight_decay=float(wd or 0.01),
1124
- warmup_ratio=float(warm or 0.03),
1125
- max_seq_length=int(msl or 2048),
1126
- lora_r=int(lr_),
1127
- lora_alpha=int(la),
1128
- lora_dropout=float(ld),
1129
- mixed_precision=prec,
1130
- num_workers=int(nw or 4),
1131
- quantization_type=q,
1132
- max_grad_norm=float(mgn or 1.0),
1133
- logging_steps=int(logst or 10),
1134
- eval_steps=int(evst or 100),
1135
- save_steps=int(savst or 500),
1136
- )
1137
- ),
1138
- inputs=[
1139
- cds_dataset, cds_split, cds_format, cds_input, cds_target, cds_sys, cds_dev, cds_identity,
1140
- cds_max_samples, cds_min_len, cds_max_len, cds_epochs, cds_bs, cds_gas, cds_lr, cds_minlr, cds_wd,
1141
- cds_warm, cds_msl, cds_lora_r, cds_lora_alpha, cds_lora_dropout, cds_precision, cds_workers, cds_quant,
1142
- cds_mgn, cds_log_steps, cds_eval_steps, cds_save_steps
1143
- ],
1144
- outputs=[cds_status],
1145
  )
1146
 
1147
- sm_generate.click(
1148
- lambda mn, dn, msl, bs, gas, lr, sst, est, lst, fbe, ifld, tfld, ss, seed, tt: str(
1149
- generate_smollm3_custom_config_file(
1150
- model_name=mn,
1151
- dataset_name=(dn or None),
1152
- max_seq_length=int(msl or 4096),
1153
- batch_size=int(bs or 2),
1154
- gradient_accumulation_steps=int(gas or 8),
1155
- learning_rate=float(lr or 5e-6),
1156
- save_steps=int(sst or 500),
1157
- eval_steps=int(est or 100),
1158
- logging_steps=int(lst or 10),
1159
- filter_bad_entries=bool(fbe),
1160
- input_field=ifld,
1161
- target_field=tfld,
1162
- sample_size=(int(ss) if ss is not None else None),
1163
- sample_seed=int(seed or 42),
1164
- trainer_type=tt,
1165
- )
1166
- ),
1167
- inputs=[
1168
- sm_model, sm_dataset, sm_msl, sm_bs, sm_gas, sm_lr, sm_save, sm_eval, sm_log,
1169
- sm_filter, sm_in, sm_out, sm_sample, sm_seed, sm_trainer,
1170
- ],
1171
- outputs=[sm_status],
1172
  )
1173
 
1174
  start_btn.click(
@@ -1199,6 +1147,6 @@ if __name__ == "__main__":
1199
  # Optional: allow setting server parameters via env
1200
  server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
1201
  server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
1202
- demo.queue().launch(server_name=server_name, server_port=server_port)
1203
 
1204
 
 
829
  """
830
 
831
 
832
+ def on_family_change(family: str):
833
+ """Update UI when the model family changes.
834
+
835
+ - Refresh available prebuilt configuration choices
836
+ - Reset defaults (experiment name, repo short, description, space name)
837
+ - Reveal the next step (trainer type)
838
+ """
839
  confs = list(get_config_map(family).keys())
840
  exp, repo_short, desc, space = ui_defaults(family)
841
+
842
+ # Initial dataset information placeholder until a specific config is chosen
843
+ training_md = (
844
+ f"Select a training configuration for {family} to see details (dataset, batch size, etc.)."
845
+ )
846
+
847
+ # Update objects:
848
+ return (
849
+ gr.update(choices=confs, value=(confs[0] if confs else None)),
850
+ exp,
851
+ repo_short,
852
+ desc,
853
+ space,
854
+ training_md,
855
+ gr.update(choices=[], value=None),
856
+ gr.update(visible=True), # show step 2 (trainer)
857
+ gr.update(visible=False), # hide step 3 until trainer selected
858
+ gr.update(visible=False), # hide step 4 until monitoring selected
859
+ gr.update(visible=(family == "GPT-OSS")), # advanced (scheduler) visibility
860
+ )
861
+
862
+
863
+ def on_config_change(family: str, config_choice: str):
864
+ """When a prebuilt configuration is selected, update dataset info and helpful details."""
865
+ if not config_choice:
866
+ return (
867
+ "",
868
+ gr.update(choices=[], value=None),
869
+ )
870
+
871
+ conf_map = get_config_map(family)
872
+ cfg_path = PROJECT_ROOT / conf_map[config_choice]["config_file"]
873
+ cfg_obj = import_config_object(cfg_path)
874
+
875
+ dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
876
+ batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
877
+ learning_rate = getattr(cfg_obj, "learning_rate", None) if cfg_obj else None
878
+ max_seq_length = getattr(cfg_obj, "max_seq_length", None) if cfg_obj else None
879
+ base_model = conf_map[config_choice]["default_model"]
880
+
881
+ md_lines = [
882
+ f"**Configuration**: {config_choice}",
883
+ f"**Base model**: {base_model}",
884
+ ]
885
+ if dataset_name:
886
+ md_lines.append(f"**Dataset**: `{dataset_name}`")
887
+ if batch_size is not None:
888
+ md_lines.append(f"**Batch size**: {batch_size}")
889
+ if learning_rate is not None:
890
+ md_lines.append(f"**Learning rate**: {learning_rate}")
891
+ if max_seq_length is not None:
892
+ md_lines.append(f"**Max seq length**: {max_seq_length}")
893
+
894
+ training_md = "\n".join(md_lines)
895
+
896
+ # dataset selection (allow custom but prefill with the config's dataset if any)
897
+ ds_choices = [dataset_name] if dataset_name else []
898
+
899
+ return training_md, gr.update(choices=ds_choices, value=(dataset_name or None))
900
+
901
+
902
+ def on_trainer_selected(_: str):
903
+ """Reveal monitoring step once trainer type is chosen."""
904
+ return gr.update(visible=True)
905
+
906
+
907
+ def on_monitoring_change(mode: str):
908
+ """Reveal configuration/details step and adjust Trackio-related visibility by mode."""
909
+ show_trackio = mode in ("both", "trackio")
910
+ show_dataset_repo = mode != "none"
911
+ return (
912
+ gr.update(visible=True),
913
+ gr.update(visible=show_trackio), # trackio space name
914
+ gr.update(visible=show_trackio), # deploy trackio space
915
+ gr.update(visible=show_dataset_repo), # create dataset repo
916
+ )
917
 
918
 
919
  def start_pipeline(
 
1013
  )
1014
  gr.Markdown(joinus)
1015
 
1016
+ # --- Progressive interface --------------------------------------------------------
1017
+ gr.Markdown("### Configure your run in simple steps")
1018
+
1019
+ # Step 1: Model family
1020
+ with gr.Group():
1021
+ model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="1) Model family")
1022
+
1023
+ # Step 2: Trainer (revealed after family)
1024
+ step2_group = gr.Group(visible=False)
1025
+ with step2_group:
1026
+ trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="2) Trainer type")
1027
+
1028
+ # Step 3: Monitoring (revealed after trainer)
1029
+ step3_group = gr.Group(visible=False)
1030
+ with step3_group:
1031
+ monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="dataset", label="3) Monitoring mode")
1032
+
1033
+ # Step 4: Config & details (revealed after monitoring)
1034
+ step4_group = gr.Group(visible=False)
1035
+ with step4_group:
1036
+ # Defaults based on initial family selection
1037
+ exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
1038
+
1039
+ config_choice = gr.Dropdown(
1040
+ choices=list(get_config_map("SmolLM3").keys()),
1041
+ value="Basic Training",
1042
+ label="4) Training configuration",
1043
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1044
 
1045
+ with gr.Tabs():
1046
+ with gr.Tab("Overview"):
1047
+ training_info = gr.Markdown("Select a training configuration to see details.")
1048
+ dataset_choice = gr.Dropdown(
1049
+ choices=[],
1050
+ value=None,
1051
+ allow_custom_value=True,
1052
+ label="Dataset (from config; optional)",
1053
+ )
1054
+ with gr.Row():
1055
+ experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
1056
+ repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
1057
+ with gr.Row():
1058
+ author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
1059
+ model_description = gr.Textbox(value=desc_default, label="Model description")
1060
+ trackio_space_name = gr.Textbox(
1061
+ value=trackio_space_default,
1062
+ label="Trackio Space name (used when monitoring != none)",
1063
+ visible=False,
1064
+ )
1065
+ deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space", visible=False)
1066
+ create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo", visible=True)
1067
+ with gr.Row():
1068
+ push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
1069
+ switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
1070
+
1071
+ with gr.Tab("Advanced"):
1072
+ # GPT-OSS specific scheduler overrides
1073
+ advanced_scheduler_group = gr.Group(visible=False)
1074
+ with advanced_scheduler_group:
1075
+ scheduler_override = gr.Dropdown(
1076
+ choices=[c for c in SCHEDULER_CHOICES if c is not None],
1077
+ value=None,
1078
+ allow_custom_value=True,
1079
+ label="Scheduler override",
1080
+ )
1081
+ with gr.Row():
1082
+ min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
1083
+ min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
1084
+
1085
+ # Final action & logs
1086
+ start_btn = gr.Button("Run Pipeline", variant="primary")
1087
  logs = gr.Textbox(value="", label="Logs", lines=20)
1088
 
1089
+ # --- Events ---------------------------------------------------------------------
1090
+ model_family.change(
1091
+ on_family_change,
1092
+ inputs=model_family,
1093
+ outputs=[
1094
+ config_choice,
1095
+ experiment_name,
1096
+ repo_short,
1097
+ model_description,
1098
+ trackio_space_name,
1099
+ training_info,
1100
+ dataset_choice,
1101
+ step2_group,
1102
+ step3_group,
1103
+ step4_group,
1104
+ advanced_scheduler_group,
1105
+ ],
 
 
 
 
1106
  )
1107
 
1108
+ trainer_type.change(on_trainer_selected, inputs=trainer_type, outputs=step3_group)
1109
+
1110
+ monitoring_mode.change(
1111
+ on_monitoring_change,
1112
+ inputs=monitoring_mode,
1113
+ outputs=[step4_group, trackio_space_name, deploy_trackio_space, create_dataset_repo],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114
  )
1115
 
1116
+ config_choice.change(
1117
+ on_config_change,
1118
+ inputs=[model_family, config_choice],
1119
+ outputs=[training_info, dataset_choice],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  )
1121
 
1122
  start_btn.click(
 
1147
  # Optional: allow setting server parameters via env
1148
  server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
1149
  server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
1150
+ demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
1151
 
1152