LLMTrainingPro / app.py
Vishwas1's picture
Update app.py
2de0e9b verified
# training_space/app.py (FastAPI Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
from fastapi.middleware.cors import CORSMiddleware
import logging
app = FastAPI()
# Configure Logging
logging.basicConfig(
filename='training.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
# CORS Configuration
origins = [
"https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL
"http://localhost", # For local testing
"https://web.postman.co",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the expected payload structure
class TrainingRequest(BaseModel):
task: str # 'generation' or 'classification'
model_params: dict
model_name: str
dataset_name: str # The name of the existing Hugging Face dataset
# Root Endpoint
@app.get("/")
def read_root():
return {
"message": "Welcome to the Training Space API!",
"instructions": "To train a model, send a POST request to /train with the required parameters."
}
# Train Endpoint
@app.post("/train")
def train_model(request: TrainingRequest):
try:
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
# Create a unique directory for this training session
session_id = str(uuid.uuid4())
session_dir = f"./training_sessions/{session_id}"
os.makedirs(session_dir, exist_ok=True)
# No need to save dataset content; use dataset_name directly
dataset_name = request.dataset_name
# Define the absolute path to train_model.py
TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
# Prepare the command to run the training script with dataset_name
cmd = [
"python", TRAIN_MODEL_PATH,
"--task", request.task,
"--model_name", request.model_name,
"--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path
"--num_layers", str(request.model_params.get('num_layers', 12)),
"--attention_heads", str(request.model_params.get('attention_heads', 1)),
"--hidden_size", str(request.model_params.get('hidden_size', 64)),
"--vocab_size", str(request.model_params.get('vocab_size', 30000)),
"--sequence_length", str(request.model_params.get('sequence_length', 512))
]
# Start the training process as a background task in the root directory
subprocess.Popen(cmd, cwd=os.path.dirname(__file__))
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
return {"status": "Training started", "session_id": session_id}
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
session_dir = f"./training_sessions/{session_id}"
log_file = os.path.join(session_dir, "training.log")
if not os.path.exists(log_file):
raise HTTPException(status_code=404, detail="Session ID not found.")
with open(log_file, "r", encoding="utf-8") as f:
logs = f.read()
return {"session_id": session_id, "logs": logs}