skanderovitch commited on
Commit
792d4e1
·
verified ·
1 Parent(s): d740344

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from transformers import pipeline
4
+ from PIL import Image
5
+ import torch
6
+ import io
7
+
8
+ # Set up the device for the model
9
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Load the feature extraction pipeline
12
+ feature_extraction_pipeline = pipeline(
13
+ task="image-feature-extraction",
14
+ model="google/vit-base-patch16-384",
15
+ device=0 if torch.cuda.is_available() else -1
16
+ )
17
+
18
+ # Initialize FastAPI
19
+ app = FastAPI()
20
+
21
+ # Endpoint to extract image features
22
+ @app.post("/extract-features/")
23
+ async def extract_features(file: UploadFile = File(...)):
24
+ try:
25
+ # Validate file format
26
+ if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
27
+ raise HTTPException(status_code=400, detail="Unsupported file format. Upload a JPEG or PNG image.")
28
+
29
+ # Read image
30
+ contents = await file.read()
31
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
32
+
33
+ # Get feature embeddings
34
+ features = feature_extraction_pipeline(image)
35
+
36
+ # Return the embedding vector
37
+ return JSONResponse(content={"features": features})
38
+
39
+ except Exception as e:
40
+ raise HTTPException(status_code=500, detail=str(e))
41
+