File size: 1,526 Bytes
792d4e1 4f12c55 792d4e1 a99a830 d739aa1 792d4e1 4f12c55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import io
# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
task="image-feature-extraction",
model="google/vit-base-patch16-384",
device=0 if torch.cuda.is_available() else -1
)
# Initialize FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "App is running successfully!"}
# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...)):
try:
# Validate file format
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Get feature embeddings
features = feature_extraction_pipeline(image)
# Return the embedding vector
return JSONResponse(content={"features": features})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|