Upload app.py
Browse files
@@ -0,0 +1,281 @@
1 |
import VolumeMaker
2 |
import utils
3 |
import numpy as np
4 |
import random
5 |
import torch
6 |
import torch.nn as nn
7 |
import pandas as pd
8 |
import shutil
9 |
import subprocess
10 |
from transformers import AutoModelForSequenceClassification
11 |
from torch.utils.data import Dataset,DataLoader
12 |
import pandas as pd
13 |
device = torch.device("cpu")
14 |
import os
15 |
16 |
from transformers import AutoTokenizer
17 |
import torch.nn.functional as F
18 |
from rdkit import Chem
19 |
from rdkit.Chem import AllChem
20 |
from collections import OrderedDict
21 |
from tqdm import tqdm
22 |
import time
23 |
24 |
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
25 |
pdb_path = "structure"
26 |
# seq_path = "test3.csv"
27 |
temp_path = "temp"
28 |
29 |
def setup_seed(seed):
30 |
31 |
32 |
33 |
34 |
torch.backends.cudnn.deterministic = True
35 |
36 |
37 |
38 |
batch_size = 1
39 |
num_labels = 2
40 |
radius = 2
41 |
n_features = 1024
42 |
hid_dim = 300
43 |
n_heads = 1
44 |
dropout = 0
45 |
46 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
47 |
48 |
class MyDataset(Dataset):
49 |
def __init__(self,dict_data) -> None:
50 |
51 |
52 |
53 |
def __getitem__(self, index):
54 |
return self.data['text'][index], self.structure[index]
55 |
def __len__(self):
56 |
return len(self.data['text'])
57 |
58 |
def collate_fn(batch):
59 |
data = [item[0] for item in batch]
60 |
structure = torch.tensor([item[1].tolist() for item in batch]).to(device)
61 |
max_len = max([len(b[0]) for b in batch])+2
62 |
fingerprint = torch.tensor(peptides_to_fingerprint_matrix(data, radius, n_features),dtype=float).to(device)
63 |
pt_batch=tokenizer(data, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
64 |
return {'input_ids':pt_batch['input_ids'].to(device),
65 |
'attention_mask':pt_batch['attention_mask'].to(device)}, structure, fingerprint
66 |
67 |
class AttentionBlock(nn.Module):
68 |
def __init__(self, hid_dim, n_heads, dropout):
69 |
70 |
71 |
self.hid_dim = hid_dim
72 |
self.n_heads = n_heads
73 |
74 |
assert hid_dim % n_heads == 0
75 |
76 |
self.f_q = nn.Linear(hid_dim, hid_dim)
77 |
self.f_k = nn.Linear(hid_dim, hid_dim)
78 |
self.f_v = nn.Linear(hid_dim, hid_dim)
79 |
80 |
self.fc = nn.Linear(hid_dim, hid_dim)
81 |
82 |
self.do = nn.Dropout(dropout)
83 |
84 |
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).cuda()
85 |
86 |
def forward(self, query, key, value, mask=None):
87 |
batch_size = query.shape[0]
88 |
89 |
Q = self.f_q(query)
90 |
K = self.f_k(key)
91 |
V = self.f_v(value)
92 |
93 |
Q = Q.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
94 |
K_T = K.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3).transpose(2,3)
95 |
V = V.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
96 |
97 |
energy = torch.matmul(Q, K_T) / self.scale
98 |
99 |
if mask is not None:
100 |
energy = energy.masked_fill(mask == 0, -1e10)
101 |
102 |
attention = self.do(F.softmax(energy, dim=-1))
103 |
104 |
weighter_matrix = torch.matmul(attention, V)
105 |
106 |
weighter_matrix = weighter_matrix.permute(0, 2, 1, 3).contiguous()
107 |
108 |
weighter_matrix = weighter_matrix.view(batch_size, self.n_heads * (self.hid_dim // self.n_heads))
109 |
110 |
weighter_matrix = self.do(self.fc(weighter_matrix))
111 |
112 |
return weighter_matrix
113 |
114 |
class CrossAttentionBlock(nn.Module):
115 |
def __init__(self):
116 |
super(CrossAttentionBlock, self).__init__()
117 |
self.att = AttentionBlock(hid_dim = hid_dim, n_heads = n_heads, dropout=0.1)
118 |
def forward(self, structure_feature, fingerprint_feature, sequence_feature):
119 |
# cross attention for compound information enrichment
120 |
fingerprint_feature = fingerprint_feature + self.att(fingerprint_feature, structure_feature, structure_feature)
121 |
# self-attention
122 |
fingerprint_feature = self.att(fingerprint_feature, fingerprint_feature, fingerprint_feature)
123 |
# cross-attention for interaction
124 |
output = self.att(fingerprint_feature, sequence_feature, sequence_feature)
125 |
return output
126 |
127 |
def peptides_to_fingerprint_matrix(peptides, radius=radius, n_features=n_features):
128 |
n_peptides = len(peptides)
129 |
features = np.zeros((n_peptides, n_features))
130 |
for i, peptide in enumerate(peptides):
131 |
mol = Chem.MolFromSequence(peptide)
132 |
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_features)
133 |
fp_array = np.zeros((1,))
134 |
AllChem.DataStructs.ConvertToNumpyArray(fp, fp_array)
135 |
features[i, :] = fp_array
136 |
return features
137 |
138 |
class MyModel(nn.Module):
139 |
def __init__(self):
140 |
141 |
self.bert = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=hid_dim)
142 |
self.bn1 = nn.BatchNorm1d(256)
143 |
self.bn2 = nn.BatchNorm1d(128)
144 |
self.bn3 = nn.BatchNorm1d(64)
145 |
self.relu = nn.ReLU()
146 |
self.fc1 = nn.Linear(300,256)
147 |
self.fc2 = nn.Linear(256,128)
148 |
self.fc3 = nn.Linear(128,64)
149 |
self.fc_fingerprint = nn.Linear(1024,hid_dim)
150 |
self.fc_structure = nn.Linear(1500,hid_dim)
151 |
self.fingerprint_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=1024, hidden_size=1024//2, batch_first=True)
152 |
self.structure_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=500, hidden_size=500//2, batch_first=True)
153 |
self.output_layer = nn.Linear(64,num_labels)
154 |
self.dropout = nn.Dropout(0)
155 |
self.CAB = CrossAttentionBlock()
156 |
def forward(self,structure, x, fingerprint):
157 |
fingerprint = torch.unsqueeze(fingerprint, 2).float()
158 |
structure = structure.permute(0, 2, 1)
159 |
fingerprint = fingerprint.permute(0, 2, 1)
160 |
with torch.no_grad():
161 |
bert_output = self.bert(input_ids=x['input_ids'].to(device),attention_mask=x['attention_mask'].to(device))
162 |
sequence_feature = self.dropout(bert_output["logits"])
163 |
structure = structure.to(device)
164 |
fingerprint_feature, _ = self.fingerprint_lstm(fingerprint)
165 |
structure_feature, _ = self.structure_lstm(structure)
166 |
fingerprint_feature = fingerprint_feature.flatten(start_dim=1)
167 |
structure_feature = structure_feature.flatten(start_dim=1)
168 |
fingerprint_feature = self.fc_fingerprint(fingerprint_feature)
169 |
structure_feature = self.fc_structure(structure_feature)
170 |
output_feature = self.CAB(structure_feature, fingerprint_feature, sequence_feature)
171 |
output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
172 |
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
173 |
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
174 |
output_feature = self.dropout(self.output_layer(output_feature))
175 |
return torch.softmax(output_feature,dim=1)
176 |
177 |
178 |
def pdb_structure(Structure_index):
179 |
created_folders = []
180 |
SurfacePoitCloud_all = []
181 |
for index in Structure_index:
182 |
structure_folder = join(temp_path, str(index))
183 |
os.makedirs(structure_folder, exist_ok=True)
184 |
185 |
pdb_file = join(pdb_path, f"{index}.pdb")
186 |
if os.path.exists(pdb_file):
187 |
shutil.copy2(pdb_file, structure_folder)
188 |
189 |
print(f"PDB file not found for structure {index}")
190 |
coords, atname, pdbname, pdb_num = utils.parsePDB(structure_folder)
191 |
atoms_channel = utils.atomlistToChannels(atname)
192 |
radius = utils.atomlistToRadius(atname)
193 |
PointCloudSurfaceObject = VolumeMaker.PointCloudSurface(device=device)
194 |
coords = coords.to(device)
195 |
radius = radius.to(device)
196 |
atoms_channel = atoms_channel.to(device)
197 |
SurfacePoitCloud = PointCloudSurfaceObject(coords, radius)
198 |
feature = SurfacePoitCloud.view(pdb_num,-1,3).cpu()
199 |
200 |
SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(SurfacePoitCloud_all),dim=1)
201 |
for folder in created_folders:
202 |
203 |
return SurfacePoitCloud_all_tensor
204 |
205 |
def ACE(file):
206 |
if not os.path.exists(pdb_path):
207 |
208 |
209 |
210 |
211 |
# df = pd.read_csv(seq_path)
212 |
# test_sequences = df["Seq"].tolist()
213 |
# test_Structure_index = df["Structure_index"].tolist()
214 |
215 |
test_sequences = [file]
216 |
test_Structure_index = [f"structure_{i}" for i in range(len(test_sequences))]
217 |
218 |
219 |
test_dict = {"text":test_sequences, 'structure':test_Structure_index}
220 |
print("=================================Structure prediction========================")
221 |
for i in tqdm(range(0, len(test_sequences))):
222 |
while True:
223 |
command = ["curl", "-X", "POST", "-k", "--data", f"{test_sequences[i]}", "https://api.esmatlas.com/foldSequence/v1/pdb/"]
224 |
result = subprocess.run(command, capture_output=True, text=True)
225 |
with open(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'), 'w') as file:
226 |
227 |
stats = os.stat(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'))
228 |
if stats.st_size < 1024:
229 |
print(f"Download for {test_Structure_index[i]} failed due to empty file. Retrying...")
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
# 导入模型
238 |
model = MyModel()
239 |
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')), strict=False)
240 |
model = model.to(device)
241 |
242 |
# 预测
243 |
244 |
with torch.no_grad():
245 |
probability_all = []
246 |
Target_all = []
247 |
print("=================================Start prediction========================")
248 |
for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
249 |
batchs = {k: v for k, v in batch.items()}
250 |
outputs = model(structure_fea, batchs, fingerprint)
251 |
probability = outputs[0].tolist()
252 |
train_argmax = np.argmax(outputs.cpu().detach().numpy(), axis=1)
253 |
for j in range(0,len(train_argmax)):
254 |
output = train_argmax[j]
255 |
if output == 0:
256 |
Target = "low"
257 |
probability = probability[0]
258 |
elif output == 1:
259 |
Target = "high"
260 |
probability = probability[1]
261 |
print(Target, probability)
262 |
263 |
264 |
summary = OrderedDict()
265 |
summary['Seq'] = test_sequences
266 |
summary['Target'] = Target_all
267 |
summary['Probability'] = probability_all
268 |
summary_df = pd.DataFrame(summary)
269 |
summary_df.to_csv('output.csv', index=False)
270 |
if len(test_sequences) > 1:
271 |
out_text = "Please download csv"
272 |
out_prob = "Please download csv"
273 |
274 |
out_text = output
275 |
out_prob = probability
276 |
return 'outputs.csv', out_text, out_prob
277 |
278 |
iface = gr.Interface(fn=ACE,
279 |
280 |
outputs= ["file","text","text"])
281 |