Spaces:
Sleeping
Sleeping
# training_space/app.py (FastAPI Backend) | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import subprocess | |
import os | |
import uuid | |
from huggingface_hub import HfApi, HfFolder | |
from fastapi.middleware.cors import CORSMiddleware | |
import logging | |
app = FastAPI() | |
# Configure Logging | |
logging.basicConfig( | |
filename='training.log', | |
filemode='a', | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
level=logging.INFO | |
) | |
# CORS Configuration | |
origins = [ | |
"https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL | |
"http://localhost", # For local testing | |
"https://web.postman.co", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define the expected payload structure | |
class TrainingRequest(BaseModel): | |
task: str # 'generation' or 'classification' | |
model_params: dict | |
model_name: str | |
dataset_name: str # The name of the existing Hugging Face dataset | |
# Root Endpoint | |
def read_root(): | |
return { | |
"message": "Welcome to the Training Space API!", | |
"instructions": "To train a model, send a POST request to /train with the required parameters." | |
} | |
# Train Endpoint | |
def train_model(request: TrainingRequest): | |
try: | |
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}") | |
# Create a unique directory for this training session | |
session_id = str(uuid.uuid4()) | |
session_dir = f"./training_sessions/{session_id}" | |
os.makedirs(session_dir, exist_ok=True) | |
# No need to save dataset content; use dataset_name directly | |
dataset_name = request.dataset_name | |
# Define the absolute path to train_model.py | |
TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py") | |
# Prepare the command to run the training script with dataset_name | |
cmd = [ | |
"python", TRAIN_MODEL_PATH, | |
"--task", request.task, | |
"--model_name", request.model_name, | |
"--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path | |
"--num_layers", str(request.model_params.get('num_layers', 12)), | |
"--attention_heads", str(request.model_params.get('attention_heads', 1)), | |
"--hidden_size", str(request.model_params.get('hidden_size', 64)), | |
"--vocab_size", str(request.model_params.get('vocab_size', 30000)), | |
"--sequence_length", str(request.model_params.get('sequence_length', 512)) | |
] | |
# Start the training process as a background task in the root directory | |
subprocess.Popen(cmd, cwd=os.path.dirname(__file__)) | |
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}") | |
return {"status": "Training started", "session_id": session_id} | |
except Exception as e: | |
logging.error(f"Error during training request: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
except Exception as e: | |
logging.error(f"Error during training request: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Optional: Status Endpoint | |
def get_status(session_id: str): | |
session_dir = f"./training_sessions/{session_id}" | |
log_file = os.path.join(session_dir, "training.log") | |
if not os.path.exists(log_file): | |
raise HTTPException(status_code=404, detail="Session ID not found.") | |
with open(log_file, "r", encoding="utf-8") as f: | |
logs = f.read() | |
return {"session_id": session_id, "logs": logs} | |