|
import os |
|
import uuid |
|
import tempfile |
|
import httpx |
|
import torch |
|
from PIL import Image |
|
from fastapi import FastAPI, Query, HTTPException |
|
from transformers import AutoModelForImageClassification, ViTImageProcessor |
|
from typing import Optional |
|
|
|
|
|
default_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub') |
|
|
|
|
|
try: |
|
os.makedirs(default_cache_dir, exist_ok=True) |
|
except PermissionError: |
|
|
|
default_cache_dir = os.path.join('/tmp', 'huggingface_cache') |
|
os.makedirs(default_cache_dir, exist_ok=True) |
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = default_cache_dir |
|
os.environ['HF_HOME'] = default_cache_dir |
|
|
|
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") |
|
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection') |
|
|
|
app = FastAPI(title="NSFW Image Detection API") |
|
|
|
@app.get("/detect") |
|
async def detect_nsfw( |
|
url: str = Query(..., description="URL of the image to analyze"), |
|
timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds") |
|
): |
|
""" |
|
Detect NSFW content in an image from a given URL. |
|
|
|
- Saves the image to a temporary file |
|
- Processes the image using a pre-trained NSFW detection model |
|
- Returns the highest confidence prediction |
|
""" |
|
try: |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") |
|
|
|
|
|
async with httpx.AsyncClient() as client: |
|
response = await client.get(url, timeout=timeout) |
|
response.raise_for_status() |
|
|
|
|
|
with open(temp_filename, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
img = Image.open(temp_filename).convert("RGB") |
|
|
|
|
|
with torch.no_grad(): |
|
inputs = processor(images=img, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
confidences = torch.softmax(logits, dim=-1) |
|
|
|
|
|
max_confidence_id = confidences[0].argmax().item() |
|
max_label = model.config.id2label[max_confidence_id] |
|
max_confidence = confidences[0][max_confidence_id].item() |
|
|
|
|
|
os.unlink(temp_filename) |
|
|
|
return { |
|
"label": max_label, |
|
"confidence": max_confidence |
|
} |
|
|
|
except httpx.RequestError as e: |
|
raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
|
|
dummy_img = Image.new('RGB', (224, 224), color='red') |
|
with torch.no_grad(): |
|
inputs = processor(images=dummy_img, return_tensors="pt") |
|
model(**inputs) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |