Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional, Dict | |
from transformers import pipeline | |
import torch | |
import logging | |
from functools import lru_cache | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Configure CORS - adjust the origins based on your needs | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Adjust this in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Model cache | |
MODEL_CACHE: Dict[str, any] = {} | |
def get_model(model_type: str): | |
"""Get or initialize model with caching.""" | |
if model_type not in MODEL_CACHE: | |
logger.info(f"Initializing {model_type} model...") | |
try: | |
if model_type == "summarizer": | |
MODEL_CACHE[model_type] = pipeline( | |
"summarization", | |
model="facebook/bart-large-cnn", | |
device="cpu" | |
) | |
elif model_type == "detector": | |
MODEL_CACHE[model_type] = pipeline( | |
"text-classification", | |
model="roberta-base-openai-detector", | |
device="cpu" | |
) | |
logger.info(f"Successfully initialized {model_type} model") | |
except Exception as e: | |
logger.error(f"Error initializing {model_type} model: {str(e)}") | |
raise RuntimeError(f"Failed to initialize {model_type} model") | |
return MODEL_CACHE[model_type] | |
class TextRequest(BaseModel): | |
text: str | |
max_length: Optional[int] = 130 | |
min_length: Optional[int] = 30 | |
def validate_text(text: str, min_words: int = 10) -> bool: | |
"""Validate text input.""" | |
return len(text.split()) >= min_words | |
async def root(): | |
"""Health check endpoint.""" | |
return {"status": "healthy", "message": "API is running"} | |
async def summarize_text(request: TextRequest): | |
"""Endpoint to summarize text.""" | |
try: | |
if not validate_text(request.text): | |
raise HTTPException( | |
status_code=400, | |
detail="Text is too short to summarize (minimum 10 words required)" | |
) | |
summarizer = get_model("summarizer") | |
summary = summarizer( | |
request.text, | |
max_length=request.max_length, | |
min_length=request.min_length, | |
do_sample=False | |
) | |
return {"summary": summary[0]["summary_text"]} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in summarization: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail="An error occurred during summarization" | |
) | |
async def detect_ai(request: TextRequest): | |
"""Endpoint to detect if text is AI-generated.""" | |
try: | |
if not validate_text(request.text, min_words=5): | |
raise HTTPException( | |
status_code=400, | |
detail="Text is too short for AI detection (minimum 5 words required)" | |
) | |
detector = get_model("detector") | |
result = detector(request.text)[0] | |
prob_ai = torch.sigmoid(torch.tensor(result["score"])).item() | |
score = prob_ai * 100 | |
confidence = ( | |
"high" if abs(score - 50) > 25 | |
else "medium" if abs(score - 50) > 10 | |
else "low" | |
) | |
return { | |
"score": round(score, 2), | |
"likely_ai": score > 70, | |
"confidence": confidence | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in AI detection: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail="An error occurred during AI detection" | |
) |