Spaces:
Sleeping
Sleeping
File size: 4,076 Bytes
35d30d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict
from transformers import pipeline
import torch
import logging
from functools import lru_cache
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Configure CORS - adjust the origins based on your needs
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model cache
MODEL_CACHE: Dict[str, any] = {}
@lru_cache()
def get_model(model_type: str):
"""Get or initialize model with caching."""
if model_type not in MODEL_CACHE:
logger.info(f"Initializing {model_type} model...")
try:
if model_type == "summarizer":
MODEL_CACHE[model_type] = pipeline(
"summarization",
model="facebook/bart-large-cnn",
device="cpu"
)
elif model_type == "detector":
MODEL_CACHE[model_type] = pipeline(
"text-classification",
model="roberta-base-openai-detector",
device="cpu"
)
logger.info(f"Successfully initialized {model_type} model")
except Exception as e:
logger.error(f"Error initializing {model_type} model: {str(e)}")
raise RuntimeError(f"Failed to initialize {model_type} model")
return MODEL_CACHE[model_type]
class TextRequest(BaseModel):
text: str
max_length: Optional[int] = 130
min_length: Optional[int] = 30
def validate_text(text: str, min_words: int = 10) -> bool:
"""Validate text input."""
return len(text.split()) >= min_words
@app.get("/")
async def root():
"""Health check endpoint."""
return {"status": "healthy", "message": "API is running"}
@app.post("/api/summarize")
async def summarize_text(request: TextRequest):
"""Endpoint to summarize text."""
try:
if not validate_text(request.text):
raise HTTPException(
status_code=400,
detail="Text is too short to summarize (minimum 10 words required)"
)
summarizer = get_model("summarizer")
summary = summarizer(
request.text,
max_length=request.max_length,
min_length=request.min_length,
do_sample=False
)
return {"summary": summary[0]["summary_text"]}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in summarization: {str(e)}")
raise HTTPException(
status_code=500,
detail="An error occurred during summarization"
)
@app.post("/api/detect")
async def detect_ai(request: TextRequest):
"""Endpoint to detect if text is AI-generated."""
try:
if not validate_text(request.text, min_words=5):
raise HTTPException(
status_code=400,
detail="Text is too short for AI detection (minimum 5 words required)"
)
detector = get_model("detector")
result = detector(request.text)[0]
prob_ai = torch.sigmoid(torch.tensor(result["score"])).item()
score = prob_ai * 100
confidence = (
"high" if abs(score - 50) > 25
else "medium" if abs(score - 50) > 10
else "low"
)
return {
"score": round(score, 2),
"likely_ai": score > 70,
"confidence": confidence
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in AI detection: {str(e)}")
raise HTTPException(
status_code=500,
detail="An error occurred during AI detection"
) |