|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
def split_sequence(sequence, max_len=1024, overlap=512): |
|
chunks = [] |
|
for i in range(0, len(sequence), max_len - overlap): |
|
chunk = sequence[i:i + max_len] |
|
if len(chunk) > 0: |
|
chunks.append(chunk) |
|
return chunks |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(model_name): |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_name, ignore_mismatched_sizes=True, trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
def predict_chunk(model, tokenizer, chunk): |
|
tokens = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**tokens) |
|
return outputs.logits |
|
|
|
def nucArg_app(): |
|
|
|
long_read_classes = { |
|
0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol', |
|
4: 'fosfomycin', 5: 'fosmidomycin', 6: 'fusidic_acid', 7: 'glycopeptide', |
|
8: 'kasugamycin', 9: 'macrolide-lincosamide-streptogramin', 10: 'multidrug', |
|
11: 'mupirocin', 12: 'non_resistant', 13: 'peptide', 14: 'polymyxin', |
|
15: 'qa_compound', 16: 'quinolone', 17: 'rifampin', 18: 'sulfonamide', |
|
19: 'tetracenomycin', 20: 'tetracycline', 21: 'trimethoprim', 22: 'tunicamycin' |
|
} |
|
short_read_classes = { |
|
0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol', |
|
4: 'fosfomycin', 5: 'fosmidomycin', 6: 'glycopeptide', 7: 'macrolide-lincosamide-streptogramin', |
|
8: 'multidrug', 9: 'mupirocin', 10: 'polymyxin', 11: 'quinolone', |
|
12: 'sulfonamide', 13: 'tetracycline', 14: 'trimethoprim' |
|
} |
|
|
|
|
|
st.title("Detecting Antimicrobial Resistance Genes") |
|
|
|
|
|
|
|
sequence = st.text_area("Enter a DNA sequence:", height=200) |
|
|
|
|
|
model_long, tokenizer_long = load_model_and_tokenizer("vedantM/NucArg_LongRead") |
|
model_short, tokenizer_short = load_model_and_tokenizer("vedantM/NucArg_ShortRead") |
|
|
|
if sequence: |
|
if len(sequence) <= 128: |
|
st.write("Using Short Reads Model.") |
|
chunks = [sequence] |
|
model, tokenizer, class_mapping = model_short, tokenizer_short, short_read_classes |
|
else: |
|
st.write("Using Long Reads Model.") |
|
chunks = split_sequence(sequence) |
|
model, tokenizer, class_mapping = model_long, tokenizer_long, long_read_classes |
|
|
|
|
|
all_logits = [] |
|
with st.spinner("Predicting..."): |
|
for chunk in chunks: |
|
try: |
|
logits = predict_chunk(model, tokenizer, chunk) |
|
all_logits.append(logits) |
|
except Exception as e: |
|
st.error(f"Error processing chunk: {e}") |
|
return |
|
|
|
|
|
aggregated_logits = torch.mean(torch.stack(all_logits), dim=0) |
|
probabilities = torch.softmax(aggregated_logits, dim=-1).tolist() |
|
predicted_class = torch.argmax(aggregated_logits).item() |
|
|
|
|
|
|
|
st.write("### Prediction complete!") |
|
st.success(f"Predicted Class: **{class_mapping[predicted_class]}**") |
|
st.write("### Class Probabilities") |
|
type_probabilities = [] |
|
for idx, prob in enumerate(probabilities[0]): |
|
|
|
type_probabilities.append({ |
|
'Type': str(class_mapping[idx]), |
|
'Probability': float(prob) |
|
}) |
|
|
|
type_probabilities = pd.DataFrame(type_probabilities).sort_values(by='Probability') |
|
|
|
tp = type_probabilities.convert_dtypes() |
|
|
|
|
|
|
|
fig=px.bar(tp,x='Probability',y='Type', orientation='h') |
|
st.write(fig) |
|
|
|
|
|
if __name__ == "__main__": |
|
nucArg_app() |
|
|