File size: 2,120 Bytes
78a9e80 792d4e1 78a9e80 792d4e1 78a9e80 792d4e1 4f12c55 792d4e1 78a9e80 792d4e1 78a9e80 1c61dbb 78a9e80 792d4e1 a99a830 d739aa1 792d4e1 201c97f 792d4e1 201c97f 792d4e1 78a9e80 792d4e1 4f12c55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
from fastapi import FastAPI, UploadFile, File, HTTPException, Header
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import numpy as np
import io
# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
task="image-feature-extraction",
model="google/vit-base-patch16-384",
device=0 if torch.cuda.is_available() else -1
)
# Read SECRET_KEY from the environment
SECRET_KEY = os.getenv("SECRET_KEY", "default-secret")
# Initialize FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "App is running successfully!"}
# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)):
try:
# Verify the SECRET_KEY
if secret_key != SECRET_KEY:
raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}")
# Validate file format
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Get feature embeddings
features = feature_extraction_pipeline(image) # Shape: (1, 577, 768)
# Extract CLS token embedding
cls_embedding = np.array(features)[0, 0, :] # Shape: (768,)
# Return the embedding vector
return JSONResponse(content={"features": cls_embedding.tolist()})
except HTTPException:
raise # Reraise HTTPExceptions for proper status codes
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|