Vishu26 commited on
Commit
c7a6586
·
1 Parent(s): 31833b2
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -49,6 +49,9 @@ class RangeModel(nn.Module):
49
 
50
  model = RangeModel()
51
  model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
 
 
 
52
 
53
  species = np.load("data/species_70b.npy", allow_pickle=True)
54
  clas = np.load("data/class_70b.npy", allow_pickle=True)
@@ -62,8 +65,6 @@ order_list = list(order[()].keys())
62
  #genus_list = list(genus[()].keys())
63
  #family_list = list(family[()].keys())
64
 
65
- #pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
66
-
67
  def update_fn(val):
68
  if val=="Class":
69
  return gr.Dropdown(label="Name", choices=class_list, interactive=True)
 
49
 
50
  model = RangeModel()
51
  model.load_state_dict(torch.load("model/demo_model.pt", map_location=torch.device('cpu')))
52
+ pos_embed = np.load("data/pos_embed.npy", allow_pickle=True)
53
+ model.x = torch.tensor(pos_embed).unsqueeze(0).float()
54
+ model.eval()
55
 
56
  species = np.load("data/species_70b.npy", allow_pickle=True)
57
  clas = np.load("data/class_70b.npy", allow_pickle=True)
 
65
  #genus_list = list(genus[()].keys())
66
  #family_list = list(family[()].keys())
67
 
 
 
68
  def update_fn(val):
69
  if val=="Class":
70
  return gr.Dropdown(label="Name", choices=class_list, interactive=True)