vedantM commited on
Commit
f6c0846
·
verified ·
1 Parent(s): e039983

added app file

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import pandas as pd
5
+ import plotly.express as px
6
+
7
+ # Sequence splitting function
8
+ def split_sequence(sequence, max_len=1024, overlap=512):
9
+ chunks = []
10
+ for i in range(0, len(sequence), max_len - overlap):
11
+ chunk = sequence[i:i + max_len]
12
+ if len(chunk) > 0:
13
+ chunks.append(chunk)
14
+ return chunks
15
+
16
+ # Load model and tokenizer
17
+ @st.cache_resource
18
+ def load_model_and_tokenizer(model_name):
19
+ model = AutoModelForSequenceClassification.from_pretrained(
20
+ model_name, ignore_mismatched_sizes=True, trust_remote_code=True
21
+ )
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ return model, tokenizer
24
+
25
+ def predict_chunk(model, tokenizer, chunk):
26
+ tokens = tokenizer(chunk, return_tensors="pt", truncation=True, padding=True)
27
+ with torch.no_grad():
28
+ outputs = model(**tokens)
29
+ return outputs.logits
30
+
31
+ def nucArg_app():
32
+ # Class mappings
33
+ long_read_classes = {
34
+ 0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol',
35
+ 4: 'fosfomycin', 5: 'fosmidomycin', 6: 'fusidic_acid', 7: 'glycopeptide',
36
+ 8: 'kasugamycin', 9: 'macrolide-lincosamide-streptogramin', 10: 'multidrug',
37
+ 11: 'mupirocin', 12: 'non_resistant', 13: 'peptide', 14: 'polymyxin',
38
+ 15: 'qa_compound', 16: 'quinolone', 17: 'rifampin', 18: 'sulfonamide',
39
+ 19: 'tetracenomycin', 20: 'tetracycline', 21: 'trimethoprim', 22: 'tunicamycin'
40
+ }
41
+ short_read_classes = {
42
+ 0: 'aminoglycoside', 1: 'bacitracin', 2: 'beta_lactam', 3: 'chloramphenicol',
43
+ 4: 'fosfomycin', 5: 'fosmidomycin', 6: 'glycopeptide', 7: 'macrolide-lincosamide-streptogramin',
44
+ 8: 'multidrug', 9: 'mupirocin', 10: 'polymyxin', 11: 'quinolone',
45
+ 12: 'sulfonamide', 13: 'tetracycline', 14: 'trimethoprim'
46
+ }
47
+
48
+ # Streamlit UI
49
+ st.title("Antibiotic Resistance Predictor")
50
+ # st.write("This app predicts antibiotic resistance based on DNA sequences.")
51
+
52
+ # Input sequence
53
+ sequence = st.text_area("Enter a DNA sequence:", height=200)
54
+
55
+ # Initialize models
56
+ model_long, tokenizer_long = load_model_and_tokenizer("vedantM/NucArg_LongRead")
57
+ model_short, tokenizer_short = load_model_and_tokenizer("vedantM/NucArg_ShortRead")
58
+
59
+ if sequence:
60
+ if len(sequence) <= 128:
61
+ chunks = [sequence] # No splitting needed
62
+ model, tokenizer, class_mapping = model_short, tokenizer_short, short_read_classes
63
+ else:
64
+ st.write("Input sequence is too large. Splitting into smaller chunks for processing.")
65
+ chunks = split_sequence(sequence)
66
+ model, tokenizer, class_mapping = model_long, tokenizer_long, long_read_classes
67
+
68
+ # Predict for all chunks and aggregate logits
69
+ all_logits = []
70
+ with st.spinner("Predicting..."):
71
+ for chunk in chunks:
72
+ try:
73
+ logits = predict_chunk(model, tokenizer, chunk)
74
+ all_logits.append(logits)
75
+ except Exception as e:
76
+ st.error(f"Error processing chunk: {e}")
77
+ return
78
+
79
+ # Aggregate logits
80
+ aggregated_logits = torch.mean(torch.stack(all_logits), dim=0)
81
+ probabilities = torch.softmax(aggregated_logits, dim=-1).tolist()
82
+ predicted_class = torch.argmax(aggregated_logits).item()
83
+
84
+ # Display results
85
+ # st.success("Prediction complete!")
86
+ st.write("### Prediction complete!")
87
+ st.success(f"Predicted Class: **{class_mapping[predicted_class]}**")
88
+ st.write("### Class Probabilities")
89
+ type_probabilities = []
90
+ for idx, prob in enumerate(probabilities[0]):
91
+ # Append to the new dataset list
92
+ type_probabilities.append({
93
+ 'Type': str(class_mapping[idx]),
94
+ 'Probability': float(prob)
95
+ })
96
+
97
+ type_probabilities = pd.DataFrame(type_probabilities).sort_values(by='Probability')#,ascending=False)
98
+ # type_probabilities = type_probabilities.set_index('Type')
99
+ tp = type_probabilities.convert_dtypes()
100
+
101
+ # st.bar_chart(data=tp, horizontal=True, x='Probability', y='Type')
102
+ # df=px.data.tips()
103
+ fig=px.bar(tp,x='Probability',y='Type', orientation='h')
104
+ st.write(fig)
105
+
106
+
107
+ if __name__ == "__main__":
108
+ nucArg_app()