Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, HTTPException, Depends, status, File, UploadFile, Form | |
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse, RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional, Dict, List, Union, Any | |
from datetime import datetime, timedelta | |
import jwt | |
from jwt.exceptions import PyJWTError | |
from passlib.context import CryptContext | |
import os | |
import shutil | |
import logging | |
import json | |
from motor.motor_asyncio import AsyncIOMotorClient | |
from decouple import config | |
import uuid | |
from bson.objectid import ObjectId | |
import asyncio | |
import time | |
import sys | |
from pathlib import Path | |
# Import the original app modules | |
from app.models import * | |
from app.agents.podcast_manager import PodcastManager | |
from app.agents.researcher import Researcher | |
from app.agents.debate_agent import DebateAgent | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="PodCraft API") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Get MongoDB connection string from environment or config | |
MONGODB_URL = os.getenv("MONGODB_URL", config("MONGODB_URL", default="mongodb://localhost:27017")) | |
# MongoDB client | |
client = AsyncIOMotorClient(MONGODB_URL) | |
db = client.podcraft | |
users = db.users | |
podcasts = db.podcasts | |
agents = db.agents | |
workflows = db.workflows | |
# Initialize podcast manager | |
podcast_manager = PodcastManager() | |
# Initialize researcher | |
researcher = Researcher() | |
# Initialize debate agent | |
debate_agent = DebateAgent() | |
# Password hashing | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
# JWT settings | |
SECRET_KEY = os.getenv("SECRET_KEY", config("SECRET_KEY", default="your-secret-key")) | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 1 week | |
# OAuth2 scheme | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
# Keep the original authentication and API routes | |
# Include all the functions and routes from the original main.py | |
def create_access_token(data: dict): | |
to_encode = data.copy() | |
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
async def get_current_user(token: str = Depends(oauth2_scheme)): | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
except PyJWTError: | |
raise credentials_exception | |
user = await users.find_one({"username": username}) | |
if user is None: | |
raise credentials_exception | |
# Convert ObjectId to string for JSON serialization | |
user["_id"] = str(user["_id"]) | |
return user | |
# Route for health check | |
async def health(): | |
return {"status": "healthy", "version": "1.0.0"} | |
# API routes | |
async def signup(user: UserCreate): | |
# Check if username exists | |
existing_user = await users.find_one({"username": user.username}) | |
if existing_user: | |
raise HTTPException(status_code=400, detail="Username already registered") | |
# Hash the password | |
hashed_password = get_password_hash(user.password) | |
# Create new user | |
user_obj = {"username": user.username, "password": hashed_password} | |
new_user = await users.insert_one(user_obj) | |
# Create access token | |
access_token = create_access_token(data={"sub": user.username}) | |
return {"access_token": access_token, "token_type": "bearer"} | |
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): | |
user = await users.find_one({"username": form_data.username}) | |
if not user or not verify_password(form_data.password, user["password"]): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect username or password", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
access_token = create_access_token(data={"sub": form_data.username}) | |
return {"access_token": access_token, "token_type": "bearer"} | |
async def login(request: Request, user: UserLogin): | |
db_user = await users.find_one({"username": user.username}) | |
if not db_user or not verify_password(user.password, db_user["password"]): | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect username or password" | |
) | |
access_token = create_access_token(data={"sub": user.username}) | |
return {"access_token": access_token, "token_type": "bearer"} | |
# Add all the other API routes from the original main.py | |
# ... | |
# Determine static files path | |
static_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "static")) | |
if not os.path.exists(static_path): | |
try: | |
os.makedirs(static_path, exist_ok=True) | |
with open(os.path.join(static_path, "index.html"), "w") as f: | |
f.write("<html><body><h1>PodCraft API</h1><p>Frontend not found.</p></body></html>") | |
except Exception as e: | |
logger.error(f"Error creating static path: {e}") | |
# Mount static files for frontend | |
app.mount("/static", StaticFiles(directory=static_path), name="static") | |
# Add route to serve audio files | |
async def serve_audio(path: str): | |
audio_dir = "/app/temp_audio" | |
if not os.path.exists(audio_dir): | |
try: | |
os.makedirs(audio_dir, exist_ok=True) | |
except Exception as e: | |
logger.error(f"Error creating audio directory: {e}") | |
audio_file = os.path.join(audio_dir, path) | |
if os.path.exists(audio_file): | |
return FileResponse(audio_file) | |
else: | |
raise HTTPException(status_code=404, detail="Audio file not found") | |
# Root path to serve frontend or redirect to static | |
async def serve_frontend(): | |
index_file = os.path.join(static_path, "index.html") | |
if os.path.exists(index_file): | |
with open(index_file, "r") as f: | |
return HTMLResponse(content=f.read()) | |
else: | |
# Create a minimal index.html if it doesn't exist | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>PodCraft API</title> | |
<style> | |
body { font-family: Arial, sans-serif; padding: 20px; } | |
h1 { color: #8b5cf6; } | |
</style> | |
</head> | |
<body> | |
<h1>PodCraft API</h1> | |
<p>The frontend application is not available. Please build the frontend or check the documentation.</p> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html_content) | |
# Catch-all route for SPA routing - all non-API routes should serve the frontend | |
async def catch_all(full_path: str): | |
# If path starts with "api", it's an API route not found | |
if full_path.startswith("api/"): | |
raise HTTPException(status_code=404, detail="API endpoint not found") | |
# Otherwise serve the SPA | |
return await serve_frontend() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |