SUMAIB / app.py
Agamrampal's picture
1
35d30d1
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict
from transformers import pipeline
import torch
import logging
from functools import lru_cache
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Configure CORS - adjust the origins based on your needs
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model cache
MODEL_CACHE: Dict[str, any] = {}
@lru_cache()
def get_model(model_type: str):
"""Get or initialize model with caching."""
if model_type not in MODEL_CACHE:
logger.info(f"Initializing {model_type} model...")
try:
if model_type == "summarizer":
MODEL_CACHE[model_type] = pipeline(
"summarization",
model="facebook/bart-large-cnn",
device="cpu"
)
elif model_type == "detector":
MODEL_CACHE[model_type] = pipeline(
"text-classification",
model="roberta-base-openai-detector",
device="cpu"
)
logger.info(f"Successfully initialized {model_type} model")
except Exception as e:
logger.error(f"Error initializing {model_type} model: {str(e)}")
raise RuntimeError(f"Failed to initialize {model_type} model")
return MODEL_CACHE[model_type]
class TextRequest(BaseModel):
text: str
max_length: Optional[int] = 130
min_length: Optional[int] = 30
def validate_text(text: str, min_words: int = 10) -> bool:
"""Validate text input."""
return len(text.split()) >= min_words
@app.get("/")
async def root():
"""Health check endpoint."""
return {"status": "healthy", "message": "API is running"}
@app.post("/api/summarize")
async def summarize_text(request: TextRequest):
"""Endpoint to summarize text."""
try:
if not validate_text(request.text):
raise HTTPException(
status_code=400,
detail="Text is too short to summarize (minimum 10 words required)"
)
summarizer = get_model("summarizer")
summary = summarizer(
request.text,
max_length=request.max_length,
min_length=request.min_length,
do_sample=False
)
return {"summary": summary[0]["summary_text"]}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in summarization: {str(e)}")
raise HTTPException(
status_code=500,
detail="An error occurred during summarization"
)
@app.post("/api/detect")
async def detect_ai(request: TextRequest):
"""Endpoint to detect if text is AI-generated."""
try:
if not validate_text(request.text, min_words=5):
raise HTTPException(
status_code=400,
detail="Text is too short for AI detection (minimum 5 words required)"
)
detector = get_model("detector")
result = detector(request.text)[0]
prob_ai = torch.sigmoid(torch.tensor(result["score"])).item()
score = prob_ai * 100
confidence = (
"high" if abs(score - 50) > 25
else "medium" if abs(score - 50) > 10
else "low"
)
return {
"score": round(score, 2),
"likely_ai": score > 70,
"confidence": confidence
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in AI detection: {str(e)}")
raise HTTPException(
status_code=500,
detail="An error occurred during AI detection"
)