skanderovitch's picture
Update app.py
201c97f verified
raw
history blame
1.68 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import io
# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
task="image-feature-extraction",
model="google/vit-base-patch16-384",
device=0 if torch.cuda.is_available() else -1
)
# Initialize FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "App is running successfully!"}
# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...)):
try:
# Validate file format
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Get feature embeddings
features = feature_extraction_pipeline(image) # Shape: (1, 577, 768)
# Extract CLS token embedding
cls_embedding = np.array(features)[0, 0, :] # Shape: (768,)
# Return the embedding vector
return JSONResponse(content={"features": cls_embedding.tolist()})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)