File size: 4,076 Bytes
35d30d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict
from transformers import pipeline
import torch
import logging
from functools import lru_cache

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Configure CORS - adjust the origins based on your needs
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Adjust this in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Model cache
MODEL_CACHE: Dict[str, any] = {}

@lru_cache()
def get_model(model_type: str):
    """Get or initialize model with caching."""
    if model_type not in MODEL_CACHE:
        logger.info(f"Initializing {model_type} model...")
        try:
            if model_type == "summarizer":
                MODEL_CACHE[model_type] = pipeline(
                    "summarization",
                    model="facebook/bart-large-cnn",
                    device="cpu"
                )
            elif model_type == "detector":
                MODEL_CACHE[model_type] = pipeline(
                    "text-classification",
                    model="roberta-base-openai-detector",
                    device="cpu"
                )
            logger.info(f"Successfully initialized {model_type} model")
        except Exception as e:
            logger.error(f"Error initializing {model_type} model: {str(e)}")
            raise RuntimeError(f"Failed to initialize {model_type} model")
    return MODEL_CACHE[model_type]

class TextRequest(BaseModel):
    text: str
    max_length: Optional[int] = 130
    min_length: Optional[int] = 30

def validate_text(text: str, min_words: int = 10) -> bool:
    """Validate text input."""
    return len(text.split()) >= min_words

@app.get("/")
async def root():
    """Health check endpoint."""
    return {"status": "healthy", "message": "API is running"}

@app.post("/api/summarize")
async def summarize_text(request: TextRequest):
    """Endpoint to summarize text."""
    try:
        if not validate_text(request.text):
            raise HTTPException(
                status_code=400,
                detail="Text is too short to summarize (minimum 10 words required)"
            )

        summarizer = get_model("summarizer")
        summary = summarizer(
            request.text,
            max_length=request.max_length,
            min_length=request.min_length,
            do_sample=False
        )

        return {"summary": summary[0]["summary_text"]}
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in summarization: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail="An error occurred during summarization"
        )

@app.post("/api/detect")
async def detect_ai(request: TextRequest):
    """Endpoint to detect if text is AI-generated."""
    try:
        if not validate_text(request.text, min_words=5):
            raise HTTPException(
                status_code=400,
                detail="Text is too short for AI detection (minimum 5 words required)"
            )

        detector = get_model("detector")
        result = detector(request.text)[0]
        
        prob_ai = torch.sigmoid(torch.tensor(result["score"])).item()
        score = prob_ai * 100
        
        confidence = (
            "high" if abs(score - 50) > 25 
            else "medium" if abs(score - 50) > 10 
            else "low"
        )
        
        return {
            "score": round(score, 2),
            "likely_ai": score > 70,
            "confidence": confidence
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in AI detection: {str(e)}")
        raise HTTPException(
            status_code=500,
            detail="An error occurred during AI detection"
        )