benstaf commited on
Commit
0017a72
·
verified ·
1 Parent(s): 96d1185

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,18 +1,18 @@
1
  import os
2
- # MUST be before importing transformers
3
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
4
 
5
  from fastapi import FastAPI, UploadFile, File
6
- from transformers import AutoModelForImageClassification, AutoImageProcessor
7
  from PIL import Image
8
- import torch.nn.functional as F
9
  import torch
 
10
  import io
11
 
12
  app = FastAPI()
13
 
14
- model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Gender-Classifier-Mini")
15
- processor = AutoImageProcessor.from_pretrained("prithivMLmods/Gender-Classifier-Mini")
 
16
 
17
  @app.post("/classify/")
18
  async def classify_gender(image: UploadFile = File(...)):
@@ -21,10 +21,11 @@ async def classify_gender(image: UploadFile = File(...)):
21
  inputs = processor(images=img, return_tensors="pt")
22
 
23
  with torch.no_grad():
24
- logits = model(**inputs).logits
25
- probs = F.softmax(logits, dim=1)
26
- pred = torch.argmax(probs).item()
27
- confidence = probs[0][pred].item()
28
- label = model.config.id2label[pred]
 
29
 
30
- return {"label": label, "confidence": confidence}
 
1
  import os
2
+ os.environ["HF_HOME"] = "/tmp/huggingface"
 
3
 
4
  from fastapi import FastAPI, UploadFile, File
5
+ from transformers import SiglipForImageClassification, AutoImageProcessor
6
  from PIL import Image
 
7
  import torch
8
+ import torch.nn.functional as F
9
  import io
10
 
11
  app = FastAPI()
12
 
13
+ model_name = "prithivMLmods/Gender-Classifier-Mini"
14
+ model = SiglipForImageClassification.from_pretrained(model_name)
15
+ processor = AutoImageProcessor.from_pretrained(model_name)
16
 
17
  @app.post("/classify/")
18
  async def classify_gender(image: UploadFile = File(...)):
 
21
  inputs = processor(images=img, return_tensors="pt")
22
 
23
  with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ logits = outputs.logits
26
+ probs = F.softmax(logits, dim=1).squeeze().tolist()
27
+
28
+ labels = {"0": "Female ♀", "1": "Male ♂"}
29
+ predictions = {labels[str(i)]: round(probs[i], 3) for i in range(len(probs))}
30
 
31
+ return predictions