oucgc1996 commited on
Commit
75548e5
1 Parent(s): f235a18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -0
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import VolumeMaker
2
+ import utils
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import torch.nn as nn
7
+ import pandas as pd
8
+ import shutil
9
+ import subprocess
10
+ from transformers import AutoModelForSequenceClassification
11
+ from torch.utils.data import Dataset,DataLoader
12
+ import pandas as pd
13
+ device = torch.device("cpu")
14
+ import os
15
+ join=os.path.join
16
+ from transformers import AutoTokenizer
17
+ import torch.nn.functional as F
18
+ from rdkit import Chem
19
+ from rdkit.Chem import AllChem
20
+ from collections import OrderedDict
21
+ from tqdm import tqdm
22
+ import time
23
+
24
+ model_checkpoint = "facebook/esm2_t6_8M_UR50D"
25
+ pdb_path = "structure"
26
+ # seq_path = "test3.csv"
27
+ temp_path = "temp"
28
+
29
+ def setup_seed(seed):
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+ setup_seed(4)
36
+
37
+
38
+ batch_size = 1
39
+ num_labels = 2
40
+ radius = 2
41
+ n_features = 1024
42
+ hid_dim = 300
43
+ n_heads = 1
44
+ dropout = 0
45
+
46
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
47
+
48
+ class MyDataset(Dataset):
49
+ def __init__(self,dict_data) -> None:
50
+ super(MyDataset,self).__init__()
51
+ self.data=dict_data
52
+ self.structure=pdb_structure(dict_data['structure'])
53
+ def __getitem__(self, index):
54
+ return self.data['text'][index], self.structure[index]
55
+ def __len__(self):
56
+ return len(self.data['text'])
57
+
58
+ def collate_fn(batch):
59
+ data = [item[0] for item in batch]
60
+ structure = torch.tensor([item[1].tolist() for item in batch]).to(device)
61
+ max_len = max([len(b[0]) for b in batch])+2
62
+ fingerprint = torch.tensor(peptides_to_fingerprint_matrix(data, radius, n_features),dtype=float).to(device)
63
+ pt_batch=tokenizer(data, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
64
+ return {'input_ids':pt_batch['input_ids'].to(device),
65
+ 'attention_mask':pt_batch['attention_mask'].to(device)}, structure, fingerprint
66
+
67
+ class AttentionBlock(nn.Module):
68
+ def __init__(self, hid_dim, n_heads, dropout):
69
+ super().__init__()
70
+
71
+ self.hid_dim = hid_dim
72
+ self.n_heads = n_heads
73
+
74
+ assert hid_dim % n_heads == 0
75
+
76
+ self.f_q = nn.Linear(hid_dim, hid_dim)
77
+ self.f_k = nn.Linear(hid_dim, hid_dim)
78
+ self.f_v = nn.Linear(hid_dim, hid_dim)
79
+
80
+ self.fc = nn.Linear(hid_dim, hid_dim)
81
+
82
+ self.do = nn.Dropout(dropout)
83
+
84
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).cuda()
85
+
86
+ def forward(self, query, key, value, mask=None):
87
+ batch_size = query.shape[0]
88
+
89
+ Q = self.f_q(query)
90
+ K = self.f_k(key)
91
+ V = self.f_v(value)
92
+
93
+ Q = Q.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
94
+ K_T = K.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3).transpose(2,3)
95
+ V = V.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
96
+
97
+ energy = torch.matmul(Q, K_T) / self.scale
98
+
99
+ if mask is not None:
100
+ energy = energy.masked_fill(mask == 0, -1e10)
101
+
102
+ attention = self.do(F.softmax(energy, dim=-1))
103
+
104
+ weighter_matrix = torch.matmul(attention, V)
105
+
106
+ weighter_matrix = weighter_matrix.permute(0, 2, 1, 3).contiguous()
107
+
108
+ weighter_matrix = weighter_matrix.view(batch_size, self.n_heads * (self.hid_dim // self.n_heads))
109
+
110
+ weighter_matrix = self.do(self.fc(weighter_matrix))
111
+
112
+ return weighter_matrix
113
+
114
+ class CrossAttentionBlock(nn.Module):
115
+ def __init__(self):
116
+ super(CrossAttentionBlock, self).__init__()
117
+ self.att = AttentionBlock(hid_dim = hid_dim, n_heads = n_heads, dropout=0.1)
118
+ def forward(self, structure_feature, fingerprint_feature, sequence_feature):
119
+ # cross attention for compound information enrichment
120
+ fingerprint_feature = fingerprint_feature + self.att(fingerprint_feature, structure_feature, structure_feature)
121
+ # self-attention
122
+ fingerprint_feature = self.att(fingerprint_feature, fingerprint_feature, fingerprint_feature)
123
+ # cross-attention for interaction
124
+ output = self.att(fingerprint_feature, sequence_feature, sequence_feature)
125
+ return output
126
+
127
+ def peptides_to_fingerprint_matrix(peptides, radius=radius, n_features=n_features):
128
+ n_peptides = len(peptides)
129
+ features = np.zeros((n_peptides, n_features))
130
+ for i, peptide in enumerate(peptides):
131
+ mol = Chem.MolFromSequence(peptide)
132
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_features)
133
+ fp_array = np.zeros((1,))
134
+ AllChem.DataStructs.ConvertToNumpyArray(fp, fp_array)
135
+ features[i, :] = fp_array
136
+ return features
137
+
138
+ class MyModel(nn.Module):
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.bert = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=hid_dim)
142
+ self.bn1 = nn.BatchNorm1d(256)
143
+ self.bn2 = nn.BatchNorm1d(128)
144
+ self.bn3 = nn.BatchNorm1d(64)
145
+ self.relu = nn.ReLU()
146
+ self.fc1 = nn.Linear(300,256)
147
+ self.fc2 = nn.Linear(256,128)
148
+ self.fc3 = nn.Linear(128,64)
149
+ self.fc_fingerprint = nn.Linear(1024,hid_dim)
150
+ self.fc_structure = nn.Linear(1500,hid_dim)
151
+ self.fingerprint_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=1024, hidden_size=1024//2, batch_first=True)
152
+ self.structure_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=500, hidden_size=500//2, batch_first=True)
153
+ self.output_layer = nn.Linear(64,num_labels)
154
+ self.dropout = nn.Dropout(0)
155
+ self.CAB = CrossAttentionBlock()
156
+ def forward(self,structure, x, fingerprint):
157
+ fingerprint = torch.unsqueeze(fingerprint, 2).float()
158
+ structure = structure.permute(0, 2, 1)
159
+ fingerprint = fingerprint.permute(0, 2, 1)
160
+ with torch.no_grad():
161
+ bert_output = self.bert(input_ids=x['input_ids'].to(device),attention_mask=x['attention_mask'].to(device))
162
+ sequence_feature = self.dropout(bert_output["logits"])
163
+ structure = structure.to(device)
164
+ fingerprint_feature, _ = self.fingerprint_lstm(fingerprint)
165
+ structure_feature, _ = self.structure_lstm(structure)
166
+ fingerprint_feature = fingerprint_feature.flatten(start_dim=1)
167
+ structure_feature = structure_feature.flatten(start_dim=1)
168
+ fingerprint_feature = self.fc_fingerprint(fingerprint_feature)
169
+ structure_feature = self.fc_structure(structure_feature)
170
+ output_feature = self.CAB(structure_feature, fingerprint_feature, sequence_feature)
171
+ output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
172
+ output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
173
+ output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
174
+ output_feature = self.dropout(self.output_layer(output_feature))
175
+ return torch.softmax(output_feature,dim=1)
176
+
177
+
178
+ def pdb_structure(Structure_index):
179
+ created_folders = []
180
+ SurfacePoitCloud_all = []
181
+ for index in Structure_index:
182
+ structure_folder = join(temp_path, str(index))
183
+ os.makedirs(structure_folder, exist_ok=True)
184
+ created_folders.append(structure_folder)
185
+ pdb_file = join(pdb_path, f"{index}.pdb")
186
+ if os.path.exists(pdb_file):
187
+ shutil.copy2(pdb_file, structure_folder)
188
+ else:
189
+ print(f"PDB file not found for structure {index}")
190
+ coords, atname, pdbname, pdb_num = utils.parsePDB(structure_folder)
191
+ atoms_channel = utils.atomlistToChannels(atname)
192
+ radius = utils.atomlistToRadius(atname)
193
+ PointCloudSurfaceObject = VolumeMaker.PointCloudSurface(device=device)
194
+ coords = coords.to(device)
195
+ radius = radius.to(device)
196
+ atoms_channel = atoms_channel.to(device)
197
+ SurfacePoitCloud = PointCloudSurfaceObject(coords, radius)
198
+ feature = SurfacePoitCloud.view(pdb_num,-1,3).cpu()
199
+ SurfacePoitCloud_all.append(feature)
200
+ SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(SurfacePoitCloud_all),dim=1)
201
+ for folder in created_folders:
202
+ shutil.rmtree(folder)
203
+ return SurfacePoitCloud_all_tensor
204
+
205
+ def ACE(file):
206
+ if not os.path.exists(pdb_path):
207
+ os.makedirs(pdb_path)
208
+ else:
209
+ shutil.rmtree(pdb_path)
210
+ os.makedirs(pdb_path)
211
+ # df = pd.read_csv(seq_path)
212
+ # test_sequences = df["Seq"].tolist()
213
+ # test_Structure_index = df["Structure_index"].tolist()
214
+
215
+ test_sequences = [file]
216
+ test_Structure_index = [f"structure_{i}" for i in range(len(test_sequences))]
217
+
218
+
219
+ test_dict = {"text":test_sequences, 'structure':test_Structure_index}
220
+ print("=================================Structure prediction========================")
221
+ for i in tqdm(range(0, len(test_sequences))):
222
+ while True:
223
+ command = ["curl", "-X", "POST", "-k", "--data", f"{test_sequences[i]}", "https://api.esmatlas.com/foldSequence/v1/pdb/"]
224
+ result = subprocess.run(command, capture_output=True, text=True)
225
+ with open(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'), 'w') as file:
226
+ file.write(result.stdout)
227
+ stats = os.stat(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'))
228
+ if stats.st_size < 1024:
229
+ print(f"Download for {test_Structure_index[i]} failed due to empty file. Retrying...")
230
+ time.sleep(20)
231
+ continue
232
+ else:
233
+ break
234
+ test_data=MyDataset(test_dict)
235
+ test_dataloader=DataLoader(test_data,batch_size=batch_size,collate_fn=collate_fn,shuffle=False)
236
+
237
+ # 导入模型
238
+ model = MyModel()
239
+ model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')), strict=False)
240
+ model = model.to(device)
241
+
242
+ # 预测
243
+ model.eval()
244
+ with torch.no_grad():
245
+ probability_all = []
246
+ Target_all = []
247
+ print("=================================Start prediction========================")
248
+ for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
249
+ batchs = {k: v for k, v in batch.items()}
250
+ outputs = model(structure_fea, batchs, fingerprint)
251
+ probability = outputs[0].tolist()
252
+ train_argmax = np.argmax(outputs.cpu().detach().numpy(), axis=1)
253
+ for j in range(0,len(train_argmax)):
254
+ output = train_argmax[j]
255
+ if output == 0:
256
+ Target = "low"
257
+ probability = probability[0]
258
+ elif output == 1:
259
+ Target = "high"
260
+ probability = probability[1]
261
+ print(Target, probability)
262
+ probability_all.append(probability)
263
+ Target_all.append(Target)
264
+ summary = OrderedDict()
265
+ summary['Seq'] = test_sequences
266
+ summary['Target'] = Target_all
267
+ summary['Probability'] = probability_all
268
+ summary_df = pd.DataFrame(summary)
269
+ summary_df.to_csv('output.csv', index=False)
270
+ if len(test_sequences) > 1:
271
+ out_text = "Please download csv"
272
+ out_prob = "Please download csv"
273
+ else:
274
+ out_text = output
275
+ out_prob = probability
276
+ return 'outputs.csv', out_text, out_prob
277
+
278
+ iface = gr.Interface(fn=ACE,
279
+ inputs="text",
280
+ outputs= ["file","text","text"])
281
+ iface.launch()