jinysun commited on
Commit
2927024
·
1 Parent(s): 86eba7b

Delete screen.py

Browse files
Files changed (1) hide show
  1. screen.py +0 -121
screen.py DELETED
@@ -1,121 +0,0 @@
1
- import os
2
- import pandas as pd
3
-
4
- import torch
5
- from torch.nn import functional as F
6
- from transformers import AutoTokenizer
7
-
8
- from util.utils import *
9
-
10
- from tqdm import tqdm
11
- from train import markerModel
12
-
13
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14
- os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
15
-
16
- device_count = torch.cuda.device_count()
17
- device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")
18
-
19
- device = torch.device('cpu')
20
- a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
21
- d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(a_model_name)
24
- d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
25
-
26
- #--biomarker Model
27
- ##-- hyper param config file Load --##
28
- config = load_hparams('config/predict.json')
29
- config = DictX(config)
30
- model = markerModel(config.d_model_name, config.p_model_name,
31
- config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
32
- #model.load_state_dict(torch.load(r"J:\libray\DeepDAP\DeepDAP\OSC\dap.pt"))
33
- # model = BiomarkerModel.load_from_checkpoint('./biomarker_bindingdb_train8595_pretopre/3477h3wf/checkpoints/epoch=30-step=7284.ckpt').to(device_biomarker)
34
- model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
35
- model.eval()
36
- model.freeze()
37
-
38
- if device_biomarker.type == 'cuda':
39
- model = torch.nn.DataParallel(model)
40
-
41
- def get_marker(drug_inputs, prot_inputs):
42
- output_preds = model(drug_inputs, prot_inputs)
43
-
44
- predict = torch.squeeze( (output_preds)).tolist()
45
-
46
- # output_preds = torch.relu(output_preds)
47
- # predict = torch.tanh(output_preds)
48
- # predict = predict.squeeze(dim=1).tolist()
49
-
50
- return predict
51
-
52
-
53
- def marker_prediction(smiles, aas):
54
- try:
55
- aas_input = []
56
- for ass_data in aas:
57
- aas_input.append(' '.join(list(ass_data)))
58
-
59
- a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
60
- # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
61
- a_input_ids = a_inputs['input_ids'].to(device)
62
- a_attention_mask = a_inputs['attention_mask'].to(device)
63
- a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
64
-
65
- d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
66
- # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
67
- d_input_ids = d_inputs['input_ids'].to(device)
68
- d_attention_mask = d_inputs['attention_mask'].to(device)
69
- d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
70
-
71
- output_predict = get_marker(a_inputs, d_inputs)
72
-
73
- output_list = [{'acceptor': smiles[i], 'donor': aas[i], 'predict': output_predict[i]} for i in range(0,len(aas))]
74
-
75
- return output_list
76
-
77
- except Exception as e:
78
- print(e)
79
- return {'Error_message': e}
80
-
81
-
82
- def smiles_aas_test(file):
83
-
84
- batch_size = 80
85
- try:
86
- datas = []
87
- marker_list = []
88
- marker_datas = []
89
-
90
- smiles_aas = pd.read_csv(file)
91
-
92
- ## -- 1 to 1 pair predict check -- ##
93
- for data in smiles_aas.values:
94
- marker_datas.append([data[2 ], data[1]])
95
- if len(marker_datas) == batch_size:
96
- marker_list.append(list(marker_datas))
97
- marker_datas.clear()
98
-
99
- if len(marker_datas) != 0:
100
- marker_list.append(list(marker_datas))
101
- marker_datas.clear()
102
-
103
- for marker_datas in tqdm(marker_list, total=len(marker_list)):
104
- smiles_d , smiles_a = zip(*marker_datas)
105
- output_pred = marker_prediction(list(smiles_d), list(smiles_a) )
106
- if len(datas) == 0:
107
- datas = output_pred
108
- else:
109
- datas = datas + output_pred
110
- datas = pd.DataFrame(datas)
111
- # ## -- Export result data to csv -- ##
112
- # df = pd.DataFrame(datas)
113
- # df.to_csv('./results/predictData_nontonon_bindingdb_test.csv', index=None)
114
-
115
- # print(df)
116
- return datas
117
-
118
- except Exception as e:
119
- print(e)
120
- return {'Error_message': e}
121
-