Enkacard / main.py
cINAWGD's picture
Update main.py
2d907c4 verified
import os
import uuid
import tempfile
import httpx
import torch
from PIL import Image
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoModelForImageClassification, ViTImageProcessor
from typing import Optional
# Determine a writable cache directory
default_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')
# Ensure the directory exists
try:
os.makedirs(default_cache_dir, exist_ok=True)
except PermissionError:
# Fallback to a temporary directory if user's home directory is not writable
default_cache_dir = os.path.join('/tmp', 'huggingface_cache')
os.makedirs(default_cache_dir, exist_ok=True)
# Set the environment variable to the created directory
os.environ['TRANSFORMERS_CACHE'] = default_cache_dir
os.environ['HF_HOME'] = default_cache_dir
# Initialize the model and processor globally to avoid reloading for each request
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')
app = FastAPI(title="NSFW Image Detection API")
@app.get("/detect")
async def detect_nsfw(
url: str = Query(..., description="URL of the image to analyze"),
timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds")
):
"""
Detect NSFW content in an image from a given URL.
- Saves the image to a temporary file
- Processes the image using a pre-trained NSFW detection model
- Returns the highest confidence prediction
"""
try:
# Get the system's temporary directory
temp_dir = tempfile.gettempdir()
temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
# Download the image
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=timeout)
response.raise_for_status()
# Save the image
with open(temp_filename, 'wb') as f:
f.write(response.content)
# Open and process the image
img = Image.open(temp_filename).convert("RGB")
# Perform inference
with torch.no_grad():
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Calculate softmax probabilities
confidences = torch.softmax(logits, dim=-1)
# Find the label with the highest confidence
max_confidence_id = confidences[0].argmax().item()
max_label = model.config.id2label[max_confidence_id]
max_confidence = confidences[0][max_confidence_id].item()
# Clean up the temporary file
os.unlink(temp_filename)
return {
"label": max_label,
"confidence": max_confidence
}
except httpx.RequestError as e:
raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
# Startup event to warm up the model
@app.on_event("startup")
async def startup_event():
# Perform a dummy inference to warm up the model
dummy_img = Image.new('RGB', (224, 224), color='red')
with torch.no_grad():
inputs = processor(images=dummy_img, return_tensors="pt")
model(**inputs)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)