burtenshaw commited on
Commit
8f5cc68
Β·
1 Parent(s): d6ee53d

add push functionality and note about duplication

Browse files
Files changed (1) hide show
  1. app.py +132 -14
app.py CHANGED
@@ -68,9 +68,23 @@ def create_autotrain_params(
68
  epochs: int,
69
  batch_size: int,
70
  learning_rate: float,
 
 
71
  **kwargs,
72
  ):
73
  """Create AutoTrain parameter object based on task type"""
 
 
 
 
 
 
 
 
 
 
 
 
74
  common_params = {
75
  "model": base_model,
76
  "project_name": project_name,
@@ -94,6 +108,7 @@ def create_autotrain_params(
94
  "mixed_precision": "no",
95
  "save_total_limit": 1,
96
  "eval_strategy": "epoch",
 
97
  }
98
 
99
  if task == "text-classification":
@@ -114,12 +129,15 @@ def create_autotrain_params(
114
  "llm-reward": "reward",
115
  }
116
 
 
 
 
 
 
 
 
117
  return LLMTrainingParams(
118
- **{
119
- k: v
120
- for k, v in common_params.items()
121
- if k not in ["early_stopping_patience", "early_stopping_threshold"]
122
- },
123
  text_column=kwargs.get("text_column", "messages"),
124
  block_size=kwargs.get("block_size", 2048),
125
  peft=kwargs.get("use_peft", True),
@@ -245,6 +263,8 @@ def start_training_job(
245
  batch_size: str = "8",
246
  learning_rate: str = "2e-5",
247
  backend: str = "local",
 
 
248
  ) -> str:
249
  """
250
  Start a new AutoTrain training job.
@@ -260,6 +280,8 @@ def start_training_job(
260
  batch_size: Training batch size (default: 16)
261
  learning_rate: Learning rate for training (default: 2e-5)
262
  backend: Training backend to use (default: local)
 
 
263
 
264
  Returns:
265
  Status message with run ID and details
@@ -269,6 +291,7 @@ def start_training_job(
269
  epochs_int = int(epochs)
270
  batch_size_int = int(batch_size)
271
  learning_rate_float = float(learning_rate)
 
272
 
273
  # Generate run ID
274
  run_id = str(uuid.uuid4())
@@ -283,12 +306,16 @@ def start_training_job(
283
  "status": "pending",
284
  "created_at": datetime.utcnow().isoformat(),
285
  "updated_at": datetime.utcnow().isoformat(),
 
 
286
  "config": {
287
  "task": task,
288
  "epochs": epochs_int,
289
  "batch_size": batch_size_int,
290
  "learning_rate": learning_rate_float,
291
  "backend": backend,
 
 
292
  },
293
  }
294
 
@@ -306,6 +333,8 @@ def start_training_job(
306
  epochs=epochs_int,
307
  batch_size=batch_size_int,
308
  learning_rate=learning_rate_float,
 
 
309
  )
310
 
311
  # Start training in background
@@ -315,7 +344,8 @@ def start_training_job(
315
  thread.daemon = True
316
  thread.start()
317
 
318
- return f"""βœ… Training job submitted successfully!
 
319
 
320
  Run ID: {run_id}
321
  Project: {project_name}
@@ -327,7 +357,18 @@ Configuration:
327
  β€’ Epochs: {epochs}
328
  β€’ Batch Size: {batch_size}
329
  β€’ Learning Rate: {learning_rate}
330
- β€’ Backend: {backend}
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  πŸ”— Monitor progress:
333
  β€’ Gradio UI: http://localhost:7860
@@ -335,6 +376,8 @@ Configuration:
335
 
336
  πŸ’‘ Use get_training_runs() to check status"""
337
 
 
 
338
  except Exception as e:
339
  return f"❌ Error submitting job: {str(e)}"
340
 
@@ -449,6 +492,18 @@ def get_run_details(run_id: str) -> str:
449
  details_text += f"\nβ€’ Learning Rate: {config.get('learning_rate')}"
450
  details_text += f"\nβ€’ Backend: {config.get('backend')}"
451
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  return details_text
453
 
454
  except Exception as e:
@@ -656,6 +711,8 @@ def submit_training_job_ui(
656
  batch_size,
657
  learning_rate,
658
  backend,
 
 
659
  ):
660
  """Submit training job from web UI"""
661
  if not all([task, project_name, base_model, dataset_path]):
@@ -670,6 +727,8 @@ def submit_training_job_ui(
670
  batch_size=str(batch_size),
671
  learning_rate=str(learning_rate),
672
  backend=backend,
 
 
673
  )
674
 
675
  return result, fetch_runs_for_ui()
@@ -685,14 +744,42 @@ with gr.Blocks(
685
  }
686
  """,
687
  ) as app:
688
- gr.Markdown("""
689
  # πŸš€ AutoTrain Gradio MCP Server
690
 
691
- **All-in-One Solution:** Web UI + MCP Server + AutoTrain Integration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
- β€’ **Web Interface**: Manage training jobs through this UI
694
- β€’ **MCP Server**: AI assistants can use tools at `http://localhost:7860/gradio_api/mcp/sse`
695
- β€’ **Direct Integration**: No FastAPI needed - everything runs in Gradio
696
  """)
697
 
698
  with gr.Tabs():
@@ -716,6 +803,11 @@ with gr.Blocks(
716
  with gr.Tab("πŸƒ Start Training"):
717
  gr.Markdown("## Submit New Training Job")
718
 
 
 
 
 
 
719
  with gr.Row():
720
  with gr.Column():
721
  task_dropdown = gr.Dropdown(
@@ -750,6 +842,13 @@ with gr.Blocks(
750
  value="local",
751
  )
752
 
 
 
 
 
 
 
 
753
  submit_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
754
  submit_output = gr.Textbox(label="Status", interactive=False, lines=10)
755
 
@@ -765,13 +864,27 @@ with gr.Blocks(
765
 
766
  ### Available MCP Tools:
767
 
768
- - `start_training_job` - Submit new training jobs
769
  - `get_training_runs` - List all runs with status
770
  - `get_run_details` - Get detailed run information
771
- - `delete_training_run` - Delete training runs
772
  - `get_task_recommendations` - Get training recommendations
773
  - `get_system_status` - Check system status
774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  ### Claude Desktop Configuration:
776
 
777
  ```json
@@ -788,6 +901,7 @@ with gr.Blocks(
788
 
789
  Total Runs: {len(load_runs())}
790
  W&B Project: {WANDB_PROJECT}
 
791
  """)
792
 
793
  # MCP Tools Tab
@@ -825,6 +939,8 @@ with gr.Blocks(
825
  gr.Textbox(label="batch_size", value="8"),
826
  gr.Textbox(label="learning_rate", value="2e-5"),
827
  gr.Textbox(label="backend", value="local"),
 
 
828
  ],
829
  outputs=gr.Textbox(label="Training Job Result"),
830
  title="start_training_job",
@@ -875,6 +991,8 @@ with gr.Blocks(
875
  batch_size,
876
  learning_rate,
877
  backend,
 
 
878
  ],
879
  outputs=[submit_output, runs_table],
880
  )
 
68
  epochs: int,
69
  batch_size: int,
70
  learning_rate: float,
71
+ push_to_hub: bool,
72
+ hub_repo_id: str = "",
73
  **kwargs,
74
  ):
75
  """Create AutoTrain parameter object based on task type"""
76
+ # Hub configuration
77
+ hub_config = {}
78
+ if push_to_hub:
79
+ hub_config = {
80
+ "push_to_hub": True,
81
+ "username": os.environ.get("HF_USERNAME", ""),
82
+ "token": os.environ.get("HF_TOKEN", ""),
83
+ }
84
+ # If custom repo_id is provided, use it; otherwise use project_name
85
+ if hub_repo_id:
86
+ hub_config["repo_id"] = hub_repo_id
87
+
88
  common_params = {
89
  "model": base_model,
90
  "project_name": project_name,
 
108
  "mixed_precision": "no",
109
  "save_total_limit": 1,
110
  "eval_strategy": "epoch",
111
+ **hub_config, # Add hub configuration
112
  }
113
 
114
  if task == "text-classification":
 
129
  "llm-reward": "reward",
130
  }
131
 
132
+ # For LLM tasks, exclude some parameters that don't apply
133
+ llm_params = {
134
+ k: v
135
+ for k, v in common_params.items()
136
+ if k not in ["early_stopping_patience", "early_stopping_threshold"]
137
+ }
138
+
139
  return LLMTrainingParams(
140
+ **llm_params,
 
 
 
 
141
  text_column=kwargs.get("text_column", "messages"),
142
  block_size=kwargs.get("block_size", 2048),
143
  peft=kwargs.get("use_peft", True),
 
263
  batch_size: str = "8",
264
  learning_rate: str = "2e-5",
265
  backend: str = "local",
266
+ push_to_hub: str = "false",
267
+ hub_repo_id: str = "",
268
  ) -> str:
269
  """
270
  Start a new AutoTrain training job.
 
280
  batch_size: Training batch size (default: 16)
281
  learning_rate: Learning rate for training (default: 2e-5)
282
  backend: Training backend to use (default: local)
283
+ push_to_hub: Whether to push final model to Hub (true/false)
284
+ hub_repo_id: Custom repository ID for Hub (optional)
285
 
286
  Returns:
287
  Status message with run ID and details
 
291
  epochs_int = int(epochs)
292
  batch_size_int = int(batch_size)
293
  learning_rate_float = float(learning_rate)
294
+ push_to_hub_bool = push_to_hub.lower() == "true"
295
 
296
  # Generate run ID
297
  run_id = str(uuid.uuid4())
 
306
  "status": "pending",
307
  "created_at": datetime.utcnow().isoformat(),
308
  "updated_at": datetime.utcnow().isoformat(),
309
+ "push_to_hub": push_to_hub_bool,
310
+ "hub_repo_id": hub_repo_id,
311
  "config": {
312
  "task": task,
313
  "epochs": epochs_int,
314
  "batch_size": batch_size_int,
315
  "learning_rate": learning_rate_float,
316
  "backend": backend,
317
+ "push_to_hub": push_to_hub_bool,
318
+ "hub_repo_id": hub_repo_id,
319
  },
320
  }
321
 
 
333
  epochs=epochs_int,
334
  batch_size=batch_size_int,
335
  learning_rate=learning_rate_float,
336
+ push_to_hub=push_to_hub_bool,
337
+ hub_repo_id=hub_repo_id,
338
  )
339
 
340
  # Start training in background
 
344
  thread.daemon = True
345
  thread.start()
346
 
347
+ # Build result message
348
+ result_msg = f"""βœ… Training job submitted successfully!
349
 
350
  Run ID: {run_id}
351
  Project: {project_name}
 
357
  β€’ Epochs: {epochs}
358
  β€’ Batch Size: {batch_size}
359
  β€’ Learning Rate: {learning_rate}
360
+ β€’ Backend: {backend}"""
361
+
362
+ if push_to_hub_bool:
363
+ final_repo = hub_repo_id if hub_repo_id else project_name
364
+ result_msg += f"""
365
+ β€’ Push to Hub: βœ… Enabled
366
+ β€’ Repository: {final_repo}
367
+ β€’ Requires: HF_USERNAME and HF_TOKEN environment variables"""
368
+ else:
369
+ result_msg += "\nβ€’ Push to Hub: ❌ Disabled"
370
+
371
+ result_msg += """
372
 
373
  πŸ”— Monitor progress:
374
  β€’ Gradio UI: http://localhost:7860
 
376
 
377
  πŸ’‘ Use get_training_runs() to check status"""
378
 
379
+ return result_msg
380
+
381
  except Exception as e:
382
  return f"❌ Error submitting job: {str(e)}"
383
 
 
492
  details_text += f"\nβ€’ Learning Rate: {config.get('learning_rate')}"
493
  details_text += f"\nβ€’ Backend: {config.get('backend')}"
494
 
495
+ # Hub configuration
496
+ if config.get("push_to_hub"):
497
+ details_text += "\nβ€’ Push to Hub: βœ… Enabled"
498
+ if config.get("hub_repo_id"):
499
+ details_text += f"\nβ€’ Hub Repository: {config.get('hub_repo_id')}"
500
+ else:
501
+ details_text += (
502
+ f"\nβ€’ Hub Repository: {run['project_name']} (default)"
503
+ )
504
+ else:
505
+ details_text += "\nβ€’ Push to Hub: ❌ Disabled"
506
+
507
  return details_text
508
 
509
  except Exception as e:
 
711
  batch_size,
712
  learning_rate,
713
  backend,
714
+ push_to_hub,
715
+ hub_repo_id,
716
  ):
717
  """Submit training job from web UI"""
718
  if not all([task, project_name, base_model, dataset_path]):
 
727
  batch_size=str(batch_size),
728
  learning_rate=str(learning_rate),
729
  backend=backend,
730
+ push_to_hub=str(push_to_hub).lower(),
731
+ hub_repo_id=hub_repo_id,
732
  )
733
 
734
  return result, fetch_runs_for_ui()
 
744
  }
745
  """,
746
  ) as app:
747
+ gr.Markdown(f"""
748
  # πŸš€ AutoTrain Gradio MCP Server
749
 
750
+ Get your AI models to train your AI models!
751
+
752
+ This space is an MCP server that you can use in Claude Desktop, Cursor, VSCode, etc to train your AI models.
753
+
754
+ :warning: To train models you with need to duplicate this space!
755
+ **MCP Server**: AI assistants can use tools at http://SPACE_URL/gradio_api/mcp/sse
756
+
757
+ Connect to it like this:
758
+
759
+ ```json
760
+ {"mcpServers": {"autotrain": {"url": "http://SPACE_URL/gradio_api/mcp/sse",
761
+ "headers": {"Authorization": "Bearer <YOUR-HUGGING-FACE-TOKEN>"
762
+ }
763
+ }
764
+ }
765
+ }
766
+ ```
767
+
768
+ Or like this for Claude Desktop:
769
+
770
+ ```json
771
+ {"mcpServers": {"hf-mcp-server": {"command": "npx",
772
+ "args": [
773
+ "mcp-remote",
774
+ "http://SPACE_URL/gradio_api/mcp/sse",
775
+ "--header",
776
+ "Authorization: Bearer <YOUR-HUGGING-FACE-TOKEN>"
777
+ ]
778
+ }
779
+ }
780
+ }
781
+ ```
782
 
 
 
 
783
  """)
784
 
785
  with gr.Tabs():
 
803
  with gr.Tab("πŸƒ Start Training"):
804
  gr.Markdown("## Submit New Training Job")
805
 
806
+ gr.Markdown("""
807
+ πŸ’‘ **Hub Integration**: Enable "Push to Hub" to automatically upload your trained model to Hugging Face Hub.
808
+ Requires `HF_USERNAME` and `HF_TOKEN` environment variables.
809
+ """)
810
+
811
  with gr.Row():
812
  with gr.Column():
813
  task_dropdown = gr.Dropdown(
 
842
  value="local",
843
  )
844
 
845
+ with gr.Row():
846
+ with gr.Column():
847
+ push_to_hub = gr.Checkbox(label="Push to Hub", value=False)
848
+ hub_repo_id = gr.Textbox(
849
+ label="Hub Repository ID", placeholder="your-repo-id"
850
+ )
851
+
852
  submit_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
853
  submit_output = gr.Textbox(label="Status", interactive=False, lines=10)
854
 
 
864
 
865
  ### Available MCP Tools:
866
 
867
+ - `start_training_job` - Submit new training jobs (includes Hub push)
868
  - `get_training_runs` - List all runs with status
869
  - `get_run_details` - Get detailed run information
 
870
  - `get_task_recommendations` - Get training recommendations
871
  - `get_system_status` - Check system status
872
 
873
+ ### πŸ€— Hugging Face Hub Integration:
874
+
875
+ To push models to the Hub, set these environment variables:
876
+
877
+ ```bash
878
+ export HF_USERNAME="your-hf-username"
879
+ export HF_TOKEN="your-hf-write-token"
880
+ ```
881
+
882
+ Get your token from: https://huggingface.co/settings/tokens
883
+
884
+ **Usage Examples:**
885
+ - `push_to_hub="true"` - Push to Hub using project name as repo
886
+ - `hub_repo_id="my-org/my-model"` - Push to custom repository
887
+
888
  ### Claude Desktop Configuration:
889
 
890
  ```json
 
901
 
902
  Total Runs: {len(load_runs())}
903
  W&B Project: {WANDB_PROJECT}
904
+ Hub Auth: {"βœ… Configured" if os.environ.get("HF_TOKEN") else "❌ Missing HF_TOKEN"}
905
  """)
906
 
907
  # MCP Tools Tab
 
939
  gr.Textbox(label="batch_size", value="8"),
940
  gr.Textbox(label="learning_rate", value="2e-5"),
941
  gr.Textbox(label="backend", value="local"),
942
+ gr.Textbox(label="push_to_hub", value="false"),
943
+ gr.Textbox(label="hub_repo_id", placeholder="your-repo-id"),
944
  ],
945
  outputs=gr.Textbox(label="Training Job Result"),
946
  title="start_training_job",
 
991
  batch_size,
992
  learning_rate,
993
  backend,
994
+ push_to_hub,
995
+ hub_repo_id,
996
  ],
997
  outputs=[submit_output, runs_table],
998
  )