File size: 4,428 Bytes
f6c0846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b55b3
f6c0846
 
 
 
 
 
 
 
 
 
 
74b55b3
f6c0846
 
 
74b55b3
f6c0846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74b55b3
f6c0846
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pandas as pd
import plotly.express as px 

# Sequence splitting function
def split_sequence(sequence, max_len=1024, overlap=512):
    chunks = []
    for i in range(0, len(sequence), max_len - overlap):
        chunk = sequence[i:i + max_len]
        if len(chunk) > 0:
            chunks.append(chunk)
    return chunks

# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer(model_name):
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, ignore_mismatched_sizes=True, trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def predict_chunk(model, tokenizer, chunk):
    tokens = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**tokens)
        return outputs.logits

def nucArg_app():
    # Class mappings
    long_read_classes = {
        0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol',
        4: 'fosfomycin', 5: 'fosmidomycin', 6: 'fusidic_acid', 7: 'glycopeptide',
        8: 'kasugamycin', 9: 'macrolide-lincosamide-streptogramin', 10: 'multidrug',
        11: 'mupirocin', 12: 'non_resistant', 13: 'peptide', 14: 'polymyxin',
        15: 'qa_compound', 16: 'quinolone', 17: 'rifampin', 18: 'sulfonamide',
        19: 'tetracenomycin', 20: 'tetracycline', 21: 'trimethoprim', 22: 'tunicamycin'
    }
    short_read_classes = {
        0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol',
        4: 'fosfomycin', 5: 'fosmidomycin', 6: 'glycopeptide', 7: 'macrolide-lincosamide-streptogramin',
        8: 'multidrug', 9: 'mupirocin', 10: 'polymyxin', 11: 'quinolone',
        12: 'sulfonamide', 13: 'tetracycline', 14: 'trimethoprim'
    }

    # Streamlit UI
    st.title("Detecting Antimicrobial Resistance Genes")
    # st.write("This app predicts antibiotic resistance based on DNA sequences.")

    # Input sequence
    sequence = st.text_area("Enter a DNA sequence:", height=200)

    # Initialize models
    model_long, tokenizer_long = load_model_and_tokenizer("vedantM/NucArg_LongRead")
    model_short, tokenizer_short = load_model_and_tokenizer("vedantM/NucArg_ShortRead")

    if sequence:
        if len(sequence) <= 128:
            st.write("Using Short Reads Model.")
            chunks = [sequence]  # No splitting needed
            model, tokenizer, class_mapping = model_short, tokenizer_short, short_read_classes
        else:
            st.write("Using Long Reads Model.")
            chunks = split_sequence(sequence)
            model, tokenizer, class_mapping = model_long, tokenizer_long, long_read_classes

        # Predict for all chunks and aggregate logits
        all_logits = []
        with st.spinner("Predicting..."):
            for chunk in chunks:
                try:
                    logits = predict_chunk(model, tokenizer, chunk)
                    all_logits.append(logits)
                except Exception as e:
                    st.error(f"Error processing chunk: {e}")
                    return

        # Aggregate logits
        aggregated_logits = torch.mean(torch.stack(all_logits), dim=0)
        probabilities = torch.softmax(aggregated_logits, dim=-1).tolist()
        predicted_class = torch.argmax(aggregated_logits).item()

        # Display results
        # st.success("Prediction complete!")
        st.write("### Prediction complete!")
        st.success(f"Predicted Class: **{class_mapping[predicted_class]}**")
        st.write("### Class Probabilities")
        type_probabilities = []
        for idx, prob in enumerate(probabilities[0]):
            # Append to the new dataset list
            type_probabilities.append({
                'Type': str(class_mapping[idx]),
                'Probability': float(prob)
            })
        
        type_probabilities = pd.DataFrame(type_probabilities).sort_values(by='Probability')
        # type_probabilities = type_probabilities.set_index('Type')
        tp = type_probabilities.convert_dtypes()

        # st.bar_chart(data=tp, horizontal=True, x='Probability', y='Type')
        # df=px.data.tips()
        fig=px.bar(tp,x='Probability',y='Type', orientation='h')
        st.write(fig)


if __name__ == "__main__":
    nucArg_app()