skanderovitch commited on
Commit
4cce556
·
verified ·
1 Parent(s): 1c61dbb

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -62
  2. main.py +27 -0
app.py DELETED
@@ -1,62 +0,0 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File, HTTPException, Header
3
- from fastapi.responses import JSONResponse
4
- from transformers import pipeline
5
- from PIL import Image
6
- import torch
7
- import numpy as np
8
- import io
9
-
10
- # Set up the device for the model
11
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
-
13
- # Load the feature extraction pipeline
14
- feature_extraction_pipeline = pipeline(
15
- task="image-feature-extraction",
16
- model="google/vit-base-patch16-384",
17
- device=0 if torch.cuda.is_available() else -1
18
- )
19
-
20
- # Read SECRET_KEY from the environment
21
- SECRET_KEY = os.getenv("SECRET_KEY", "default-secret")
22
-
23
- # Initialize FastAPI
24
- app = FastAPI()
25
-
26
- @app.get("/")
27
- def read_root():
28
- return {"message": "App is running successfully!"}
29
-
30
- # Endpoint to extract image features
31
- @app.post("/extract-features/")
32
- async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)):
33
- try:
34
- # Verify the SECRET_KEY
35
- if secret_key != SECRET_KEY:
36
- raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}")
37
-
38
- # Validate file format
39
- if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
40
- raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
41
-
42
- # Read image
43
- contents = await file.read()
44
- image = Image.open(io.BytesIO(contents)).convert("RGB")
45
-
46
- # Get feature embeddings
47
- features = feature_extraction_pipeline(image) # Shape: (1, 577, 768)
48
-
49
- # Extract CLS token embedding
50
- cls_embedding = np.array(features)[0, 0, :] # Shape: (768,)
51
-
52
- # Return the embedding vector
53
- return JSONResponse(content={"features": cls_embedding.tolist()})
54
-
55
- except HTTPException:
56
- raise # Reraise HTTPExceptions for proper status codes
57
- except Exception as e:
58
- raise HTTPException(status_code=500, detail=str(e))
59
-
60
- if __name__ == "__main__":
61
- import uvicorn
62
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import torchaudio
3
+ from speechbrain.inference.speaker import EncoderClassifier
4
+ import torch
5
+ import io
6
+
7
+ app = FastAPI()
8
+
9
+ # Load model once at startup
10
+ classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb")
11
+
12
+ @app.post("/embed")
13
+ async def get_voice_embedding(file: UploadFile = File(...)):
14
+ if not file.filename.endswith((".wav", ".mp3", ".flac")):
15
+ raise HTTPException(status_code=400, detail="Invalid file format")
16
+
17
+ # Read audio bytes and load into tensor
18
+ audio_bytes = await file.read()
19
+ audio_tensor, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
20
+
21
+ # Compute embedding
22
+ with torch.no_grad():
23
+ embeddings = classifier.encode_batch(audio_tensor)
24
+
25
+ return {
26
+ "embedding": embeddings.squeeze().tolist()
27
+ }