from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional, Dict from transformers import pipeline import torch import logging from functools import lru_cache # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Configure CORS - adjust the origins based on your needs app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust this in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Model cache MODEL_CACHE: Dict[str, any] = {} @lru_cache() def get_model(model_type: str): """Get or initialize model with caching.""" if model_type not in MODEL_CACHE: logger.info(f"Initializing {model_type} model...") try: if model_type == "summarizer": MODEL_CACHE[model_type] = pipeline( "summarization", model="facebook/bart-large-cnn", device="cpu" ) elif model_type == "detector": MODEL_CACHE[model_type] = pipeline( "text-classification", model="roberta-base-openai-detector", device="cpu" ) logger.info(f"Successfully initialized {model_type} model") except Exception as e: logger.error(f"Error initializing {model_type} model: {str(e)}") raise RuntimeError(f"Failed to initialize {model_type} model") return MODEL_CACHE[model_type] class TextRequest(BaseModel): text: str max_length: Optional[int] = 130 min_length: Optional[int] = 30 def validate_text(text: str, min_words: int = 10) -> bool: """Validate text input.""" return len(text.split()) >= min_words @app.get("/") async def root(): """Health check endpoint.""" return {"status": "healthy", "message": "API is running"} @app.post("/api/summarize") async def summarize_text(request: TextRequest): """Endpoint to summarize text.""" try: if not validate_text(request.text): raise HTTPException( status_code=400, detail="Text is too short to summarize (minimum 10 words required)" ) summarizer = get_model("summarizer") summary = summarizer( request.text, max_length=request.max_length, min_length=request.min_length, do_sample=False ) return {"summary": summary[0]["summary_text"]} except HTTPException: raise except Exception as e: logger.error(f"Error in summarization: {str(e)}") raise HTTPException( status_code=500, detail="An error occurred during summarization" ) @app.post("/api/detect") async def detect_ai(request: TextRequest): """Endpoint to detect if text is AI-generated.""" try: if not validate_text(request.text, min_words=5): raise HTTPException( status_code=400, detail="Text is too short for AI detection (minimum 5 words required)" ) detector = get_model("detector") result = detector(request.text)[0] prob_ai = torch.sigmoid(torch.tensor(result["score"])).item() score = prob_ai * 100 confidence = ( "high" if abs(score - 50) > 25 else "medium" if abs(score - 50) > 10 else "low" ) return { "score": round(score, 2), "likely_ai": score > 70, "confidence": confidence } except HTTPException: raise except Exception as e: logger.error(f"Error in AI detection: {str(e)}") raise HTTPException( status_code=500, detail="An error occurred during AI detection" )