File size: 3,576 Bytes
84e7abf
a439dad
 
 
 
 
 
 
 
 
e8a949b
 
 
 
 
 
 
 
 
 
 
 
 
 
a439dad
 
 
 
 
 
 
 
 
 
 
 
 
6df42e3
a439dad
 
 
 
fd1fd71
a439dad
 
 
 
 
 
 
 
 
 
 
 
 
 
2d907c4
a439dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6df42e3
 
a439dad
 
6df42e3
d885278
a439dad
 
 
 
14051b6
a439dad
 
 
 
 
 
 
 
14051b6
a439dad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import uuid
import tempfile
import httpx
import torch
from PIL import Image
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoModelForImageClassification, ViTImageProcessor
from typing import Optional

# Determine a writable cache directory
default_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')

# Ensure the directory exists
try:
    os.makedirs(default_cache_dir, exist_ok=True)
except PermissionError:
    # Fallback to a temporary directory if user's home directory is not writable
    default_cache_dir = os.path.join('/tmp', 'huggingface_cache')
    os.makedirs(default_cache_dir, exist_ok=True)

# Set the environment variable to the created directory
os.environ['TRANSFORMERS_CACHE'] = default_cache_dir
os.environ['HF_HOME'] = default_cache_dir
# Initialize the model and processor globally to avoid reloading for each request
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
processor = ViTImageProcessor.from_pretrained('Falconsai/nsfw_image_detection')

app = FastAPI(title="NSFW Image Detection API")

@app.get("/detect")
async def detect_nsfw(
    url: str = Query(..., description="URL of the image to analyze"),
    timeout: Optional[int] = Query(default=10, description="Timeout for downloading image in seconds")
):
    """
    Detect NSFW content in an image from a given URL.
    
    - Saves the image to a temporary file
    - Processes the image using a pre-trained NSFW detection model
    - Returns the highest confidence prediction
    """
    try:
        # Get the system's temporary directory
        temp_dir = tempfile.gettempdir()
        temp_filename = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
        
        # Download the image
        async with httpx.AsyncClient() as client:
            response = await client.get(url, timeout=timeout)
            response.raise_for_status()
            
            # Save the image
            with open(temp_filename, 'wb') as f:
                f.write(response.content)
        
        # Open and process the image
        img = Image.open(temp_filename).convert("RGB")
        
        # Perform inference
        with torch.no_grad():
            inputs = processor(images=img, return_tensors="pt")
            outputs = model(**inputs)
            logits = outputs.logits
        
        # Calculate softmax probabilities
        confidences = torch.softmax(logits, dim=-1)
        
        # Find the label with the highest confidence
        max_confidence_id = confidences[0].argmax().item()
        max_label = model.config.id2label[max_confidence_id]
        max_confidence = confidences[0][max_confidence_id].item()
        
        # Clean up the temporary file
        os.unlink(temp_filename)
        
        return {
            "label": max_label,
            "confidence": max_confidence
        }
    
    except httpx.RequestError as e:
        raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")

# Startup event to warm up the model
@app.on_event("startup")
async def startup_event():
    # Perform a dummy inference to warm up the model
    dummy_img = Image.new('RGB', (224, 224), color='red')
    with torch.no_grad():
        inputs = processor(images=dummy_img, return_tensors="pt")
        model(**inputs)

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)