import os import uuid import tempfile import httpx import torch from PIL import Image from fastapi import FastAPI, Query, HTTPException from transformers import AutoModelForImageClassification, ViTImageProcessor from typing import Optional # Initialize the model and processor globally to avoid reloading for each request model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection') app = FastAPI(title="NSFW Image Detection API") @app.get("/detect") async def detect_nsfw( url: str = Query(..., description="URL of the image to analyze"), timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds") ): """ Detect NSFW content in an image from a given URL. - Saves the image to a temporary file - Processes the image using a pre-trained NSFW detection model - Returns the highest confidence prediction """ try: # Get the system's temporary directory temp_dir = tempfile.gettempdir() temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") # Download the image async with httpx.AsyncClient() as client: response = await client.get(url, timeout=timeout) response.raise_for_status() # Save the image with open(temp_filename, 'wb') as f: f.write(response.content) # Open and process the image img = Image.open(temp_filename) # Perform inference with torch.no_grad(): inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Calculate softmax probabilities confidences = torch.softmax(logits, dim=-1) # Find the label with the highest confidence max_confidence_id = confidences[0].argmax().item() max_label = model.config.id2label[max_confidence_id] max_confidence = confidences[0][max_confidence_id].item() # Clean up the temporary file os.unlink(temp_filename) return { "label": max_label, "confidence": max_confidence } except httpx.RequestError as e: raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") # Startup event to warm up the model @app.on_event("startup") async def startup_event(): # Perform a dummy inference to warm up the model dummy_img = Image.new('RGB', (224, 224), color='red') with torch.no_grad(): inputs = processor(images=dummy_img, return_tensors="pt") model(**inputs) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)