panda1835 commited on
Commit
ca3b812
·
verified ·
1 Parent(s): 6db9d59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -15,27 +15,28 @@ import models
15
  print(f"Is CUDA available: {torch.cuda.is_available()}")
16
  # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
- with open("index_to_species.json", "r") as file:
19
- index_to_species_data = file.read()
20
- index_to_species = json.loads(index_to_species_data)
21
 
22
  num_classes = len(list(index_to_species.keys()))
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  # Load the model
27
- classify_model = models.DinoVisionTransformerClassifier(num_classes)
28
- classify_model = classify_model.to(device)
29
- classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device(device)))
30
- classify_model.eval()
31
 
32
  k = 5
33
 
34
  def classify(image):
35
- output = classify_model(image)[0]
 
 
36
  tops = torch.topk(output, k=k).indices
37
  scores = torch.softmax(output, dim=0)[tops]
38
-
39
  result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
40
  sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
41
  # Get the current time
@@ -56,4 +57,5 @@ gr.Interface(
56
  inputs=gr.Image(type="pil", label="Input Image"),
57
  outputs=[gr.JSON()],
58
  title=title,
 
59
  ).launch()
 
15
  print(f"Is CUDA available: {torch.cuda.is_available()}")
16
  # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
 
18
+ # with open("index_to_species.json", "r") as file:
19
+ # index_to_species_data = file.read()
20
+ # index_to_species = json.loads(index_to_species_data)
21
 
22
  num_classes = len(list(index_to_species.keys()))
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
  # Load the model
27
+ linear_model_name = 'linear_2025-07-08.pt'
28
+ classify_model = models.LinearClassifier(input_dim=768, output_dim=num_classes)
29
+ classify_model.load_state_dict(torch.load(os.path.join('models', linear_model_name)))
 
30
 
31
  k = 5
32
 
33
  def classify(image):
34
+ embedding = extract_embedding(image)
35
+ embedding = embedding['embedding']
36
+ output = classify_model(torch.Tensor(embedding).to(device))
37
  tops = torch.topk(output, k=k).indices
38
  scores = torch.softmax(output, dim=0)[tops]
39
+
40
  result = {index_to_species[str(tops[i].item())].replace("_", " "): round(scores[i].item(), 2) for i in range(len(tops))}
41
  sorted_result = {k: v for k, v in sorted(result.items(), key=lambda item: item[1], reverse=True) if v > 0}
42
  # Get the current time
 
57
  inputs=gr.Image(type="pil", label="Input Image"),
58
  outputs=[gr.JSON()],
59
  title=title,
60
+ debug=True
61
  ).launch()