import os import uuid import tempfile import httpx import torch from PIL import Image from fastapi import FastAPI, Query, HTTPException from transformers import AutoModelForImageClassification, ViTImageProcessor from typing import Optional # Determine a writable cache directory default_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub') # Ensure the directory exists try: os.makedirs(default_cache_dir, exist_ok=True) except PermissionError: # Fallback to a temporary directory if user's home directory is not writable default_cache_dir = os.path.join('/tmp', 'huggingface_cache') os.makedirs(default_cache_dir, exist_ok=True) # Set the environment variable to the created directory os.environ['TRANSFORMERS_CACHE'] = default_cache_dir os.environ['HF_HOME'] = default_cache_dir # Initialize the model and processor globally to avoid reloading for each request model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection') app = FastAPI(title="NSFW Image Detection API") @app.get("/detect") async def detect_nsfw( url: str = Query(..., description="URL of the image to analyze"), timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds") ): """ Detect NSFW content in an image from a given URL. - Saves the image to a temporary file - Processes the image using a pre-trained NSFW detection model - Returns the highest confidence prediction """ try: # Get the system's temporary directory temp_dir = tempfile.gettempdir() temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") # Download the image async with httpx.AsyncClient() as client: response = await client.get(url, timeout=timeout) response.raise_for_status() # Save the image with open(temp_filename, 'wb') as f: f.write(response.content) # Open and process the image img = Image.open(temp_filename).convert("RGB") # Perform inference with torch.no_grad(): inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Calculate softmax probabilities confidences = torch.softmax(logits, dim=-1) # Find the label with the highest confidence max_confidence_id = confidences[0].argmax().item() max_label = model.config.id2label[max_confidence_id] max_confidence = confidences[0][max_confidence_id].item() # Clean up the temporary file os.unlink(temp_filename) return { "label": max_label, "confidence": max_confidence } except httpx.RequestError as e: raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") # Startup event to warm up the model @app.on_event("startup") async def startup_event(): # Perform a dummy inference to warm up the model dummy_img = Image.new('RGB', (224, 224), color='red') with torch.no_grad(): inputs = processor(images=dummy_img, return_tensors="pt") model(**inputs) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)