vshulev commited on
Commit
9fe11bc
·
1 Parent(s): 973829c

Implement tSNE

Browse files
Files changed (2) hide show
  1. app.py +61 -25
  2. config.py +1 -0
app.py CHANGED
@@ -58,6 +58,7 @@ classification_model.eval()
58
 
59
  # Load datasets
60
  ecolayers_ds = load_dataset(DATASETS["ecolayers"])
 
61
 
62
 
63
  def set_default_inputs():
@@ -133,7 +134,6 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
133
  top_k.values.detach().numpy(),
134
  index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
135
  )
136
- # top_k = pd.Series(top_k.values.detach().numpy(), index=top_k.indices.detach().numpy())
137
 
138
  fig, ax = plt.subplots()
139
  ax.bar(top_k.index.astype(str), top_k.values)
@@ -148,6 +148,34 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
148
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  with gr.Blocks() as demo:
152
  # Header section
153
  gr.Markdown("# DNA Identifier Tool")
@@ -169,16 +197,24 @@ with gr.Blocks() as demo:
169
  inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
170
 
171
  with gr.Row():
172
- btn_run = gr.Button("Predict")
173
- btn_run.click(
174
- fn=preprocess,
175
- inputs=[inp_dna, inp_lat, inp_lng],
176
- )
177
-
178
  btn_defaults = gr.Button("I'm feeling lucky")
179
  btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
180
 
181
  with gr.Tab("Genus Prediction"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  gr.Interface(
183
  fn=predict_genus,
184
  inputs=[
@@ -188,26 +224,26 @@ with gr.Blocks() as demo:
188
  inp_lng,
189
  ],
190
  outputs=["image"],
 
191
  )
192
 
193
- # with gr.Row():
 
 
 
 
 
 
 
194
 
195
- # gr.Markdown("Make plot or table for Top 5 species")
196
-
197
- # with gr.Row():
198
- # genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
199
- # # btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
200
-
201
- with gr.Tab('DNA Embedding Space Visualizer'):
202
- gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.")
203
-
204
- with gr.Row() as row:
205
- with gr.Column():
206
- gr.Markdown("Plot of your DNA sequence among other known species clusters.")
207
- # plot = gr.Plot("")
208
- # btn_run.click(fn=tsne_DNA, inputs=[inp_dna, genus_out])
209
-
210
- with gr.Column():
211
- gr.Markdown("Plot of the five most common species at your sample coordinate.")
212
 
213
  demo.launch()
 
58
 
59
  # Load datasets
60
  ecolayers_ds = load_dataset(DATASETS["ecolayers"])
61
+ amazon_ds = load_dataset(DATASETS["amazon"])
62
 
63
 
64
  def set_default_inputs():
 
134
  top_k.values.detach().numpy(),
135
  index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
136
  )
 
137
 
138
  fig, ax = plt.subplots()
139
  ax.bar(top_k.index.astype(str), top_k.values)
 
148
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
149
 
150
 
151
+ def cluster_dna(top_k: float):
152
+ df = amazon_ds["train"].to_pandas()
153
+ df = df[df["genus"].notna()]
154
+ top_k = int(top_k)
155
+ genus_counts = df["genus"].value_counts()
156
+ top_genuses = genus_counts.head(top_k).index
157
+ df = df[df["genus"].isin(top_genuses)]
158
+ tsne = TSNE(
159
+ n_components=2, perplexity=30, learning_rate=200,
160
+ n_iter=1000, random_state=0,
161
+ )
162
+ X = np.stack(df["embeddings"].tolist())
163
+ y = df["genus"].tolist()
164
+
165
+ X_tsne = tsne.fit_transform(X)
166
+
167
+ label_encoder = LabelEncoder()
168
+ y_encoded = label_encoder.fit_transform(y)
169
+
170
+ fig, ax = plt.subplots()
171
+ ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7)
172
+ ax.set_title(f"DNA Embedding Space (of {str(top_k)} most common genera)")
173
+ # Reduce unnecessary whitespace
174
+ ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1)
175
+ fig.canvas.draw()
176
+
177
+ return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
178
+
179
  with gr.Blocks() as demo:
180
  # Header section
181
  gr.Markdown("# DNA Identifier Tool")
 
197
  inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
198
 
199
  with gr.Row():
 
 
 
 
 
 
200
  btn_defaults = gr.Button("I'm feeling lucky")
201
  btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
202
 
203
  with gr.Tab("Genus Prediction"):
204
+ gr.Markdown("""
205
+ # Genus prediction
206
+
207
+ A demo of predicting the genus of a DNA sequence using multiple
208
+ approaches (method dropdown):
209
+
210
+ - **fine_tuned_model**: using our
211
+ `LofiAmazon/BarcodeBERT-Finetuned-Amazon` which predicts the genus
212
+ based on the DNA sequence and environmental data.
213
+ - **cosine**: computes a cosine similarity between the DNA sequence
214
+ embedding generated by our model and the embeddings of known samples
215
+ that we precomputed and stored in a Pinecone index. Thie method
216
+ DOES NOT examine ecological layer data.
217
+ """)
218
  gr.Interface(
219
  fn=predict_genus,
220
  inputs=[
 
224
  inp_lng,
225
  ],
226
  outputs=["image"],
227
+ allow_flagging="never",
228
  )
229
 
230
+ with gr.Tab("DNA Embedding Space Visualizer"):
231
+ gr.Markdown("""
232
+ # DNA Embedding Space Visualizer
233
+
234
+ We show a 2D t-SNE plot of the DNA embeddings of the five most common
235
+ genera in our dataset. This shows that the DNA Transformer model is
236
+ learning to cluster similar DNA sequences together.
237
+ """)
238
 
239
+ gr.Interface(
240
+ fn=cluster_dna,
241
+ inputs=[
242
+ gr.Slider(minimum=1, maximum=10, step=1, value=5,
243
+ label="Number of top genera to visualize")
244
+ ],
245
+ outputs=["image"],
246
+ allow_flagging="never",
247
+ )
 
 
 
 
 
 
 
 
248
 
249
  demo.launch()
config.py CHANGED
@@ -25,4 +25,5 @@ MODELS = {
25
 
26
  DATASETS = {
27
  "ecolayers": "LofiAmazon/Global-Ecolayers",
 
28
  }
 
25
 
26
  DATASETS = {
27
  "ecolayers": "LofiAmazon/Global-Ecolayers",
28
+ "amazon": "LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon",
29
  }