import VolumeMaker
import utils
import numpy as np
import random
import torch
import torch.nn as nn
import pathlib
import pandas as pd
import shutil
import subprocess
from transformers import AutoModelForSequenceClassification
from import Dataset,DataLoader
import pandas as pd
device = torch.device("cpu")
import os
from transformers import AutoTokenizer
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
from collections import OrderedDict
from tqdm import tqdm
import time
import gradio as gr
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
pdb_path = pathlib.Path(__file__).parent.joinpath("structure" )
# seq_path = "test3.csv"
temp_path = pathlib.Path(__file__).parent.joinpath("temp" )
def setup_seed(seed):
torch.backends.cudnn.deterministic = True
batch_size = 1
num_labels = 2
radius = 2
n_features = 1024
hid_dim = 300
n_heads = 1
dropout = 0
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
class MyDataset(Dataset):
def __init__(self,dict_data) -> None:
def __getitem__(self, index):
return['text'][index], self.structure[index]
def __len__(self):
return len(['text'])
def collate_fn(batch):
data = [item[0] for item in batch]
structure = torch.tensor([item[1].tolist() for item in batch]).to(device)
max_len = max([len(b[0]) for b in batch])+2
fingerprint = torch.tensor(peptides_to_fingerprint_matrix(data, radius, n_features),dtype=float).to(device)
pt_batch=tokenizer(data, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
return {'input_ids':pt_batch['input_ids'].to(device),
'attention_mask':pt_batch['attention_mask'].to(device)}, structure, fingerprint
class AttentionBlock(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
self.hid_dim = hid_dim
self.n_heads = n_heads
assert hid_dim % n_heads == 0
self.f_q = nn.Linear(hid_dim, hid_dim)
self.f_k = nn.Linear(hid_dim, hid_dim)
self.f_v = nn.Linear(hid_dim, hid_dim)
self.fc = nn.Linear(hid_dim, hid_dim) = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.f_q(query)
K = self.f_k(key)
V = self.f_v(value)
Q = Q.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
K_T = K.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3).transpose(2,3)
V = V.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
energy = torch.matmul(Q, K_T) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention =, dim=-1))
weighter_matrix = torch.matmul(attention, V)
weighter_matrix = weighter_matrix.permute(0, 2, 1, 3).contiguous()
weighter_matrix = weighter_matrix.view(batch_size, self.n_heads * (self.hid_dim // self.n_heads))
weighter_matrix =
return weighter_matrix
class CrossAttentionBlock(nn.Module):
def __init__(self):
super(CrossAttentionBlock, self).__init__()
self.att = AttentionBlock(hid_dim = hid_dim, n_heads = n_heads, dropout=0.1)
def forward(self, structure_feature, fingerprint_feature, sequence_feature):
# cross attention for compound information enrichment
fingerprint_feature = fingerprint_feature + self.att(fingerprint_feature, structure_feature, structure_feature)
# self-attention
fingerprint_feature = self.att(fingerprint_feature, fingerprint_feature, fingerprint_feature)
# cross-attention for interaction
output = self.att(fingerprint_feature, sequence_feature, sequence_feature)
return output
def peptides_to_fingerprint_matrix(peptides, radius=radius, n_features=n_features):
n_peptides = len(peptides)
features = np.zeros((n_peptides, n_features))
for i, peptide in enumerate(peptides):
mol = Chem.MolFromSequence(peptide)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_features)
fp_array = np.zeros((1,))
AllChem.DataStructs.ConvertToNumpyArray(fp, fp_array)
features[i, :] = fp_array
return features
class MyModel(nn.Module):
def __init__(self):
self.bert = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=hid_dim)
self.bn1 = nn.BatchNorm1d(256)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(64)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(300,256)
self.fc2 = nn.Linear(256,128)
self.fc3 = nn.Linear(128,64)
self.fc_fingerprint = nn.Linear(1024,hid_dim)
self.fc_structure = nn.Linear(1500,hid_dim)
self.fingerprint_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=1024, hidden_size=1024//2, batch_first=True)
self.structure_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=500, hidden_size=500//2, batch_first=True)
self.output_layer = nn.Linear(64,num_labels)
self.dropout = nn.Dropout(0)
self.CAB = CrossAttentionBlock()
def forward(self,structure, x, fingerprint):
fingerprint = torch.unsqueeze(fingerprint, 2).float()
structure = structure.permute(0, 2, 1)
fingerprint = fingerprint.permute(0, 2, 1)
with torch.no_grad():
bert_output = self.bert(input_ids=x['input_ids'].to(device),attention_mask=x['attention_mask'].to(device))
sequence_feature = self.dropout(bert_output["logits"])
structure =
fingerprint_feature, _ = self.fingerprint_lstm(fingerprint)
structure_feature, _ = self.structure_lstm(structure)
fingerprint_feature = fingerprint_feature.flatten(start_dim=1)
structure_feature = structure_feature.flatten(start_dim=1)
fingerprint_feature = self.fc_fingerprint(fingerprint_feature)
structure_feature = self.fc_structure(structure_feature)
output_feature = self.CAB(structure_feature, fingerprint_feature, sequence_feature)
output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
output_feature = self.dropout(self.output_layer(output_feature))
return torch.softmax(output_feature,dim=1)
def pdb_structure(Structure_index):
created_folders = []
SurfacePoitCloud_all = []
for index in Structure_index:
structure_folder = join(temp_path, str(index))
os.makedirs(structure_folder, exist_ok=True)
pdb_file = join(pdb_path, f"{index}.pdb")
if os.path.exists(pdb_file):
shutil.copy2(pdb_file, structure_folder)
print(f"PDB file not found for structure {index}")
coords, atname, pdbname, pdb_num = utils.parsePDB(structure_folder)
atoms_channel = utils.atomlistToChannels(atname)
radius = utils.atomlistToRadius(atname)
PointCloudSurfaceObject = VolumeMaker.PointCloudSurface(device=device)
coords =
radius =
atoms_channel =
SurfacePoitCloud = PointCloudSurfaceObject(coords, radius)
feature = SurfacePoitCloud.view(pdb_num,-1,3).cpu()
SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(SurfacePoitCloud_all),dim=1)
for folder in created_folders:
return SurfacePoitCloud_all_tensor
def ACE(file):
if not os.path.exists(pdb_path):
test_sequences = [file]
test_Structure_index = [f"structure_{i}" for i in range(len(test_sequences))]
test_dict = {"text":test_sequences, 'structure':test_Structure_index}
print("=================================Structure prediction========================")
for i in tqdm(range(0, len(test_sequences))):
command = ["curl", "-X", "POST", "-k", "--data", f"{test_sequences[i]}", ""]
result =, capture_output=True, text=True)
with open(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'), 'w') as file:
# 导入模型
model = MyModel()
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')), strict=False)
model =
# 预测
with torch.no_grad():
probability_all = []
Target_all = []
print("=================================Start prediction========================")
for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
batchs = {k: v for k, v in batch.items()}
outputs = model(structure_fea, batchs, fingerprint)
probability = outputs[0].tolist()
train_argmax = np.argmax(outputs.cpu().detach().numpy(), axis=1)
for j in range(0,len(train_argmax)):
output = train_argmax[j]
if output == 0:
Target = "low"
probability = probability[0]
elif output == 1:
Target = "high"
probability = probability[1]
print(Target, probability)
summary = OrderedDict()
summary['Seq'] = test_sequences
summary['Target'] = Target
summary['Probability'] = probability_all
summary_df = pd.DataFrame(summary)
summary_df.to_csv('output.csv', index=False)
if len(test_sequences) > 1:
out_text = "Please download csv"
out_prob = "Please download csv"
out_text = output
out_prob = probability
return 'output.csv', out_text, out_prob
with open("", "r") as f:
description =
iface = gr.Interface(fn=ACE,
title="🚀DeepACE: ACE classification model",
inputs=gr.Textbox(show_label=False, placeholder="Enter peptide or protein", lines=4),
outputs= ["file",gr.Textbox(show_label=False, placeholder="class", lines=1),gr.Textbox(show_label=False, placeholder="probability", lines=1)],
iface.launch() |