Spaces:
Running
on
A10G
Running
on
A10G
""" | |
AutoTrain Gradio MCP Server - All-in-One | |
This single Gradio app: | |
1. Provides a web interface for managing AutoTrain jobs | |
2. Automatically exposes MCP tools at /gradio_api/mcp/sse | |
3. Handles all AutoTrain operations directly (no FastAPI needed) | |
""" | |
import os | |
import json | |
import uuid | |
import threading | |
from datetime import datetime | |
from typing import List, Dict, Any | |
import socket | |
import gradio as gr | |
import pandas as pd | |
import wandb | |
from autotrain.project import AutoTrainProject | |
from autotrain.params import ( | |
LLMTrainingParams, | |
TextClassificationParams, | |
ImageClassificationParams, | |
) | |
# Simple JSON-based storage (replace with SQLite if needed) | |
RUNS_FILE = "training_runs.json" | |
WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "autotrain-mcp") | |
def load_runs() -> List[Dict[str, Any]]: | |
"""Load training runs from JSON file""" | |
if os.path.exists(RUNS_FILE): | |
try: | |
with open(RUNS_FILE, "r") as f: | |
return json.load(f) | |
except (json.JSONDecodeError, IOError): | |
return [] | |
return [] | |
def save_runs(runs: List[Dict[str, Any]]): | |
"""Save training runs to JSON file""" | |
with open(RUNS_FILE, "w") as f: | |
json.dump(runs, f, indent=2) | |
def get_status_emoji(status: str) -> str: | |
"""Get emoji for training status""" | |
emoji_map = { | |
"pending": "β³", | |
"running": "π", | |
"completed": "β ", | |
"failed": "β", | |
"cancelled": "βΉοΈ", | |
} | |
return emoji_map.get(status.lower(), "β") | |
def create_autotrain_params( | |
task: str, | |
base_model: str, | |
project_name: str, | |
dataset_path: str, | |
epochs: int, | |
batch_size: int, | |
learning_rate: float, | |
push_to_hub: bool, | |
hub_repo_id: str = "", | |
**kwargs, | |
): | |
"""Create AutoTrain parameter object based on task type""" | |
# Hub configuration | |
hub_config = {} | |
if push_to_hub: | |
hub_config = { | |
"push_to_hub": True, | |
"username": os.environ.get("HF_USERNAME", ""), | |
"token": os.environ.get("HF_TOKEN", ""), | |
} | |
# If custom repo_id is provided, use it; otherwise use project_name | |
if hub_repo_id: | |
hub_config["repo_id"] = hub_repo_id | |
common_params = { | |
"model": base_model, | |
"project_name": project_name, | |
"data_path": dataset_path, | |
"train_split": kwargs.get("train_split", "train"), | |
"valid_split": kwargs.get("valid_split"), | |
"epochs": epochs, | |
"batch_size": batch_size, | |
"lr": learning_rate, | |
"log": "wandb", | |
# Required defaults | |
"warmup_ratio": 0.1, | |
"gradient_accumulation": 1, | |
"optimizer": "adamw_torch", | |
"scheduler": "linear", | |
"weight_decay": 0.01, | |
"max_grad_norm": 1.0, | |
"seed": 42, | |
"logging_steps": 10, | |
"auto_find_batch_size": False, | |
"mixed_precision": "no", | |
"save_total_limit": 1, | |
"eval_strategy": "epoch", | |
**hub_config, # Add hub configuration | |
} | |
if task == "text-classification": | |
return TextClassificationParams( | |
**common_params, | |
text_column=kwargs.get("text_column", "text"), | |
target_column=kwargs.get("target_column", "label"), | |
max_seq_length=kwargs.get("max_seq_length", 128), | |
early_stopping_patience=3, | |
early_stopping_threshold=0.01, | |
) | |
elif task.startswith("llm-"): | |
trainer_map = { | |
"llm-sft": "sft", | |
"llm-dpo": "dpo", | |
"llm-orpo": "orpo", | |
"llm-reward": "reward", | |
} | |
# For LLM tasks, exclude some parameters that don't apply | |
llm_params = { | |
k: v | |
for k, v in common_params.items() | |
if k not in ["early_stopping_patience", "early_stopping_threshold"] | |
} | |
return LLMTrainingParams( | |
**llm_params, | |
text_column=kwargs.get("text_column", "messages"), | |
block_size=kwargs.get("block_size", 2048), | |
peft=kwargs.get("use_peft", True), | |
quantization=kwargs.get("quantization", "int4"), | |
trainer=trainer_map[task], | |
chat_template="tokenizer", | |
# LLM-specific defaults | |
add_eos_token=True, | |
model_max_length=2048, | |
padding="right", | |
use_flash_attention_2=False, | |
disable_gradient_checkpointing=False, | |
target_modules="all-linear", | |
merge_adapter=False, | |
lora_r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
model_ref=None, | |
dpo_beta=0.1, | |
max_prompt_length=512, | |
max_completion_length=1024, | |
prompt_text_column="prompt", | |
rejected_text_column="rejected", | |
unsloth=False, | |
distributed_backend="accelerate", | |
) | |
elif task == "image-classification": | |
return ImageClassificationParams( | |
**common_params, | |
image_column=kwargs.get("image_column", "image"), | |
target_column=kwargs.get("target_column", "label"), | |
) | |
else: | |
raise ValueError(f"Unsupported task type: {task}") | |
def run_training_background(run_id: str, params: Any, backend: str): | |
"""Run training job in background thread""" | |
runs = load_runs() | |
# Update status to running | |
for run in runs: | |
if run["run_id"] == run_id: | |
run["status"] = "running" | |
run["started_at"] = datetime.utcnow().isoformat() | |
break | |
save_runs(runs) | |
try: | |
# Set W&B environment variables for AutoTrain to use | |
os.environ["WANDB_PROJECT"] = WANDB_PROJECT | |
print(f"Starting real training for run {run_id}") | |
print(f"Model: {params.model}") | |
print(f"Dataset: {params.data_path}") | |
print(f"Backend: {backend}") | |
# Create AutoTrain project - this will handle W&B internally | |
project = AutoTrainProject(params=params, backend=backend, process=True) | |
# Actually run the training - this blocks until completion | |
print(f"Executing training job for run {run_id}...") | |
result = project.create() | |
print(f"Training completed successfully for run {run_id}") | |
print(f"Result: {result}") | |
# Get the actual W&B run URL after training starts | |
wandb_url = f"https://wandb.ai/{WANDB_PROJECT}" | |
try: | |
if wandb.run is not None: | |
wandb_url = wandb.run.url | |
print(f"Got actual W&B URL: {wandb_url}") | |
else: | |
print("No active W&B run found, using default URL") | |
except Exception as e: | |
print(f"Could not get W&B URL: {e}") | |
# Update with actual W&B URL | |
runs = load_runs() | |
for run in runs: | |
if run["run_id"] == run_id: | |
run["wandb_url"] = wandb_url | |
break | |
save_runs(runs) | |
# Update status to completed | |
runs = load_runs() | |
for run in runs: | |
if run["run_id"] == run_id: | |
run["status"] = "completed" | |
run["completed_at"] = datetime.utcnow().isoformat() | |
if result: | |
run["result"] = str(result) | |
break | |
save_runs(runs) | |
except Exception as e: | |
print(f"Training failed for run {run_id}: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
# Update status to failed | |
runs = load_runs() | |
for run in runs: | |
if run["run_id"] == run_id: | |
run["status"] = "failed" | |
run["error_message"] = str(e) | |
run["completed_at"] = datetime.utcnow().isoformat() | |
break | |
save_runs(runs) | |
# MCP Tool Functions (these automatically become MCP tools) | |
def start_training_job( | |
task: str = "text-classification", | |
project_name: str = "test-project", | |
base_model: str = "distilbert-base-uncased", | |
dataset_path: str = "imdb", | |
epochs: str = "1", | |
batch_size: str = "8", | |
learning_rate: str = "2e-5", | |
backend: str = "local", | |
push_to_hub: str = "false", | |
hub_repo_id: str = "", | |
) -> str: | |
""" | |
Start a new AutoTrain training job. | |
Args: | |
task: Type of training task (text-classification, llm-sft, | |
llm-dpo, llm-orpo, image-classification) | |
project_name: Name for the training project | |
base_model: Base model from Hugging Face Hub | |
(e.g., distilbert-base-uncased) | |
dataset_path: Dataset path or HF dataset name (e.g., imdb) | |
epochs: Number of training epochs (default: 3) | |
batch_size: Training batch size (default: 16) | |
learning_rate: Learning rate for training (default: 2e-5) | |
backend: Training backend to use (default: local) | |
push_to_hub: Whether to push final model to Hub (true/false) | |
hub_repo_id: Custom repository ID for Hub (optional) | |
Returns: | |
Status message with run ID and details | |
""" | |
try: | |
# Convert string parameters | |
epochs_int = int(epochs) | |
batch_size_int = int(batch_size) | |
learning_rate_float = float(learning_rate) | |
push_to_hub_bool = push_to_hub.lower() == "true" | |
# Generate run ID | |
run_id = str(uuid.uuid4()) | |
# Create run record | |
run_data = { | |
"run_id": run_id, | |
"project_name": project_name, | |
"task": task, | |
"base_model": base_model, | |
"dataset_path": dataset_path, | |
"status": "pending", | |
"created_at": datetime.utcnow().isoformat(), | |
"updated_at": datetime.utcnow().isoformat(), | |
"push_to_hub": push_to_hub_bool, | |
"hub_repo_id": hub_repo_id, | |
"config": { | |
"task": task, | |
"epochs": epochs_int, | |
"batch_size": batch_size_int, | |
"learning_rate": learning_rate_float, | |
"backend": backend, | |
"push_to_hub": push_to_hub_bool, | |
"hub_repo_id": hub_repo_id, | |
}, | |
} | |
# Save to storage | |
runs = load_runs() | |
runs.append(run_data) | |
save_runs(runs) | |
# Create AutoTrain parameters | |
params = create_autotrain_params( | |
task=task, | |
base_model=base_model, | |
project_name=project_name, | |
dataset_path=dataset_path, | |
epochs=epochs_int, | |
batch_size=batch_size_int, | |
learning_rate=learning_rate_float, | |
push_to_hub=push_to_hub_bool, | |
hub_repo_id=hub_repo_id, | |
) | |
# Start training in background | |
thread = threading.Thread( | |
target=run_training_background, args=(run_id, params, backend) | |
) | |
thread.daemon = True | |
thread.start() | |
# Build result message | |
result_msg = f"""β Training job submitted successfully! | |
Run ID: {run_id} | |
Project: {project_name} | |
Task: {task} | |
Model: {base_model} | |
Dataset: {dataset_path} | |
Configuration: | |
β’ Epochs: {epochs} | |
β’ Batch Size: {batch_size} | |
β’ Learning Rate: {learning_rate} | |
β’ Backend: {backend}""" | |
if push_to_hub_bool: | |
final_repo = hub_repo_id if hub_repo_id else project_name | |
result_msg += f""" | |
β’ Push to Hub: β Enabled | |
β’ Repository: {final_repo} | |
β’ Requires: HF_USERNAME and HF_TOKEN environment variables""" | |
else: | |
result_msg += "\nβ’ Push to Hub: β Disabled" | |
result_msg += """ | |
π Monitor progress: | |
β’ Gradio UI: http://localhost:7860 | |
β’ W&B tracking will be available once training starts | |
π‘ Use get_training_runs() to check status""" | |
return result_msg | |
except Exception as e: | |
return f"β Error submitting job: {str(e)}" | |
def get_training_runs(limit: str = "20", status: str = "") -> str: | |
""" | |
Get list of training runs with their status and details. | |
Args: | |
limit: Maximum number of runs to return (default: 20) | |
status: Filter by run status (pending, running, completed, | |
failed, cancelled) | |
Returns: | |
Formatted list of training runs with status and links | |
""" | |
try: | |
runs = load_runs() | |
# Filter by status if provided | |
if status: | |
runs = [run for run in runs if run.get("status") == status] | |
# Apply limit | |
runs = runs[-int(limit) :] | |
if not runs: | |
return "No training runs found. Start a new training job to see it here!" | |
runs_text = f"π Training Runs (showing {len(runs)}):\n\n" | |
for run in reversed(runs): # Show newest first | |
status_emoji = get_status_emoji(run["status"]) | |
# Format run display with line break | |
run_display = ( | |
f"{status_emoji} **{run['project_name']}** ({run['run_id'][:8]}...)" | |
) | |
runs_text += f"{run_display}\n" | |
runs_text += f" Task: {run['task']}\n" | |
runs_text += f" Model: {run['base_model']}\n" | |
runs_text += f" Status: {run['status'].title()}\n" | |
runs_text += f" Created: {run['created_at']}\n" | |
if run.get("wandb_url"): | |
runs_text += f" π W&B: {run['wandb_url']}\n" | |
if run.get("error_message"): | |
runs_text += f" β Error: {run['error_message']}\n" | |
runs_text += "\n" | |
return runs_text | |
except Exception as e: | |
return f"β Error fetching runs: {str(e)}" | |
def get_run_details(run_id: str) -> str: | |
""" | |
Get detailed information about a specific training run. | |
Args: | |
run_id: ID of the training run (can be partial ID) | |
Returns: | |
Detailed run information including config and status | |
""" | |
try: | |
runs = load_runs() | |
# Find run by full or partial ID | |
found_run = None | |
for run in runs: | |
if run["run_id"] == run_id or run["run_id"].startswith(run_id): | |
found_run = run | |
break | |
if not found_run: | |
return f"β Training run {run_id} not found" | |
run = found_run | |
details_text = f"""π Training Run Details | |
**Run ID:** {run["run_id"]} | |
**Project:** {run["project_name"]} | |
**Task:** {run["task"]} | |
**Model:** {run["base_model"]} | |
**Dataset:** {run["dataset_path"]} | |
**Status:** {run["status"].title()} | |
**Timestamps:** | |
β’ Created: {run["created_at"]} | |
β’ Updated: {run.get("updated_at", "N/A")}""" | |
if run.get("started_at"): | |
details_text += f"\nβ’ Started: {run['started_at']}" | |
if run.get("completed_at"): | |
details_text += f"\nβ’ Completed: {run['completed_at']}" | |
if run.get("wandb_url"): | |
details_text += f"\n\nπ **W&B Dashboard:** {run['wandb_url']}" | |
if run.get("error_message"): | |
details_text += f"\n\nβ **Error:** {run['error_message']}" | |
if run.get("config"): | |
config = run["config"] | |
details_text += "\n\nβοΈ **Training Configuration:**" | |
details_text += f"\nβ’ Epochs: {config.get('epochs')}" | |
details_text += f"\nβ’ Batch Size: {config.get('batch_size')}" | |
details_text += f"\nβ’ Learning Rate: {config.get('learning_rate')}" | |
details_text += f"\nβ’ Backend: {config.get('backend')}" | |
# Hub configuration | |
if config.get("push_to_hub"): | |
details_text += "\nβ’ Push to Hub: β Enabled" | |
if config.get("hub_repo_id"): | |
details_text += f"\nβ’ Hub Repository: {config.get('hub_repo_id')}" | |
else: | |
details_text += ( | |
f"\nβ’ Hub Repository: {run['project_name']} (default)" | |
) | |
else: | |
details_text += "\nβ’ Push to Hub: β Disabled" | |
return details_text | |
except Exception as e: | |
return f"β Error fetching run details: {str(e)}" | |
def get_task_recommendations( | |
task: str = "text-classification", dataset_size: str = "medium" | |
) -> str: | |
""" | |
Get training recommendations for a specific task type. | |
Args: | |
task: Task type (text-classification, llm-sft, image-classification) | |
dataset_size: Size of dataset (small, medium, large) | |
Returns: | |
Recommended models, parameters, and best practices | |
""" | |
recommendations = { | |
"text-classification": { | |
"models": ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"], | |
"params": {"batch_size": 16, "learning_rate": 2e-5, "epochs": 3}, | |
"backends": ["local", "spaces-t4-small"], | |
"notes": [ | |
"Good for sentiment analysis", | |
"Works well with IMDB, AG News datasets", | |
], | |
}, | |
"llm-sft": { | |
"models": [ | |
"microsoft/DialoGPT-medium", | |
"HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
], | |
"params": {"batch_size": 1, "learning_rate": 1e-5, "epochs": 3}, | |
"backends": ["spaces-t4-medium", "spaces-a10g-large"], | |
"notes": ["Use PEFT for efficiency", "Ensure proper chat formatting"], | |
}, | |
"image-classification": { | |
"models": ["google/vit-base-patch16-224", "microsoft/resnet-50"], | |
"params": {"batch_size": 32, "learning_rate": 2e-5, "epochs": 5}, | |
"backends": ["local", "spaces-t4-small"], | |
"notes": ["Ensure images are preprocessed", "Works with CIFAR, ImageNet"], | |
}, | |
} | |
rec = recommendations.get( | |
task, | |
{ | |
"models": [], | |
"params": {}, | |
"backends": ["local"], | |
"notes": ["No specific recommendations available"], | |
}, | |
) | |
rec_text = f"""π― Training Recommendations for {task.title()} \ | |
({dataset_size} dataset) | |
**Recommended Models:** | |
{chr(10).join(f"β’ {model}" for model in rec["models"])} | |
**Recommended Parameters:** | |
{chr(10).join(f"β’ {k}: {v}" for k, v in rec["params"].items())} | |
**Backend Suggestions:** | |
{chr(10).join(f"β’ {backend}" for backend in rec["backends"])} | |
**Best Practices:** | |
{chr(10).join(f"β’ {note}" for note in rec["notes"])}""" | |
return rec_text | |
def get_system_status(random_string: str = "") -> str: | |
""" | |
Get AutoTrain system status and capabilities. | |
Returns: | |
System status, available tasks, backends, and statistics | |
""" | |
try: | |
runs = load_runs() | |
# Calculate stats | |
total_runs = len(runs) | |
running_runs = len([r for r in runs if r.get("status") == "running"]) | |
completed_runs = len([r for r in runs if r.get("status") == "completed"]) | |
failed_runs = len([r for r in runs if r.get("status") == "failed"]) | |
wandb_api_status = ( | |
"β Configured" if os.environ.get("WANDB_API_KEY") else "β Missing" | |
) | |
wandb_metrics_status = ( | |
"β Enabled" | |
if os.environ.get("WANDB_API_KEY") | |
else "β System metrics only" | |
) | |
status_text = f"""## βοΈ System Status | |
### π Run Statistics | |
| Metric | Count | | |
|--------|-------| | |
| **Server Status** | β Running | | |
| **Total Runs** | {total_runs} | | |
| **Active Runs** | {running_runs} | | |
| **Completed Runs** | {completed_runs} | | |
| **Failed Runs** | {failed_runs} | | |
### π‘ Access Points | |
| Service | URL | | |
|---------|-----| | |
| **Gradio UI** | http://SPACE_URL | | |
| **MCP Server** | http://SPACE_URL/gradio_api/mcp/sse | | |
| **MCP Schema** | http://SPACE_URL/gradio_api/mcp/schema | | |
### π οΈ W&B Integration | |
| Component | Status | | |
|-----------|--------| | |
| **Project** | {WANDB_PROJECT} | | |
| **API Key** | {wandb_api_status} | | |
| **Training Metrics** | {wandb_metrics_status} | | |
π‘ **Note:** Set WANDB_API_KEY for complete training metrics logging""" | |
return status_text | |
except Exception as e: | |
return f"β Error getting system status: {str(e)}" | |
def refresh_data(random_string: str = "") -> str: | |
"""Refresh data for UI updates""" | |
return "Data refreshed successfully" | |
def load_initial_data(random_string: str = "") -> str: | |
"""Load initial data for the application""" | |
return "Initial data loaded successfully" | |
# Web UI Functions | |
def fetch_runs_for_ui(): | |
"""Fetch runs for the web interface table""" | |
try: | |
runs = load_runs() | |
if not runs: | |
return pd.DataFrame( | |
{ | |
"Status": [], | |
"W&B Link": [], | |
"Project": [], | |
"Task": [], | |
"Model": [], | |
"Created": [], | |
"Run ID": [], | |
} | |
) | |
data = [] | |
for run in reversed(runs): # Newest first | |
wandb_link = "" | |
if run.get("wandb_url"): | |
wandb_link = f"[π W&B Run]({run['wandb_url']})" | |
data.append( | |
{ | |
"Status": f"{get_status_emoji(run['status'])} {run['status'].title()}", | |
"W&B Link": wandb_link, | |
"Project": run["project_name"], | |
"Task": run["task"].replace("-", " ").title(), | |
"Model": run["base_model"], | |
"Created": run["created_at"][:16].replace("T", " "), | |
"Run ID": run["run_id"][:8] + "...", | |
} | |
) | |
return pd.DataFrame(data) | |
except Exception as e: | |
return pd.DataFrame({"Error": [f"Failed to fetch runs: {str(e)}"]}) | |
def submit_training_job_ui( | |
task, | |
project_name, | |
base_model, | |
dataset_path, | |
epochs, | |
batch_size, | |
learning_rate, | |
backend, | |
push_to_hub, | |
hub_repo_id, | |
): | |
"""Submit training job from web UI""" | |
if not all([task, project_name, base_model, dataset_path]): | |
return "β Please fill in all required fields", fetch_runs_for_ui() | |
result = start_training_job( | |
task=task, | |
project_name=project_name, | |
base_model=base_model, | |
dataset_path=dataset_path, | |
epochs=str(epochs), | |
batch_size=str(batch_size), | |
learning_rate=str(learning_rate), | |
backend=backend, | |
push_to_hub=str(push_to_hub).lower(), | |
hub_repo_id=hub_repo_id, | |
) | |
return result, fetch_runs_for_ui() | |
# Create Gradio Interface | |
with gr.Blocks( | |
title="AutoTrain Gradio MCP Server", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px !important; | |
} | |
""", | |
) as app: | |
gr.Markdown(""" | |
# π AutoTrain MCP Server | |
Get your AI models to train your AI models! | |
This space is an MCP server that you can use in Claude Desktop, Cursor, VSCode, etc to train your AI models. | |
:warning: To train models you with need to duplicate this space! | |
**MCP Server**: AI assistants can use tools at http://SPACE_URL/gradio_api/mcp/sse | |
Connect to it like this: | |
```javascript | |
{ | |
"mcpServers": { | |
"autotrain": { | |
"url": "http://SPACE_URL/gradio_api/mcp/sse", | |
"headers": {"Authorization": "Bearer <YOUR-HUGGING-FACE-TOKEN>"} | |
} | |
} | |
} | |
``` | |
Or like this for Claude Desktop: | |
```javascript | |
{ | |
"mcpServers": { | |
"autotrain": { | |
"command": "npx", | |
"args": [ | |
"mcp-remote", | |
"http://SPACE_URL/gradio_api/mcp/sse", | |
"--header", | |
"Authorization: Bearer <YOUR-HUGGING-FACE-TOKEN>" | |
] | |
} | |
} | |
} | |
``` | |
""") | |
with gr.Tabs(): | |
# Dashboard Tab | |
with gr.Tab("π Training Runs"): | |
with gr.Row(): | |
runs_table = gr.Dataframe( | |
value=fetch_runs_for_ui(), interactive=False, datatype="markdown" | |
) | |
with gr.Row(): | |
refresh_btn = gr.Button("π Refresh", variant="secondary") | |
with gr.Tab("π§ System Status"): | |
stats = gr.Markdown(value=get_system_status()) | |
# MCP Tools Tab | |
with gr.Tab("π§ MCP Tools"): | |
gr.Markdown("## MCP Tool Testing Interface") | |
gr.Markdown("These tools are exposed via MCP for Claude Desktop") | |
gr.Interface( | |
fn=get_system_status, | |
inputs=[], | |
outputs=gr.Textbox(label="System Status"), | |
title="get_system_status", | |
description="Get AutoTrain system status and capabilities", | |
) | |
gr.Interface( | |
fn=get_training_runs, | |
inputs=[ | |
gr.Textbox(label="limit", value="20"), | |
gr.Textbox(label="status", value=""), | |
], | |
outputs=gr.Textbox(label="Training Runs"), | |
title="get_training_runs", | |
description="Get list of training runs with status", | |
) | |
gr.Interface( | |
fn=start_training_job, | |
inputs=[ | |
gr.Textbox(label="task", value="text-classification"), | |
gr.Textbox(label="project_name", value="test-project"), | |
gr.Textbox(label="base_model", value="distilbert-base-uncased"), | |
gr.Textbox(label="dataset_path", value="imdb"), | |
gr.Textbox(label="epochs", value="1"), | |
gr.Textbox(label="batch_size", value="8"), | |
gr.Textbox(label="learning_rate", value="2e-5"), | |
gr.Textbox(label="backend", value="local"), | |
gr.Textbox(label="push_to_hub", value="false"), | |
gr.Textbox(label="hub_repo_id", placeholder="your-repo-id"), | |
], | |
outputs=gr.Textbox(label="Training Job Result"), | |
title="start_training_job", | |
description="Start a new AutoTrain training job", | |
) | |
gr.Interface( | |
fn=get_run_details, | |
inputs=gr.Textbox( | |
label="run_id", placeholder="Enter run ID or first 8 chars" | |
), | |
outputs=gr.Textbox(label="Run Details"), | |
title="get_run_details", | |
description="Get detailed information about a training run", | |
) | |
gr.Interface( | |
fn=get_task_recommendations, | |
inputs=[ | |
gr.Textbox(label="task", value="text-classification"), | |
gr.Textbox(label="dataset_size", value="medium"), | |
], | |
outputs=gr.Textbox(label="Recommendations"), | |
title="get_task_recommendations", | |
description="Get training recommendations for a task", | |
) | |
# Event handlers with proper function names (not lambda) | |
def refresh_ui_data(): | |
return fetch_runs_for_ui(), get_system_status() | |
def load_initial_ui_data(): | |
return fetch_runs_for_ui(), get_system_status() | |
refresh_btn.click( | |
fn=refresh_ui_data, | |
outputs=[runs_table, stats], | |
) | |
# Load initial data | |
app.load( | |
fn=load_initial_ui_data, | |
outputs=[runs_table, stats], | |
) | |
# Helper to find an available port | |
def _find_available_port(start_port: int = 7860, max_tries: int = 20) -> int: | |
"""Return the first available port starting from `start_port`.""" | |
port = start_port | |
for _ in range(max_tries): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
try: | |
s.bind(("0.0.0.0", port)) | |
return port # Port is free | |
except OSError: | |
port += 1 # Try next port | |
# If no port found, let OS pick one | |
return 0 | |
if __name__ == "__main__": | |
chosen_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) | |
try: | |
chosen_port = _find_available_port(chosen_port) | |
except Exception: | |
# Fallback to OS-assigned port if something goes wrong | |
chosen_port = 0 | |
app.launch( | |
server_name="0.0.0.0", | |
server_port=chosen_port, | |
mcp_server=True, # Enable MCP server functionality | |
) | |