skanderovitch's picture
Update app.py
a99a830 verified
raw
history blame
1.53 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import io
# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
task="image-feature-extraction",
model="google/vit-base-patch16-384",
device=0 if torch.cuda.is_available() else -1
)
# Initialize FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "App is running successfully!"}
# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...)):
try:
# Validate file format
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Get feature embeddings
features = feature_extraction_pipeline(image)
# Return the embedding vector
return JSONResponse(content={"features": features})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)