LofiAmazonSpace / app.py
jennzhuge
hi
2c039db
raw
history blame
3.88 kB
import json
import pandas as pd
import gradio as gr
# from transformers import PreTrainedTokenizerFast, BertForMaskedLM
from datasets import load_dataset
import infer
with open("default_inputs.json", "r") as default_inputs_file:
DEFAULT_INPUTS = json.load(default_inputs_file)
def set_default_inputs():
return (DEFAULT_INPUTS["dna_sequence"],
DEFAULT_INPUTS["latitude"],
DEFAULT_INPUTS["longitude"])
def preprocess():
''' prepares app input for the genus prediction model
'''
# preprocess DNA seq
# Replace all symbols in nucraw which are not A, C, G, T with N
inp_dna = inp_dna.str.replace("[^ACGT]", "N", regex=True)
# Truncate trailing Ns from nucraw
inp_dna = inp_dna.str.replace("N+$", "", regex=True)
# Insert spaces between all k-mers
inp_dna = inp_dna.apply(lambda x: " ".join([x[i:i+4] for i in range(0, len(x), 4)]))
# load model to calculate new embeddings
tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True)
tokenizer.add_special_tokens({"pad_token": "<UNK>"})
bert_model = BertForMaskedLM.from_pretrained(model, force_download=True)
embed = bert_model.predic(inp_dna)
# format lat and lon into coords
coords = (inp_lat, inp_lng)
# Grab rasters from the tifs
ecoLayers = load_dataset("LofiAmazon/Global-Ecolayers")
temp = pd.DataFrame([coords, embed], columns = ['coord', 'embeddings'])
data = pd.merge(temp, ecoLayers, on='coord', how='left')
return data
def predict_genus():
data = preprocess()
out = infer.infer_dna(data)
results = []
genuses = infer.infer()
results.append({
"sequence": dna_df['nucraw'],
# "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0)
'predictions': genuses})
return results
def tsne():
return plots
with gr.Blocks() as demo:
# Header section
gr.Markdown("# DNA Identifier Tool")
gr.Markdown("Welcome to Lofi Amazon Beats' DNA Identifier Tool")
with gr.Tab("Genus Prediction"):
gr.Markdown("Enter a DNA sequence and the coordinates at which its sample was taken to get a genus prediction. Click 'I'm feeling lucky' to see a prediction for a random sequence.")
# Collect inputs for app (DNA and location)
with gr.row():
with gr.Column():
inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)")
with gr.Column():
with gr.Row():
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
with gr.Row():
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
with gr.Row():
btn_run = gr.Button("Run")
btn_defaults = gr.Button("I'm feeling lucky")
btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
with gr.Row():
gr.Markdown('Make plot or table for Top 5 species')
with gr.Row():
genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
with gr.Tab('DNA Embedding Space Similarity Visualizer'):
gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.")
with gr.Row():
with gr.Column():
gr.Markdown("Plot of your DNA sequence among other known species clusters.")
with gr.Column():
gr.Markdown("Plot of the five most common species at your sample coordinate.")
demo.launch()