panda1835 commited on
Commit
c06ec80
·
1 Parent(s): 4edb516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -39
app.py CHANGED
@@ -1,48 +1,37 @@
1
- from keras.models import load_model
2
- from PIL import Image, ImageOps
3
- import numpy as np
 
 
 
 
 
 
4
  import gradio as gr
5
 
6
- # Load the model
7
- classify_model = load_model('keras_model.h5')
 
 
 
8
 
9
- def format_label(label):
10
- """
11
- From '0 class 1\n' to 'class 1'
12
- """
13
- return label[:-1]
 
14
 
15
 
16
  def classify(image):
17
- # Create the array of the right shape to feed into the keras model
18
- # The 'length' or number of images you can put into the array is
19
- # determined by the first position in the shape tuple, in this case 1.
20
- data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
21
-
22
- #resize the image to a 224x224 with the same strategy as in TM2:
23
- #resizing the image to be at least 224x224 and then cropping from the center
24
- size = (224, 224)
25
- image = ImageOps.fit(image, size, Image.LANCZOS)
26
-
27
- #turn the image into a numpy array
28
- image_array = np.asarray(image)
29
- # Normalize the image
30
- normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
31
- # Load the image into the array
32
- data[0] = normalized_image_array
33
-
34
- # run the inference
35
- pred = classify_model.predict(data)
36
- pred = pred.tolist()
37
-
38
- with open('labels.txt','r') as f:
39
- labels = f.readlines()
40
-
41
- result = {format_label(labels[i]): round(pred[0][i],2) for i in range(len(pred[0]))}
42
- sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
43
-
44
-
45
- return sorted_result
46
 
47
 
48
  title = "🐢"
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import warnings
5
+
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ from PIL import Image
10
  import gradio as gr
11
 
12
+ import models
13
+
14
+ with open("../prepare_data/index_to_species.json", "r") as file:
15
+ index_to_species_data = file.read()
16
+ index_to_species = json.loads(index_to_species_data)
17
 
18
+ num_classes = len(list(index_to_species.keys()))
19
+
20
+ # Load the model
21
+ classify_model = models.DinoVisionTransformerClassifier(num_classes)
22
+ classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth"))
23
+ classify_model.eval()
24
 
25
 
26
  def classify(image):
27
+ output = classify_model(image)[0]
28
+ tops = torch.topk(output, k=k).indices
29
+ scores = torch.softmax(output, dim=0)[tops]
30
+
31
+ result = {index_to_species[str(tops[i].item())]: round(scores[i].item(), 2) for i in range(len(tops))}
32
+ sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
33
+
34
+ return sorted_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  title = "🐢"