Spaces:
No application file
No application file
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import io | |
# Set up the device for the model | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load the feature extraction pipeline | |
feature_extraction_pipeline = pipeline( | |
task="image-feature-extraction", | |
model="google/vit-base-patch16-384", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
# Initialize FastAPI | |
app = FastAPI() | |
def read_root(): | |
return {"message": "App is running successfully!"} | |
# Endpoint to extract image features | |
async def extract_features(file: UploadFile = File(...)): | |
try: | |
# Validate file format | |
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]: | |
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}") | |
# Read image | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
# Get feature embeddings | |
features = feature_extraction_pipeline(image) | |
# Return the embedding vector | |
return JSONResponse(content={"features": features}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |