Spaces:
Running
Running
improves interface flow
Browse files- interface.py +181 -233
interface.py
CHANGED
@@ -829,10 +829,91 @@ joinus = """
|
|
829 |
"""
|
830 |
|
831 |
|
832 |
-
def on_family_change(family: str)
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
confs = list(get_config_map(family).keys())
|
834 |
exp, repo_short, desc, space = ui_defaults(family)
|
835 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
|
837 |
|
838 |
def start_pipeline(
|
@@ -932,243 +1013,110 @@ with gr.Blocks(title="SmolLM3 / GPT-OSS Fine-tuning Pipeline") as demo:
|
|
932 |
)
|
933 |
gr.Markdown(joinus)
|
934 |
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="Model family")
|
964 |
-
trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="Trainer type")
|
965 |
-
monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="both", label="Monitoring mode")
|
966 |
-
|
967 |
-
config_choice = gr.Dropdown(choices=list(get_config_map("SmolLM3").keys()), value="Basic Training", label="Training configuration")
|
968 |
-
|
969 |
-
exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
|
970 |
-
with gr.Row():
|
971 |
-
experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
|
972 |
-
repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
|
973 |
-
|
974 |
-
with gr.Row():
|
975 |
-
author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
|
976 |
-
model_description = gr.Textbox(value=desc_default, label="Model description")
|
977 |
-
|
978 |
-
with gr.Row():
|
979 |
-
trackio_space_name = gr.Textbox(value=trackio_space_default, label="Trackio Space name (used when monitoring != none)")
|
980 |
-
deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space")
|
981 |
-
create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo")
|
982 |
-
|
983 |
-
with gr.Row():
|
984 |
-
push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
|
985 |
-
switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
|
986 |
-
|
987 |
-
gr.Markdown("### Medical SFT (GPT-OSS o1)")
|
988 |
-
gr.Markdown("Configure GPT-OSS Medical o1 SFT (FreedomIntelligence/medical-o1-reasoning-SFT)")
|
989 |
-
med_dataset_config = gr.Dropdown(choices=["en", "en_mix", "zh", "zh_mix"], value="en", label="Dataset config")
|
990 |
-
med_system = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="System message", lines=2)
|
991 |
-
med_developer = gr.Textbox(value="You are are GPT-Tonic, an intelligent assistant that always answers health-related queries scientifically.", label="Developer message", lines=3)
|
992 |
-
with gr.Row():
|
993 |
-
med_epochs = gr.Number(value=2.0, precision=2, label="Epochs")
|
994 |
-
med_bs = gr.Number(value=4, precision=0, label="Batch size")
|
995 |
-
med_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
|
996 |
-
med_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
997 |
-
med_msl = gr.Number(value=2048, precision=0, label="Max seq length")
|
998 |
-
med_generate = gr.Button("Generate Medical Config")
|
999 |
-
med_status = gr.Textbox(label="Generated config path", interactive=False)
|
1000 |
-
|
1001 |
-
logs = gr.Textbox(value="", label="Logs", lines=20)
|
1002 |
-
start_btn = gr.Button("Run Pipeline")
|
1003 |
-
|
1004 |
-
with gr.Tab("Advanced Config"):
|
1005 |
-
with gr.Accordion("GPT-OSS Scheduler Overrides", open=False):
|
1006 |
-
scheduler_override = gr.Dropdown(choices=[c for c in SCHEDULER_CHOICES if c is not None], value=None, allow_custom_value=True, label="Scheduler override")
|
1007 |
-
min_lr = gr.Number(value=None, precision=6, label="min_lr (when cosine_with_min_lr)")
|
1008 |
-
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (when cosine_with_min_lr)")
|
1009 |
-
|
1010 |
-
gr.Markdown("### GPT-OSS Custom Dataset")
|
1011 |
-
with gr.Row():
|
1012 |
-
cds_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset name")
|
1013 |
-
cds_split = gr.Textbox(value="train", label="Split")
|
1014 |
-
cds_format = gr.Dropdown(choices=["openhermes_fr", "messages", "text", "medical_o1_sft", "custom", "preference"], value="openhermes_fr", label="Format")
|
1015 |
-
with gr.Row():
|
1016 |
-
cds_input = gr.Textbox(value="prompt", label="Input field")
|
1017 |
-
cds_target = gr.Textbox(value="accepted_completion", label="Target field (optional, blank for None)")
|
1018 |
-
with gr.Row():
|
1019 |
-
cds_sys = gr.Textbox(value="", label="System message (optional)")
|
1020 |
-
cds_dev = gr.Textbox(value="", label="Developer message (optional)")
|
1021 |
-
with gr.Row():
|
1022 |
-
cds_identity = gr.Textbox(value="You are GPT-Tonic, a large language model trained by TonicAI.", label="Model identity (chat_template_kwargs.model_identity)")
|
1023 |
-
with gr.Row():
|
1024 |
-
cds_max_samples = gr.Number(value=None, precision=0, label="Max samples (optional)")
|
1025 |
-
cds_min_len = gr.Number(value=10, precision=0, label="Min length")
|
1026 |
-
cds_max_len = gr.Number(value=None, precision=0, label="Max length (optional)")
|
1027 |
-
gr.Markdown("#### Training Hyperparameters")
|
1028 |
-
with gr.Row():
|
1029 |
-
cds_epochs = gr.Number(value=1.0, precision=2, label="Epochs")
|
1030 |
-
cds_bs = gr.Number(value=4, precision=0, label="Batch size")
|
1031 |
-
cds_gas = gr.Number(value=4, precision=0, label="Grad accumulation")
|
1032 |
-
cds_lr = gr.Number(value=2e-4, precision=6, label="Learning rate")
|
1033 |
-
cds_minlr = gr.Number(value=2e-5, precision=6, label="Min LR")
|
1034 |
-
with gr.Row():
|
1035 |
-
cds_wd = gr.Number(value=0.01, precision=6, label="Weight decay")
|
1036 |
-
cds_warm = gr.Number(value=0.03, precision=6, label="Warmup ratio")
|
1037 |
-
cds_msl = gr.Number(value=2048, precision=0, label="Max seq length")
|
1038 |
-
gr.Markdown("#### LoRA / Precision / Quantization / Perf")
|
1039 |
-
with gr.Row():
|
1040 |
-
cds_lora_r = gr.Number(value=16, precision=0, label="LoRA r")
|
1041 |
-
cds_lora_alpha = gr.Number(value=32, precision=0, label="LoRA alpha")
|
1042 |
-
cds_lora_dropout = gr.Number(value=0.05, precision=4, label="LoRA dropout")
|
1043 |
-
with gr.Row():
|
1044 |
-
cds_precision = gr.Dropdown(choices=["bf16", "fp16", "fp32"], value="bf16", label="Mixed precision")
|
1045 |
-
cds_workers = gr.Number(value=4, precision=0, label="Data workers")
|
1046 |
-
cds_quant = gr.Dropdown(choices=["mxfp4", "bnb4", "none"], value="mxfp4", label="Quantization")
|
1047 |
-
with gr.Row():
|
1048 |
-
cds_mgn = gr.Number(value=1.0, precision=4, label="Max grad norm")
|
1049 |
-
cds_log_steps = gr.Number(value=10, precision=0, label="Logging steps")
|
1050 |
-
cds_eval_steps = gr.Number(value=100, precision=0, label="Eval steps")
|
1051 |
-
cds_save_steps = gr.Number(value=500, precision=0, label="Save steps")
|
1052 |
-
cds_generate = gr.Button("Generate GPT-OSS Custom Config")
|
1053 |
-
cds_status = gr.Textbox(label="Generated config path", interactive=False)
|
1054 |
-
|
1055 |
-
gr.Markdown("### SmolLM3 Custom Configuration")
|
1056 |
-
with gr.Row():
|
1057 |
-
sm_model = gr.Textbox(value="HuggingFaceTB/SmolLM3-3B", label="Model name")
|
1058 |
-
sm_dataset = gr.Textbox(value="legmlai/openhermes-fr", label="Dataset (optional; leave blank for local)")
|
1059 |
-
with gr.Row():
|
1060 |
-
sm_msl = gr.Number(value=4096, precision=0, label="Max seq length")
|
1061 |
-
sm_bs = gr.Number(value=2, precision=0, label="Batch size")
|
1062 |
-
sm_gas = gr.Number(value=8, precision=0, label="Grad accumulation")
|
1063 |
-
sm_lr = gr.Number(value=5e-6, precision=8, label="Learning rate")
|
1064 |
-
with gr.Row():
|
1065 |
-
sm_save = gr.Number(value=500, precision=0, label="Save steps")
|
1066 |
-
sm_eval = gr.Number(value=100, precision=0, label="Eval steps")
|
1067 |
-
sm_log = gr.Number(value=10, precision=0, label="Logging steps")
|
1068 |
-
with gr.Row():
|
1069 |
-
sm_filter = gr.Checkbox(value=False, label="Filter bad entries")
|
1070 |
-
sm_in = gr.Textbox(value="prompt", label="Input field")
|
1071 |
-
sm_out = gr.Textbox(value="accepted_completion", label="Target field")
|
1072 |
-
with gr.Row():
|
1073 |
-
sm_sample = gr.Number(value=None, precision=0, label="Sample size (optional)")
|
1074 |
-
sm_seed = gr.Number(value=42, precision=0, label="Sample seed")
|
1075 |
-
sm_trainer = gr.Dropdown(choices=["SFT", "DPO"], value="SFT", label="Trainer type")
|
1076 |
-
sm_generate = gr.Button("Generate SmolLM3 Custom Config")
|
1077 |
-
sm_status = gr.Textbox(label="Generated config path", interactive=False)
|
1078 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1079 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
1080 |
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
)
|
1099 |
-
),
|
1100 |
-
inputs=[med_dataset_config, med_system, med_developer, med_epochs, med_bs, med_gas, med_lr, med_msl],
|
1101 |
-
outputs=[med_status],
|
1102 |
)
|
1103 |
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
input_field=ifld,
|
1111 |
-
target_field=(tfld or None),
|
1112 |
-
system_message=sm,
|
1113 |
-
developer_message=dm,
|
1114 |
-
model_identity=ident,
|
1115 |
-
max_samples=(int(ms) if ms is not None else None),
|
1116 |
-
min_length=int(minl or 10),
|
1117 |
-
max_length=(int(maxl) if maxl is not None else None),
|
1118 |
-
num_train_epochs=float(ep or 1.0),
|
1119 |
-
batch_size=int(bs or 4),
|
1120 |
-
gradient_accumulation_steps=int(gas or 4),
|
1121 |
-
learning_rate=float(lr or 2e-4),
|
1122 |
-
min_lr=float(minlr or 2e-5),
|
1123 |
-
weight_decay=float(wd or 0.01),
|
1124 |
-
warmup_ratio=float(warm or 0.03),
|
1125 |
-
max_seq_length=int(msl or 2048),
|
1126 |
-
lora_r=int(lr_),
|
1127 |
-
lora_alpha=int(la),
|
1128 |
-
lora_dropout=float(ld),
|
1129 |
-
mixed_precision=prec,
|
1130 |
-
num_workers=int(nw or 4),
|
1131 |
-
quantization_type=q,
|
1132 |
-
max_grad_norm=float(mgn or 1.0),
|
1133 |
-
logging_steps=int(logst or 10),
|
1134 |
-
eval_steps=int(evst or 100),
|
1135 |
-
save_steps=int(savst or 500),
|
1136 |
-
)
|
1137 |
-
),
|
1138 |
-
inputs=[
|
1139 |
-
cds_dataset, cds_split, cds_format, cds_input, cds_target, cds_sys, cds_dev, cds_identity,
|
1140 |
-
cds_max_samples, cds_min_len, cds_max_len, cds_epochs, cds_bs, cds_gas, cds_lr, cds_minlr, cds_wd,
|
1141 |
-
cds_warm, cds_msl, cds_lora_r, cds_lora_alpha, cds_lora_dropout, cds_precision, cds_workers, cds_quant,
|
1142 |
-
cds_mgn, cds_log_steps, cds_eval_steps, cds_save_steps
|
1143 |
-
],
|
1144 |
-
outputs=[cds_status],
|
1145 |
)
|
1146 |
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
dataset_name=(dn or None),
|
1152 |
-
max_seq_length=int(msl or 4096),
|
1153 |
-
batch_size=int(bs or 2),
|
1154 |
-
gradient_accumulation_steps=int(gas or 8),
|
1155 |
-
learning_rate=float(lr or 5e-6),
|
1156 |
-
save_steps=int(sst or 500),
|
1157 |
-
eval_steps=int(est or 100),
|
1158 |
-
logging_steps=int(lst or 10),
|
1159 |
-
filter_bad_entries=bool(fbe),
|
1160 |
-
input_field=ifld,
|
1161 |
-
target_field=tfld,
|
1162 |
-
sample_size=(int(ss) if ss is not None else None),
|
1163 |
-
sample_seed=int(seed or 42),
|
1164 |
-
trainer_type=tt,
|
1165 |
-
)
|
1166 |
-
),
|
1167 |
-
inputs=[
|
1168 |
-
sm_model, sm_dataset, sm_msl, sm_bs, sm_gas, sm_lr, sm_save, sm_eval, sm_log,
|
1169 |
-
sm_filter, sm_in, sm_out, sm_sample, sm_seed, sm_trainer,
|
1170 |
-
],
|
1171 |
-
outputs=[sm_status],
|
1172 |
)
|
1173 |
|
1174 |
start_btn.click(
|
@@ -1199,6 +1147,6 @@ if __name__ == "__main__":
|
|
1199 |
# Optional: allow setting server parameters via env
|
1200 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
1201 |
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
1202 |
-
demo.queue().launch(server_name=server_name, server_port=server_port)
|
1203 |
|
1204 |
|
|
|
829 |
"""
|
830 |
|
831 |
|
832 |
+
def on_family_change(family: str):
|
833 |
+
"""Update UI when the model family changes.
|
834 |
+
|
835 |
+
- Refresh available prebuilt configuration choices
|
836 |
+
- Reset defaults (experiment name, repo short, description, space name)
|
837 |
+
- Reveal the next step (trainer type)
|
838 |
+
"""
|
839 |
confs = list(get_config_map(family).keys())
|
840 |
exp, repo_short, desc, space = ui_defaults(family)
|
841 |
+
|
842 |
+
# Initial dataset information placeholder until a specific config is chosen
|
843 |
+
training_md = (
|
844 |
+
f"Select a training configuration for {family} to see details (dataset, batch size, etc.)."
|
845 |
+
)
|
846 |
+
|
847 |
+
# Update objects:
|
848 |
+
return (
|
849 |
+
gr.update(choices=confs, value=(confs[0] if confs else None)),
|
850 |
+
exp,
|
851 |
+
repo_short,
|
852 |
+
desc,
|
853 |
+
space,
|
854 |
+
training_md,
|
855 |
+
gr.update(choices=[], value=None),
|
856 |
+
gr.update(visible=True), # show step 2 (trainer)
|
857 |
+
gr.update(visible=False), # hide step 3 until trainer selected
|
858 |
+
gr.update(visible=False), # hide step 4 until monitoring selected
|
859 |
+
gr.update(visible=(family == "GPT-OSS")), # advanced (scheduler) visibility
|
860 |
+
)
|
861 |
+
|
862 |
+
|
863 |
+
def on_config_change(family: str, config_choice: str):
|
864 |
+
"""When a prebuilt configuration is selected, update dataset info and helpful details."""
|
865 |
+
if not config_choice:
|
866 |
+
return (
|
867 |
+
"",
|
868 |
+
gr.update(choices=[], value=None),
|
869 |
+
)
|
870 |
+
|
871 |
+
conf_map = get_config_map(family)
|
872 |
+
cfg_path = PROJECT_ROOT / conf_map[config_choice]["config_file"]
|
873 |
+
cfg_obj = import_config_object(cfg_path)
|
874 |
+
|
875 |
+
dataset_name = getattr(cfg_obj, "dataset_name", None) if cfg_obj else None
|
876 |
+
batch_size = getattr(cfg_obj, "batch_size", None) if cfg_obj else None
|
877 |
+
learning_rate = getattr(cfg_obj, "learning_rate", None) if cfg_obj else None
|
878 |
+
max_seq_length = getattr(cfg_obj, "max_seq_length", None) if cfg_obj else None
|
879 |
+
base_model = conf_map[config_choice]["default_model"]
|
880 |
+
|
881 |
+
md_lines = [
|
882 |
+
f"**Configuration**: {config_choice}",
|
883 |
+
f"**Base model**: {base_model}",
|
884 |
+
]
|
885 |
+
if dataset_name:
|
886 |
+
md_lines.append(f"**Dataset**: `{dataset_name}`")
|
887 |
+
if batch_size is not None:
|
888 |
+
md_lines.append(f"**Batch size**: {batch_size}")
|
889 |
+
if learning_rate is not None:
|
890 |
+
md_lines.append(f"**Learning rate**: {learning_rate}")
|
891 |
+
if max_seq_length is not None:
|
892 |
+
md_lines.append(f"**Max seq length**: {max_seq_length}")
|
893 |
+
|
894 |
+
training_md = "\n".join(md_lines)
|
895 |
+
|
896 |
+
# dataset selection (allow custom but prefill with the config's dataset if any)
|
897 |
+
ds_choices = [dataset_name] if dataset_name else []
|
898 |
+
|
899 |
+
return training_md, gr.update(choices=ds_choices, value=(dataset_name or None))
|
900 |
+
|
901 |
+
|
902 |
+
def on_trainer_selected(_: str):
|
903 |
+
"""Reveal monitoring step once trainer type is chosen."""
|
904 |
+
return gr.update(visible=True)
|
905 |
+
|
906 |
+
|
907 |
+
def on_monitoring_change(mode: str):
|
908 |
+
"""Reveal configuration/details step and adjust Trackio-related visibility by mode."""
|
909 |
+
show_trackio = mode in ("both", "trackio")
|
910 |
+
show_dataset_repo = mode != "none"
|
911 |
+
return (
|
912 |
+
gr.update(visible=True),
|
913 |
+
gr.update(visible=show_trackio), # trackio space name
|
914 |
+
gr.update(visible=show_trackio), # deploy trackio space
|
915 |
+
gr.update(visible=show_dataset_repo), # create dataset repo
|
916 |
+
)
|
917 |
|
918 |
|
919 |
def start_pipeline(
|
|
|
1013 |
)
|
1014 |
gr.Markdown(joinus)
|
1015 |
|
1016 |
+
# --- Progressive interface --------------------------------------------------------
|
1017 |
+
gr.Markdown("### Configure your run in simple steps")
|
1018 |
+
|
1019 |
+
# Step 1: Model family
|
1020 |
+
with gr.Group():
|
1021 |
+
model_family = gr.Dropdown(choices=MODEL_FAMILIES, value="SmolLM3", label="1) Model family")
|
1022 |
+
|
1023 |
+
# Step 2: Trainer (revealed after family)
|
1024 |
+
step2_group = gr.Group(visible=False)
|
1025 |
+
with step2_group:
|
1026 |
+
trainer_type = gr.Radio(choices=TRAINER_CHOICES, value="SFT", label="2) Trainer type")
|
1027 |
+
|
1028 |
+
# Step 3: Monitoring (revealed after trainer)
|
1029 |
+
step3_group = gr.Group(visible=False)
|
1030 |
+
with step3_group:
|
1031 |
+
monitoring_mode = gr.Dropdown(choices=MONITORING_CHOICES, value="dataset", label="3) Monitoring mode")
|
1032 |
+
|
1033 |
+
# Step 4: Config & details (revealed after monitoring)
|
1034 |
+
step4_group = gr.Group(visible=False)
|
1035 |
+
with step4_group:
|
1036 |
+
# Defaults based on initial family selection
|
1037 |
+
exp_default, repo_default, desc_default, trackio_space_default = ui_defaults("SmolLM3")
|
1038 |
+
|
1039 |
+
config_choice = gr.Dropdown(
|
1040 |
+
choices=list(get_config_map("SmolLM3").keys()),
|
1041 |
+
value="Basic Training",
|
1042 |
+
label="4) Training configuration",
|
1043 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1044 |
|
1045 |
+
with gr.Tabs():
|
1046 |
+
with gr.Tab("Overview"):
|
1047 |
+
training_info = gr.Markdown("Select a training configuration to see details.")
|
1048 |
+
dataset_choice = gr.Dropdown(
|
1049 |
+
choices=[],
|
1050 |
+
value=None,
|
1051 |
+
allow_custom_value=True,
|
1052 |
+
label="Dataset (from config; optional)",
|
1053 |
+
)
|
1054 |
+
with gr.Row():
|
1055 |
+
experiment_name = gr.Textbox(value=exp_default, label="Experiment name")
|
1056 |
+
repo_short = gr.Textbox(value=repo_default, label="Model repo (short name)")
|
1057 |
+
with gr.Row():
|
1058 |
+
author_name = gr.Textbox(value=os.environ.get("HF_USERNAME", ""), label="Author name")
|
1059 |
+
model_description = gr.Textbox(value=desc_default, label="Model description")
|
1060 |
+
trackio_space_name = gr.Textbox(
|
1061 |
+
value=trackio_space_default,
|
1062 |
+
label="Trackio Space name (used when monitoring != none)",
|
1063 |
+
visible=False,
|
1064 |
+
)
|
1065 |
+
deploy_trackio_space = gr.Checkbox(value=True, label="Deploy Trackio Space", visible=False)
|
1066 |
+
create_dataset_repo = gr.Checkbox(value=True, label="Create/ensure HF Dataset repo", visible=True)
|
1067 |
+
with gr.Row():
|
1068 |
+
push_to_hub = gr.Checkbox(value=True, label="Push model to Hugging Face Hub")
|
1069 |
+
switch_to_read_after = gr.Checkbox(value=True, label="Switch Space token to READ after training")
|
1070 |
+
|
1071 |
+
with gr.Tab("Advanced"):
|
1072 |
+
# GPT-OSS specific scheduler overrides
|
1073 |
+
advanced_scheduler_group = gr.Group(visible=False)
|
1074 |
+
with advanced_scheduler_group:
|
1075 |
+
scheduler_override = gr.Dropdown(
|
1076 |
+
choices=[c for c in SCHEDULER_CHOICES if c is not None],
|
1077 |
+
value=None,
|
1078 |
+
allow_custom_value=True,
|
1079 |
+
label="Scheduler override",
|
1080 |
+
)
|
1081 |
+
with gr.Row():
|
1082 |
+
min_lr = gr.Number(value=None, precision=6, label="min_lr (cosine_with_min_lr)")
|
1083 |
+
min_lr_rate = gr.Number(value=None, precision=6, label="min_lr_rate (cosine_with_min_lr)")
|
1084 |
+
|
1085 |
+
# Final action & logs
|
1086 |
+
start_btn = gr.Button("Run Pipeline", variant="primary")
|
1087 |
logs = gr.Textbox(value="", label="Logs", lines=20)
|
1088 |
|
1089 |
+
# --- Events ---------------------------------------------------------------------
|
1090 |
+
model_family.change(
|
1091 |
+
on_family_change,
|
1092 |
+
inputs=model_family,
|
1093 |
+
outputs=[
|
1094 |
+
config_choice,
|
1095 |
+
experiment_name,
|
1096 |
+
repo_short,
|
1097 |
+
model_description,
|
1098 |
+
trackio_space_name,
|
1099 |
+
training_info,
|
1100 |
+
dataset_choice,
|
1101 |
+
step2_group,
|
1102 |
+
step3_group,
|
1103 |
+
step4_group,
|
1104 |
+
advanced_scheduler_group,
|
1105 |
+
],
|
|
|
|
|
|
|
|
|
1106 |
)
|
1107 |
|
1108 |
+
trainer_type.change(on_trainer_selected, inputs=trainer_type, outputs=step3_group)
|
1109 |
+
|
1110 |
+
monitoring_mode.change(
|
1111 |
+
on_monitoring_change,
|
1112 |
+
inputs=monitoring_mode,
|
1113 |
+
outputs=[step4_group, trackio_space_name, deploy_trackio_space, create_dataset_repo],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1114 |
)
|
1115 |
|
1116 |
+
config_choice.change(
|
1117 |
+
on_config_change,
|
1118 |
+
inputs=[model_family, config_choice],
|
1119 |
+
outputs=[training_info, dataset_choice],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
)
|
1121 |
|
1122 |
start_btn.click(
|
|
|
1147 |
# Optional: allow setting server parameters via env
|
1148 |
server_port = int(os.environ.get("INTERFACE_PORT", "7860"))
|
1149 |
server_name = os.environ.get("INTERFACE_HOST", "0.0.0.0")
|
1150 |
+
demo.queue().launch(server_name=server_name, server_port=server_port, mcp_server=True)
|
1151 |
|
1152 |
|