ChatNT_demo / app.py
Yanisadel's picture
Update app.py
6cb0211 verified
raw
history blame
3.23 kB
# --- Imports ---
import spaces
import gradio as gr
from transformers import pipeline
import pandas as pd
import os
# --- Load Model ---
pipe = pipeline(model="InstaDeepAI/ChatNT", trust_remote_code=True)
# --- Logs ---
log_file = "logs.txt"
class Log:
def __init__(self, log_file):
self.log_file = log_file
def __call__(self):
if not os.path.exists(self.log_file):
return ""
with open(self.log_file, "r") as f:
return f.read()
# --- Main Function ---
@spaces.GPU
def run_chatnt(input_file, custom_question):
with open(log_file, "a") as log:
log.write("Request started\n")
if not custom_question or custom_question.strip() == "":
return pd.DataFrame(), None
# Read DNA sequences
dna_sequences = []
if input_file is not None:
with open(input_file.name, "r") as f:
lines = f.readlines()
for line in lines:
if line.startswith(">"):
continue
dna_sequences.append(line.strip())
if not dna_sequences:
return pd.DataFrame(), None
# Build prompt
english_sequence = custom_question + " <DNA>"
# Call model
output = pipe(
inputs={
"english_sequence": english_sequence,
"dna_sequences": dna_sequences
}
)
# Wrap output
results = []
if isinstance(output, list):
for item in output:
results.append({"Result": item})
else:
results.append({"Result": output})
df = pd.DataFrame(results)
output_file = "output.csv"
df.to_csv(output_file, index=False)
with open(log_file, "a") as log:
log.write("Request finished\n")
return df, output_file
# --- Gradio Interface ---
css = """
.gradio-container { font-family: sans-serif; }
.gr-button { color: white; border-color: black; background: black; }
footer { display: none !important; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# 🧬 ChatNT — DNA Sequence Query Assistant")
with gr.Row():
with gr.Column(scale=1):
input_file = gr.File(
label="Upload DNA Sequence File (.fasta or .txt)",
file_types=[".fasta", ".fa", ".txt"]
)
custom_question = gr.Textbox(
label="English Question (required)",
placeholder="e.g., Does this sequence contain a donor splice site?"
)
submit_btn = gr.Button("Run Query", variant="primary")
with gr.Column(scale=2):
output_df = gr.DataFrame(
label="Results",
headers=["Result"]
)
output_file = gr.File(label="Download Results (CSV)")
submit_btn.click(
run_chatnt,
inputs=[input_file, custom_question],
outputs=[output_df, output_file],
)
gr.Markdown("""
**Note:** Your question **must** include the `<DNA>` token if needed for multiple sequences.
""")
with gr.Accordion("Logs", open=True):
log_display = Log(log_file)
gr.Markdown(log_display)
# --- Launch ---
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True, show_error=True)