panda1835 commited on
Commit
80bb40d
·
1 Parent(s): c91f3be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -11,15 +11,22 @@ import gradio as gr
11
 
12
  import models
13
 
 
 
 
 
 
14
  with open("index_to_species.json", "r") as file:
15
  index_to_species_data = file.read()
16
  index_to_species = json.loads(index_to_species_data)
17
 
18
  num_classes = len(list(index_to_species.keys()))
19
 
 
 
20
  # Load the model
21
  classify_model = models.DinoVisionTransformerClassifier(num_classes)
22
- classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device('cpu')))
23
  classify_model.eval()
24
 
25
  k = 5
 
11
 
12
  import models
13
 
14
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
15
+ # True
16
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
17
+ # Tesla T4
18
+
19
  with open("index_to_species.json", "r") as file:
20
  index_to_species_data = file.read()
21
  index_to_species = json.loads(index_to_species_data)
22
 
23
  num_classes = len(list(index_to_species.keys()))
24
 
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
  # Load the model
28
  classify_model = models.DinoVisionTransformerClassifier(num_classes)
29
+ classify_model.load_state_dict(torch.load("best_dinov2_both_2023-11-21_07-44-35.pth", map_location=torch.device(device)))
30
  classify_model.eval()
31
 
32
  k = 5