import os from fastapi import FastAPI, UploadFile, File, HTTPException, Header from fastapi.responses import JSONResponse from transformers import pipeline from PIL import Image import torch import numpy as np import io # Set up the device for the model DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the feature extraction pipeline feature_extraction_pipeline = pipeline( task="image-feature-extraction", model="google/vit-base-patch16-384", device=0 if torch.cuda.is_available() else -1 ) # Read SECRET_KEY from the environment SECRET_KEY = os.getenv("SECRET_KEY", "default-secret") # Initialize FastAPI app = FastAPI() @app.get("/") def read_root(): return {"message": "App is running successfully!"} # Endpoint to extract image features @app.post("/extract-features/") async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)): try: # Verify the SECRET_KEY if secret_key != SECRET_KEY: raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}") # Validate file format if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]: raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}") # Read image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Get feature embeddings features = feature_extraction_pipeline(image) # Shape: (1, 577, 768) # Extract CLS token embedding cls_embedding = np.array(features)[0, 0, :] # Shape: (768,) # Return the embedding vector return JSONResponse(content={"features": cls_embedding.tolist()}) except HTTPException: raise # Reraise HTTPExceptions for proper status codes except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)