Vishu26 commited on
Commit
435a5b8
·
1 Parent(s): 3c43ea2
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn as nn
5
  from einops import rearrange
6
  import matplotlib.pyplot as plt
7
 
 
8
 
9
  class Attn(nn.Module):
10
  def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
@@ -79,6 +80,7 @@ def update_fn(val):
79
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
80
 
81
  def text_fn(taxon, name):
 
82
  if taxon=="Class":
83
  text_embeds = clas[()][name]
84
  elif taxon=="Order":
@@ -92,6 +94,7 @@ def text_fn(taxon, name):
92
 
93
  text_embeds = torch.tensor(text_embeds)
94
  preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
 
95
  cmap = plt.get_cmap('Greens')
96
 
97
  rgba_img = cmap(preds)
@@ -99,22 +102,14 @@ def text_fn(taxon, name):
99
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
100
  return rgb_img
101
 
102
- def pred_fn(taxon, name):
103
- if taxon=="Class":
104
- text_embeds = clas[()][name]
105
- elif taxon=="Order":
106
- text_embeds = order[()][name]
107
- elif taxon=="Family":
108
- text_embeds = family[()][name]
109
- elif taxon=="Genus":
110
- text_embeds = genus[()][name]
111
- elif taxon=="Species":
112
- text_embeds = species[()][name]
113
-
114
- text_embeds = torch.tensor(text_embeds)
115
- preds = model(text_embeds).sigmoid().unsqueeze(0).unsqueeze(0).detach().numpy()
116
- return gr.Image(preds, label="Predicted Heatmap", visible=True)
117
 
 
 
 
118
 
119
  with gr.Blocks() as demo:
120
  gr.Markdown(
@@ -137,5 +132,6 @@ with gr.Blocks() as demo:
137
  pred = gr.Image(label="Predicted Heatmap", visible=True)
138
 
139
  check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
 
140
 
141
  demo.launch()
 
5
  from einops import rearrange
6
  import matplotlib.pyplot as plt
7
 
8
+ pred_global = None
9
 
10
  class Attn(nn.Module):
11
  def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
 
80
  return gr.Dropdown(label="Name", choices=species_list, interactive=True)
81
 
82
  def text_fn(taxon, name):
83
+ global pred_global
84
  if taxon=="Class":
85
  text_embeds = clas[()][name]
86
  elif taxon=="Order":
 
94
 
95
  text_embeds = torch.tensor(text_embeds)
96
  preds = model(text_embeds).sigmoid().squeeze(0).squeeze(0).detach().numpy()
97
+ pred_global = preds
98
  cmap = plt.get_cmap('Greens')
99
 
100
  rgba_img = cmap(preds)
 
102
  #return gr.Image(preds, label="Predicted Heatmap", visible=True)
103
  return rgb_img
104
 
105
+ def thresh_fn(val):
106
+ global pred_global
107
+ pred_global = pred_global > val
108
+ cmap = plt.get_cmap('Greens')
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ rgba_img = cmap(pred_global)
111
+ rgb_img = np.delete(rgba_img, 3, 2)
112
+ return rgb_img
113
 
114
  with gr.Blocks() as demo:
115
  gr.Markdown(
 
132
  pred = gr.Image(label="Predicted Heatmap", visible=True)
133
 
134
  check_button.click(text_fn, inputs=[inp, out], outputs=[pred])
135
+ slider.change(thresh_fn, slider, pred)
136
 
137
  demo.launch()