Spaces:
Sleeping
Sleeping
File size: 3,808 Bytes
981a076 e1e315b 36071c5 7da4761 981a076 e1e315b 981a076 7da4761 981a076 3042d4c 7da4761 e1e315b 2de0e9b e1e315b 981a076 deddd5d 981a076 e1e315b deddd5d 981a076 e1e315b 2de0e9b e1e315b 2de0e9b 21a5890 2de0e9b e1e315b 21a5890 e1e315b 2de0e9b 981a076 e1e315b 2de0e9b e1e315b deddd5d e1e315b deddd5d e1e315b deddd5d 2de0e9b deddd5d 981a076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# training_space/app.py (FastAPI Backend)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import os
import uuid
from huggingface_hub import HfApi, HfFolder
from fastapi.middleware.cors import CORSMiddleware
import logging
app = FastAPI()
# Configure Logging
logging.basicConfig(
filename='training.log',
filemode='a',
format='%(asctime)s - %(levelname)s - %(message)s',
level=logging.INFO
)
# CORS Configuration
origins = [
"https://Vishwas1-LLMBuilderPro.hf.space", # Replace with your Gradio frontend Space URL
"http://localhost", # For local testing
"https://web.postman.co",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the expected payload structure
class TrainingRequest(BaseModel):
task: str # 'generation' or 'classification'
model_params: dict
model_name: str
dataset_name: str # The name of the existing Hugging Face dataset
# Root Endpoint
@app.get("/")
def read_root():
return {
"message": "Welcome to the Training Space API!",
"instructions": "To train a model, send a POST request to /train with the required parameters."
}
# Train Endpoint
@app.post("/train")
def train_model(request: TrainingRequest):
try:
logging.info(f"Received training request for model: {request.model_name}, Task: {request.task}")
# Create a unique directory for this training session
session_id = str(uuid.uuid4())
session_dir = f"./training_sessions/{session_id}"
os.makedirs(session_dir, exist_ok=True)
# No need to save dataset content; use dataset_name directly
dataset_name = request.dataset_name
# Define the absolute path to train_model.py
TRAIN_MODEL_PATH = os.path.join(os.path.dirname(__file__), "train_model.py")
# Prepare the command to run the training script with dataset_name
cmd = [
"python", TRAIN_MODEL_PATH,
"--task", request.task,
"--model_name", request.model_name,
"--dataset_name", dataset_name, # Pass dataset_name instead of dataset file path
"--num_layers", str(request.model_params.get('num_layers', 12)),
"--attention_heads", str(request.model_params.get('attention_heads', 1)),
"--hidden_size", str(request.model_params.get('hidden_size', 64)),
"--vocab_size", str(request.model_params.get('vocab_size', 30000)),
"--sequence_length", str(request.model_params.get('sequence_length', 512))
]
# Start the training process as a background task in the root directory
subprocess.Popen(cmd, cwd=os.path.dirname(__file__))
logging.info(f"Training started for model: {request.model_name}, Session ID: {session_id}")
return {"status": "Training started", "session_id": session_id}
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
logging.error(f"Error during training request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Optional: Status Endpoint
@app.get("/status/{session_id}")
def get_status(session_id: str):
session_dir = f"./training_sessions/{session_id}"
log_file = os.path.join(session_dir, "training.log")
if not os.path.exists(log_file):
raise HTTPException(status_code=404, detail="Session ID not found.")
with open(log_file, "r", encoding="utf-8") as f:
logs = f.read()
return {"session_id": session_id, "logs": logs}
|