Update app.py
Browse files
app.py
CHANGED
@@ -51,7 +51,7 @@ class MyDataset(Dataset):
|
|
51 |
def __init__(self,dict_data) -> None:
|
52 |
super(MyDataset,self).__init__()
|
53 |
self.data=dict_data
|
54 |
-
self.structure=pdb_structure(dict_data['
|
55 |
def __getitem__(self, index):
|
56 |
return self.data['text'][index], self.structure[index]
|
57 |
def __len__(self):
|
@@ -174,51 +174,57 @@ class MyModel(nn.Module):
|
|
174 |
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
|
175 |
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
|
176 |
output_feature = self.dropout(self.output_layer(output_feature))
|
177 |
-
print(output_feature)
|
178 |
return torch.softmax(output_feature,dim=1)
|
179 |
|
180 |
|
181 |
-
def pdb_structure(
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(
|
|
|
|
|
204 |
return SurfacePoitCloud_all_tensor
|
205 |
|
206 |
def ACE(file):
|
|
|
|
|
|
|
|
|
|
|
207 |
# df = pd.read_csv(seq_path)
|
208 |
# test_sequences = df["Seq"].tolist()
|
209 |
# test_Structure_index = df["Structure_index"].tolist()
|
210 |
|
211 |
test_sequences = [file]
|
212 |
-
test_Structure_index = ["
|
213 |
|
214 |
|
215 |
test_dict = {"text":test_sequences, 'structure':test_Structure_index}
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
test_data=MyDataset(test_dict)
|
223 |
test_dataloader=DataLoader(test_data,batch_size=batch_size,collate_fn=collate_fn,shuffle=False)
|
224 |
|
@@ -235,6 +241,7 @@ def ACE(file):
|
|
235 |
print("=================================Start prediction========================")
|
236 |
for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
|
237 |
batchs = {k: v for k, v in batch.items()}
|
|
|
238 |
outputs = model(structure_fea, batchs, fingerprint)
|
239 |
probability = outputs[0].tolist()
|
240 |
print(outputs)
|
@@ -257,8 +264,12 @@ def ACE(file):
|
|
257 |
summary['Probability'] = probability_all
|
258 |
summary_df = pd.DataFrame(summary)
|
259 |
summary_df.to_csv('output.csv', index=False)
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
262 |
return 'output.csv', out_text, out_prob
|
263 |
|
264 |
iface = gr.Interface(fn=ACE,
|
|
|
51 |
def __init__(self,dict_data) -> None:
|
52 |
super(MyDataset,self).__init__()
|
53 |
self.data=dict_data
|
54 |
+
self.structure=pdb_structure(dict_data['structure'])
|
55 |
def __getitem__(self, index):
|
56 |
return self.data['text'][index], self.structure[index]
|
57 |
def __len__(self):
|
|
|
174 |
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
|
175 |
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
|
176 |
output_feature = self.dropout(self.output_layer(output_feature))
|
|
|
177 |
return torch.softmax(output_feature,dim=1)
|
178 |
|
179 |
|
180 |
+
def pdb_structure(Structure_index):
|
181 |
+
created_folders = []
|
182 |
+
SurfacePoitCloud_all = []
|
183 |
+
for index in Structure_index:
|
184 |
+
structure_folder = join(temp_path, str(index))
|
185 |
+
os.makedirs(structure_folder, exist_ok=True)
|
186 |
+
created_folders.append(structure_folder)
|
187 |
+
pdb_file = join(pdb_path, f"{index}.pdb")
|
188 |
+
if os.path.exists(pdb_file):
|
189 |
+
shutil.copy2(pdb_file, structure_folder)
|
190 |
+
else:
|
191 |
+
print(f"PDB file not found for structure {index}")
|
192 |
+
coords, atname, pdbname, pdb_num = utils.parsePDB(structure_folder)
|
193 |
+
atoms_channel = utils.atomlistToChannels(atname)
|
194 |
+
radius = utils.atomlistToRadius(atname)
|
195 |
+
PointCloudSurfaceObject = VolumeMaker.PointCloudSurface(device=device)
|
196 |
+
coords = coords.to(device)
|
197 |
+
radius = radius.to(device)
|
198 |
+
atoms_channel = atoms_channel.to(device)
|
199 |
+
SurfacePoitCloud = PointCloudSurfaceObject(coords, radius)
|
200 |
+
feature = SurfacePoitCloud.view(pdb_num,-1,3).cpu()
|
201 |
+
SurfacePoitCloud_all.append(feature)
|
202 |
+
SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(SurfacePoitCloud_all),dim=1)
|
203 |
+
for folder in created_folders:
|
204 |
+
shutil.rmtree(folder)
|
205 |
return SurfacePoitCloud_all_tensor
|
206 |
|
207 |
def ACE(file):
|
208 |
+
if not os.path.exists(pdb_path):
|
209 |
+
os.makedirs(pdb_path)
|
210 |
+
else:
|
211 |
+
shutil.rmtree(pdb_path)
|
212 |
+
os.makedirs(pdb_path)
|
213 |
# df = pd.read_csv(seq_path)
|
214 |
# test_sequences = df["Seq"].tolist()
|
215 |
# test_Structure_index = df["Structure_index"].tolist()
|
216 |
|
217 |
test_sequences = [file]
|
218 |
+
test_Structure_index = [f"structure_{i}" for i in range(len(test_sequences))]
|
219 |
|
220 |
|
221 |
test_dict = {"text":test_sequences, 'structure':test_Structure_index}
|
222 |
+
print("=================================Structure prediction========================")
|
223 |
+
for i in tqdm(range(0, len(test_sequences))):
|
224 |
+
command = ["curl", "-X", "POST", "-k", "--data", f"{test_sequences[i]}", "https://api.esmatlas.com/foldSequence/v1/pdb/"]
|
225 |
+
result = subprocess.run(command, capture_output=True, text=True)
|
226 |
+
with open(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'), 'w') as file:
|
227 |
+
file.write(result.stdout)
|
228 |
test_data=MyDataset(test_dict)
|
229 |
test_dataloader=DataLoader(test_data,batch_size=batch_size,collate_fn=collate_fn,shuffle=False)
|
230 |
|
|
|
241 |
print("=================================Start prediction========================")
|
242 |
for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
|
243 |
batchs = {k: v for k, v in batch.items()}
|
244 |
+
print(structure_fea)
|
245 |
outputs = model(structure_fea, batchs, fingerprint)
|
246 |
probability = outputs[0].tolist()
|
247 |
print(outputs)
|
|
|
264 |
summary['Probability'] = probability_all
|
265 |
summary_df = pd.DataFrame(summary)
|
266 |
summary_df.to_csv('output.csv', index=False)
|
267 |
+
if len(test_sequences) > 1:
|
268 |
+
out_text = "Please download csv"
|
269 |
+
out_prob = "Please download csv"
|
270 |
+
else:
|
271 |
+
out_text = output
|
272 |
+
out_prob = probability
|
273 |
return 'output.csv', out_text, out_prob
|
274 |
|
275 |
iface = gr.Interface(fn=ACE,
|