File size: 3,576 Bytes
84e7abf a439dad e8a949b a439dad 6df42e3 a439dad fd1fd71 a439dad 2d907c4 a439dad 6df42e3 a439dad 6df42e3 d885278 a439dad 14051b6 a439dad 14051b6 a439dad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import uuid
import tempfile
import httpx
import torch
from PIL import Image
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoModelForImageClassification, ViTImageProcessor
from typing import Optional
# Determine a writable cache directory
default_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')
# Ensure the directory exists
try:
os.makedirs(default_cache_dir, exist_ok=True)
except PermissionError:
# Fallback to a temporary directory if user's home directory is not writable
default_cache_dir = os.path.join('/tmp', 'huggingface_cache')
os.makedirs(default_cache_dir, exist_ok=True)
# Set the environment variable to the created directory
os.environ['TRANSFORMERS_CACHE'] = default_cache_dir
os.environ['HF_HOME'] = default_cache_dir
# Initialize the model and processor globally to avoid reloading for each request
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
app = FastAPI(title="NSFW Image Detection API")
@app.get("/detect")
async def detect_nsfw(
url: str = Query(..., description="URL of the image to analyze"),
timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds")
):
"""
Detect NSFW content in an image from a given URL.
- Saves the image to a temporary file
- Processes the image using a pre-trained NSFW detection model
- Returns the highest confidence prediction
"""
try:
# Get the system's temporary directory
temp_dir = tempfile.gettempdir()
temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
# Download the image
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=timeout)
response.raise_for_status()
# Save the image
with open(temp_filename, 'wb') as f:
f.write(response.content)
# Open and process the image
img = Image.open(temp_filename).convert("RGB")
# Perform inference
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Calculate softmax probabilities
confidences = torch.softmax(logits, dim=-1)
# Find the label with the highest confidence
max_confidence_id = confidences[0].argmax().item()
max_label = model.config.id2label[max_confidence_id]
max_confidence = confidences[0][max_confidence_id].item()
# Clean up the temporary file
os.unlink(temp_filename)
return {
"label": max_label,
"confidence": max_confidence
}
except httpx.RequestError as e:
raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Startup event to warm up the model
@app.on_event("startup")
async def startup_event():
# Perform a dummy inference to warm up the model
dummy_img = Image.new('RGB', (224, 224), color='red')
with torch.no_grad():
inputs = processor(images=dummy_img, return_tensors="pt")
model(**inputs)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |