hiyata commited on
Commit
6a3b036
·
verified ·
1 Parent(s): 0eb9745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -32
app.py CHANGED
@@ -7,6 +7,7 @@ import torch.nn as nn
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
 
10
 
11
  class VirusClassifier(nn.Module):
12
  def __init__(self, input_shape: int):
@@ -44,20 +45,15 @@ class VirusClassifier(nn.Module):
44
 
45
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
46
  """Convert sequence to k-mer frequency vector"""
47
- # Generate all possible k-mers
48
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
49
  kmer_dict = {km: i for i, km in enumerate(kmers)}
50
-
51
- # Initialize vector
52
  vec = np.zeros(len(kmers), dtype=np.float32)
53
 
54
- # Count k-mers
55
  for i in range(len(sequence) - k + 1):
56
  kmer = sequence[i:i+k]
57
  if kmer in kmer_dict:
58
  vec[kmer_dict[kmer]] += 1
59
 
60
- # Convert to frequencies
61
  total_kmers = len(sequence) - k + 1
62
  if total_kmers > 0:
63
  vec = vec / total_kmers
@@ -88,7 +84,6 @@ def predict(file_obj):
88
  if file_obj is None:
89
  return "Please upload a FASTA file", None
90
 
91
- # Read the file content
92
  try:
93
  if isinstance(file_obj, str):
94
  text = file_obj
@@ -97,43 +92,31 @@ def predict(file_obj):
97
  except Exception as e:
98
  return f"Error reading file: {str(e)}", None
99
 
100
- # Generate k-mer dictionary
101
- k = 4 # k-mer size
102
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
103
  kmer_dict = {km: i for i, km in enumerate(kmers)}
104
 
105
- # Load model and scaler
106
  try:
107
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
108
- model = VirusClassifier(256).to(device) # k=4 -> 4^4 = 256 features
109
-
110
- # Load model with explicit map_location
111
  state_dict = torch.load('model.pt', map_location=device)
112
  model.load_state_dict(state_dict)
113
-
114
- # Load scaler
115
  scaler = joblib.load('scaler.pkl')
116
-
117
- # Set model to evaluation mode
118
  model.eval()
119
  except Exception as e:
120
  return f"Error loading model: {str(e)}", None
121
 
122
- # Initialize variables to store results and plot
123
  results_text = ""
124
  plot_image = None
125
 
126
  try:
127
  sequences = parse_fasta(text)
128
- # For simplicity, process only the first sequence for plotting
129
  header, seq = sequences[0]
130
 
131
- # Get raw frequency vector and scaled vector
132
  raw_freq_vector = sequence_to_kmer_vector(seq)
133
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
134
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
135
 
136
- # Get predictions and feature importance
137
  with torch.no_grad():
138
  output = model(X_tensor)
139
  probs = torch.softmax(output, dim=1)
@@ -141,11 +124,9 @@ def predict(file_obj):
141
  importance = model.get_feature_importance(X_tensor)
142
  kmer_importance = importance[0].cpu().numpy()
143
 
144
- # Normalize importance scores to original scale
145
  if np.max(np.abs(kmer_importance)) != 0:
146
  kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
147
 
148
- # Get top 10 k-mers based on absolute importance
149
  top_k = 10
150
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
151
  important_kmers = [
@@ -158,11 +139,9 @@ def predict(file_obj):
158
  for i in top_indices
159
  ]
160
 
161
- # Prepare SHAP-like values for waterfall plot
162
  top_features = [item['kmer'] for item in important_kmers]
163
  top_values = [item['importance'] for item in important_kmers]
164
 
165
- # Combine the rest of the features into an "Others" category
166
  others_mask = np.ones_like(kmer_importance, dtype=bool)
167
  others_mask[top_indices] = False
168
  others_sum = np.sum(kmer_importance[others_mask])
@@ -176,19 +155,15 @@ def predict(file_obj):
176
  data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
177
  feature_names=top_features
178
  )
179
- # Manually set expected_value to satisfy waterfall_legacy requirements
180
  explanation.expected_value = 0
181
 
182
- # Generate waterfall plot using SHAP's legacy function
183
  fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
184
 
185
- # Save plot to a bytes buffer
186
  buf = io.BytesIO()
187
  fig.savefig(buf, format='png')
188
  buf.seek(0)
189
- plot_image = buf
190
 
191
- # Format textual results for the first sequence
192
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
193
  pred_label = 'human' if pred_class == 1 else 'non-human'
194
 
@@ -213,7 +188,6 @@ Most influential k-mers (ranked by importance):"""
213
 
214
  return results_text, plot_image
215
 
216
- # Create the interface with two outputs: Textbox and Image
217
  iface = gr.Interface(
218
  fn=predict,
219
  inputs=gr.File(label="Upload FASTA file", type="binary"),
@@ -221,6 +195,5 @@ iface = gr.Interface(
221
  title="Virus Host Classifier"
222
  )
223
 
224
- # Launch the interface
225
  if __name__ == "__main__":
226
- iface.launch()
 
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
10
+ from PIL import Image # Import PIL for image handling
11
 
12
  class VirusClassifier(nn.Module):
13
  def __init__(self, input_shape: int):
 
45
 
46
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
47
  """Convert sequence to k-mer frequency vector"""
 
48
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
49
  kmer_dict = {km: i for i, km in enumerate(kmers)}
 
 
50
  vec = np.zeros(len(kmers), dtype=np.float32)
51
 
 
52
  for i in range(len(sequence) - k + 1):
53
  kmer = sequence[i:i+k]
54
  if kmer in kmer_dict:
55
  vec[kmer_dict[kmer]] += 1
56
 
 
57
  total_kmers = len(sequence) - k + 1
58
  if total_kmers > 0:
59
  vec = vec / total_kmers
 
84
  if file_obj is None:
85
  return "Please upload a FASTA file", None
86
 
 
87
  try:
88
  if isinstance(file_obj, str):
89
  text = file_obj
 
92
  except Exception as e:
93
  return f"Error reading file: {str(e)}", None
94
 
95
+ k = 4
 
96
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
97
  kmer_dict = {km: i for i, km in enumerate(kmers)}
98
 
 
99
  try:
100
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
+ model = VirusClassifier(256).to(device)
 
 
102
  state_dict = torch.load('model.pt', map_location=device)
103
  model.load_state_dict(state_dict)
 
 
104
  scaler = joblib.load('scaler.pkl')
 
 
105
  model.eval()
106
  except Exception as e:
107
  return f"Error loading model: {str(e)}", None
108
 
 
109
  results_text = ""
110
  plot_image = None
111
 
112
  try:
113
  sequences = parse_fasta(text)
 
114
  header, seq = sequences[0]
115
 
 
116
  raw_freq_vector = sequence_to_kmer_vector(seq)
117
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
118
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
119
 
 
120
  with torch.no_grad():
121
  output = model(X_tensor)
122
  probs = torch.softmax(output, dim=1)
 
124
  importance = model.get_feature_importance(X_tensor)
125
  kmer_importance = importance[0].cpu().numpy()
126
 
 
127
  if np.max(np.abs(kmer_importance)) != 0:
128
  kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
129
 
 
130
  top_k = 10
131
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
132
  important_kmers = [
 
139
  for i in top_indices
140
  ]
141
 
 
142
  top_features = [item['kmer'] for item in important_kmers]
143
  top_values = [item['importance'] for item in important_kmers]
144
 
 
145
  others_mask = np.ones_like(kmer_importance, dtype=bool)
146
  others_mask[top_indices] = False
147
  others_sum = np.sum(kmer_importance[others_mask])
 
155
  data=np.array([raw_freq_vector[kmer_dict[feat]] if feat != "Others" else np.sum(raw_freq_vector[others_mask]) for feat in top_features]),
156
  feature_names=top_features
157
  )
 
158
  explanation.expected_value = 0
159
 
 
160
  fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
161
 
 
162
  buf = io.BytesIO()
163
  fig.savefig(buf, format='png')
164
  buf.seek(0)
165
+ plot_image = Image.open(buf) # Convert BytesIO to PIL Image
166
 
 
167
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
168
  pred_label = 'human' if pred_class == 1 else 'non-human'
169
 
 
188
 
189
  return results_text, plot_image
190
 
 
191
  iface = gr.Interface(
192
  fn=predict,
193
  inputs=gr.File(label="Upload FASTA file", type="binary"),
 
195
  title="Virus Host Classifier"
196
  )
197
 
 
198
  if __name__ == "__main__":
199
+ iface.launch(share=True) # Set share=True to create a public link