oucgc1996 commited on
Commit
76df63f
·
verified ·
1 Parent(s): 8ec8d66

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import pandas as pd
4
+ from transformers import set_seed
5
+ import torch
6
+ import torch.nn as nn
7
+ from collections import OrderedDict
8
+ import warnings
9
+ import gradio as gr
10
+
11
+ warnings.filterwarnings('ignore')
12
+ set_seed(4)
13
+ device = "cuda:0"
14
+ model_checkpoint = "facebook/esm2_t30_150M_UR50D"
15
+
16
+ class MyModel(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.bert = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=320)
20
+ self.bn1 = nn.BatchNorm1d(256)
21
+ self.bn2 = nn.BatchNorm1d(128)
22
+ self.bn3 = nn.BatchNorm1d(64)
23
+ self.relu = nn.ReLU()
24
+ self.fc1 = nn.Linear(320,256)
25
+ self.fc2 = nn.Linear(256,128)
26
+ self.fc3 = nn.Linear(128,64)
27
+ self.output_layer = nn.Linear(64,2)
28
+ self.dropout = nn.Dropout(0)
29
+
30
+ def forward(self,x):
31
+ with torch.no_grad():
32
+ bert_output = self.bert(input_ids=x['input_ids'].to(device),attention_mask=x['attention_mask'].to(device))
33
+ output_feature = self.dropout(bert_output["logits"])
34
+ output_feature = self.relu(self.bn1(self.fc1(output_feature)))
35
+ output_feature = self.relu(self.bn2(self.fc2(output_feature)))
36
+ output_feature = self.relu(self.bn3(self.fc3(output_feature)))
37
+ output_feature = self.output_layer(output_feature)
38
+ return torch.softmax(output_feature,dim=1)
39
+
40
+ def Kmers_funct(seq,num):
41
+ for i in range(len(seq)):
42
+ a = seq[i]
43
+ l = []
44
+ for index in range(len(a)):
45
+ t = a[index:index + num]
46
+ if (len(t)) == num:
47
+ l.append(t)
48
+ return l
49
+
50
+ def ACE(file):
51
+ model = MyModel()
52
+ model.load_state_dict(torch.load("best_model.pth"))
53
+ model = model.to(device)
54
+ model.eval()
55
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
56
+ test_sequences = file
57
+ max_len = 30
58
+ test_data = tokenizer(test_sequences, max_length=max_len, padding="max_length",truncation=True, return_tensors='pt')
59
+ out_probability = []
60
+ with torch.no_grad():
61
+ predict = model(test_data)
62
+ out_probability.extend(np.max(np.array(predict.cpu()),axis=1).tolist())
63
+ test_argmax = np.argmax(predict.cpu(), axis=1).tolist()
64
+ id2str = {0:"non-ACE", 1:"ACE"}
65
+ return id2str[test_argmax[0]], out_probability[0]
66
+
67
+ def main(file):
68
+ test_seq = file
69
+ all = []
70
+ seq_all = []
71
+ output_all = []
72
+ probability_all = []
73
+ for j in range(2, 11):
74
+ X = Kmers_funct([test_seq], j)
75
+ all.extend(X)
76
+ for seq in all:
77
+ output, probability = ACE(str(seq))
78
+ seq_all.append(seq)
79
+ output_all.append(output)
80
+ probability_all.append(probability)
81
+
82
+ summary = OrderedDict()
83
+ summary['Seq'] = seq_all
84
+ summary['Class'] = output_all
85
+ summary['Probability'] = probability_all
86
+ summary_df = pd.DataFrame(summary)
87
+ summary_df.to_csv('output.csv', index=False)
88
+ return 'outputs.csv'
89
+
90
+
91
+ iface = gr.Interface(fn=main,
92
+ inputs="text",
93
+ outputs= "file")
94
+ iface.launch()