File size: 2,120 Bytes
78a9e80
 
792d4e1
 
 
 
78a9e80
792d4e1
 
 
 
 
 
 
 
 
 
 
 
78a9e80
 
 
792d4e1
 
 
4f12c55
 
 
 
792d4e1
 
78a9e80
792d4e1
78a9e80
 
1c61dbb
78a9e80
792d4e1
a99a830
d739aa1
792d4e1
 
 
 
 
 
201c97f
 
 
 
792d4e1
 
201c97f
792d4e1
78a9e80
 
792d4e1
 
 
4f12c55
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from fastapi import FastAPI, UploadFile, File, HTTPException, Header
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import numpy as np
import io

# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
    task="image-feature-extraction", 
    model="google/vit-base-patch16-384", 
    device=0 if torch.cuda.is_available() else -1
)

# Read SECRET_KEY from the environment
SECRET_KEY = os.getenv("SECRET_KEY", "default-secret")

# Initialize FastAPI
app = FastAPI()

@app.get("/")
def read_root():
    return {"message": "App is running successfully!"}

# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)):
    try:
        # Verify the SECRET_KEY
        if secret_key != SECRET_KEY:
            raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}")
        
        # Validate file format
        if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
            raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
        
        # Read image
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")

        # Get feature embeddings
        features = feature_extraction_pipeline(image)  # Shape: (1, 577, 768)
        
        # Extract CLS token embedding
        cls_embedding = np.array(features)[0, 0, :]  # Shape: (768,)
        
        # Return the embedding vector
        return JSONResponse(content={"features": cls_embedding.tolist()})
    
    except HTTPException:
        raise  # Reraise HTTPExceptions for proper status codes
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)