Last commit not found
import re | |
import torch | |
from PIL import Image | |
from lavis.models import load_model_and_preprocess | |
from lavis.processors import load_processor | |
from lavis.common.registry import registry | |
from torch.nn import functional as F | |
from lavis.models.base_model import all_gather_with_grad, concat_all_gather | |
import numpy as np | |
import pandas as pd | |
import time | |
from fuzzywuzzy import process | |
from multiprocessing import Pool, Queue, Process | |
import difflib | |
import Levenshtein | |
import os | |
# import obonet | |
def fuzzy_match(texts): | |
text_dict = {} | |
for context in texts: | |
if context not in choices: | |
# txt_dict[txt] = process.extractOne(txt, choices)[0] | |
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0] | |
return text_dict | |
def txt_map(x, txt_dict): | |
if type(x) == str: | |
x = eval(x) | |
x_ = [] | |
for i in x: | |
if i in txt_dict: | |
x_.append(txt_dict[i]) | |
else: | |
x_.append(i) | |
return x_ | |
def levenshtein_sim(text, label): | |
all_s = [] | |
for x in label: | |
s = 0 | |
for y in text: | |
temp = Levenshtein.ratio(x, y) | |
if temp > s: | |
s = temp | |
all_s.append(s) | |
all_s = [round(i, 3) for i in all_s] | |
return all_s | |
def func(text, label): | |
all_s = [] | |
for x in label: | |
s = 0 | |
for y in text: | |
temp = Levenshtein.ratio(x, y) | |
if temp > s: | |
s = temp | |
all_s.append(s) | |
all_s = [round(i, 3) for i in all_s] | |
return all_s | |
def stage2_output(df_test): | |
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False, | |
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20230924220/checkpoint_5.pth', | |
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '', | |
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True, | |
'max_protein_len': 600, | |
'max_txt_len': 25} | |
model_cls = registry.get_model_class(config['arch']) | |
model = model_cls.from_config(config) | |
model.to(device) | |
model.eval() | |
images = df_test['protein'].tolist() | |
n = len(images) | |
bsz = 12 | |
iter = n // bsz + 1 | |
for i in range(iter): | |
image = images[i*bsz: min(n, (i+1)*bsz)] | |
image = [('protein{}'.format(i), x) for i, x in enumerate(image)] | |
with model.maybe_autocast(): | |
_, _, batch_tokens = model.visual_encoder(image) | |
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)["representations"][model.vis_layers].contiguous() | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = model.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_opt = model.opt_proj(query_output.last_hidden_state) | |
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device) | |
model.opt_tokenizer.padding_side = "right" | |
text = ['' for i in range(len(image))] | |
opt_tokens = model.opt_tokenizer( | |
text, | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=model.max_txt_len, | |
).to(device) | |
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids) | |
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) | |
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) | |
num_txt = 10 | |
return_num_txt = 5 | |
with model.maybe_autocast(): | |
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=3, | |
max_length=30, | |
repetition_penalty=5., num_beams=num_txt, eos_token_id=50118, | |
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.) | |
output_text = model.opt_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
output_text = [text.strip() for text in output_text] | |
output_text_ = [] | |
for i in range(len(image)): | |
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt])) | |
with open('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), 'a+') as f: | |
for i in range(len(image)): | |
f.write(image[i][1] + "|" + output_text_[i] + '\n') | |
cat = 'mf' | |
fix = '_mf' | |
if cat == 'bp': | |
fix = '_bp' | |
if cat == 'cc': | |
fix = '_cc' | |
# model_pth = {'mf': 'uniprot_swissprot_mf_stage1_epo19.pth', 'bp': 'checkpoint17_GO_swissprot_reviewed_bp_stage1.pth', 'cc': ''} | |
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo") | |
# setup device to use | |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
# device = 'cpu' | |
### Levenshtein similarity | |
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|')[:10000] | |
test['function'] = test['function'].apply(lambda x: x.lower()) | |
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)): | |
os.remove('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix)) | |
print("stage 2 predict starting") | |
stage2_output(test) | |
print("stage 2 predict completed") | |
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output{}.txt'.format(fix), sep='|', header=None, on_bad_lines='warn') | |
df_pred.columns = ['protein', 'function'] | |
df_pred = df_pred.drop_duplicates() | |
df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';')) | |
df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))]) | |
test.columns | |
test_g = test.groupby(['protein']).agg({'function': lambda x: list(x)}).reset_index() | |
test_g.columns = ['protein', 'label'] | |
data = pd.merge(df_pred, test_g, on='protein', how='left') | |
data = data[data['label'].notnull()] | |
sim = [] | |
for text, label in zip(data['function'].tolist(), data['label'].tolist()): | |
sim.append(func(text, label)) | |
data['sim'] = sim | |
data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3)) | |
print("average similarity score: {}".format(round(data['avg_score'].mean(), 3))) | |
# data.to_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), index=False, sep='|') | |
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label']) | |
test['function'] = test['function'].apply(lambda x: x.lower()) | |
test = test.drop_duplicates() | |
test_dict = dict(zip(test['function'], test['GO_label'])) | |
val = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/val{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label']) | |
val['function'] = val['function'].apply(lambda x: x.lower()) | |
val = val.drop_duplicates() | |
val_dict = dict(zip(val['function'], val['GO_label'])) | |
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train{}.csv'.format(fix), sep='|', usecols=['function', 'GO_label']) | |
train['function'] = train['function'].apply(lambda x: x.lower()) | |
train = train.drop_duplicates() | |
train_dict = dict(zip(train['function'], train['GO_label'])) | |
# go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions_new.txt', sep='|', header=None) | |
# # go_des = pd.read_csv('/home/nilin/LAVIS/data/go_descriptions.txt', sep='|', header=None) | |
# go_des.columns = ['GO', 'function'] | |
# go_des = go_des[go_des['function'].notnull()] | |
# go_des['function'] = go_des['function'].apply(lambda x: x.lower()) | |
# GO_dict = dict(zip(go_des['function'], go_des['GO'])) | |
GO_dict = {} | |
GO_dict.update(train_dict) | |
GO_dict.update(val_dict) | |
GO_dict.update(test_dict) | |
choices = list(GO_dict.keys()) | |
# data = pd.read_csv('/home/nilin/LAVIS/predict_{}.csv'.format(cat), sep='|') | |
data = data.sort_values(by='protein') | |
data = data.drop_duplicates('protein') | |
# data = data.sample(1000) | |
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签 | |
t0 = time.time() | |
txt_dict = {} | |
all_txt = [] | |
for txt in data['function']: | |
if type(txt) == str: | |
all_txt.extend(eval(txt)) | |
else: | |
all_txt.extend(txt) | |
all_txt = list(set(all_txt)) | |
n = len(all_txt) | |
thread = 20 | |
size = int(n/thread) | |
inds = list(range(0, n, size)) | |
inds.append(n) | |
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]] | |
with Pool(processes=thread) as pool: | |
result = pool.map(fuzzy_match, all_txt_sep) | |
pool.close() | |
pool.join() | |
for d in result: | |
txt_dict.update(d) | |
# for txt in all_txt[:10]: | |
# fuzzy_match(txt) | |
data['function'] = data['function'].apply(lambda x: txt_map(x, txt_dict)) | |
data['function'] = data['function'].apply(lambda x: list(set(x))) | |
print("fuzzy matching time: {}".format(time.time() - t0)) | |
### Find the generated GO text that not included in the ground truth. Then generate pairs between them. | |
# pair_a, pair_b = [], [] | |
# for preds, labels in zip(data['function'], data['label']): | |
# if type(preds) == str: | |
# preds = eval(preds) | |
# if type(labels) == str: | |
# labels = eval(labels) | |
# l = len(labels) | |
# for pred in preds: | |
# if pred not in labels: | |
# pair_a.extend([pred]*l) | |
# pair_b.extend(labels[:]) | |
# pair_a = [re.sub('_', ':', GO_dict[i]) for i in pair_a] | |
# pair_b = [re.sub('_', ':', GO_dict[i]) for i in pair_b] | |
# with open('/home/nilin/LAVIS/examples/GO_pair{}.txt'.format(fix), 'w+') as f: | |
# for i, j in zip(pair_a, pair_b): | |
# f.write(i+' '+j+'\n') | |
# load model | |
model_config = {'arch': 'blip2_protein', 'load_finetuned': False, | |
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage1/20230922185/checkpoint_15.pth', | |
'finetuned': '', 'num_query_token': 32, 'prompt': '', | |
'model_type': 'pretrain', 'load_pretrained': True, 'freeze_vit': False, | |
'max_protein_len': 512, 'max_txt_len': 25} | |
model_cls = registry.get_model_class(model_config['arch']) | |
model = model_cls.from_config(model_config) | |
model = model.to(device) | |
model.eval() | |
# evaluate | |
t0 = time.time() | |
proteins = list(data['protein']) | |
txts = list(data['function']) | |
scores = [] | |
for seq, txt in zip(proteins, txts): | |
image = [('protein1', seq)] | |
_, _, batch_tokens = model.visual_encoder(image) | |
image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][ | |
30].contiguous() | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = model.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
use_cache=True, | |
return_dict=True, | |
) | |
image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1) | |
image_feats_all = concat_all_gather(image_feats) | |
if type(txt) == str: | |
txt = eval(txt) | |
length = len(txt) | |
with torch.no_grad(): | |
text_tokens = model.tokenizer( | |
txt, | |
padding="max_length", | |
truncation=True, | |
max_length=model.max_txt_len, | |
return_tensors="pt", | |
).to(device) | |
text_output = model.Qformer.bert( | |
text_tokens.input_ids, | |
attention_mask=text_tokens.attention_mask, | |
return_dict=True, | |
) | |
text_feat = F.normalize( | |
model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 | |
) | |
text_feat_all = concat_all_gather(text_feat) | |
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze() | |
sim_i2t, _ = sim_q2t.max(-1) | |
# print('sim_i2t: {}'.format(sim_i2t)) | |
if length > 1: | |
scores.append(list(sim_i2t.detach().cpu().numpy())) | |
else: | |
scores.append([sim_i2t.item()]) | |
print("model evaluate time: {}".format(time.time() - t0)) | |
data['score'] = scores | |
# precision and recall top-k | |
topk = 2 | |
threshould = 0.1 | |
labels = [] | |
pred_labels = [] | |
for l in data['label']: | |
if type(l) == str: | |
l = eval(l) | |
labels.extend(l) | |
labels = list(set(labels)) | |
total = len(labels) | |
for topk in range(1,7): | |
for threshould in range(1, 25, 1): | |
threshould /= 100 | |
filter_txts = [] | |
recalls = [] | |
precisions = [] | |
f1 = [] | |
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))) | |
for txts, scores, label in zip(data['function'], data['score'], data['label']): | |
if type(label) == str: | |
label = eval(label) | |
txts_ = np.array(txts) | |
scores = np.array(scores) | |
txts = txts_[scores > threshould] | |
if len(txts) < 1: | |
txts = txts_[np.argmax(scores)] | |
scores = scores[scores > threshould] | |
l = len(scores) | |
ll = len(label) | |
if l <= topk: | |
filter_txts.append(list(txts)) | |
else: | |
ind = np.argpartition(scores, -topk)[-topk:] | |
txts = txts[ind] | |
filter_txts.append(list(txts)) | |
l = topk | |
for t in label: | |
if t in txts: | |
tp_dict[t] += 1 | |
else: | |
fn_dict[t] += 1 | |
for p in txts: | |
if p not in label: | |
if p in fp_dict: | |
fp_dict[p] += 1 | |
else: | |
fp_dict[p] = 1 | |
pred_labels.extend(txts) | |
p_total = len(set(pred_labels)) | |
re, pr = 0., 0. | |
for x in labels: | |
re += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8)) | |
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x]+1e-8)) | |
r = re / total | |
p = pr / total | |
f1 = 2 * p * r / (p + r) | |
print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, r, p, f1)) | |
# num_r = 0 | |
# num_p = 0 | |
# for x in label: | |
# if x in txts: | |
# num_r += 1 | |
# for x in txts: | |
# if x in label: | |
# num_p += 1 | |
# recall = num_r/ll | |
# precision = num_p/(l+0.0001) | |
# recalls.append(recall) | |
# precisions.append(precision) | |
# f1.append((2*recall*precision)/(recall+precision+0.0001)) | |
# | |
# data['predict'] = filter_txts | |
# data['precision'] = precisions | |
# data['recall'] = recalls | |
# data['f1'] = f1 | |
# print("Topk: {}, threshould: {}, macro_recall: {}, macro_precision: {}, micro_f1: {}".format(topk, threshould, round(data['recall'].mean(), 4), round(data['precision'].mean(), 4), round(data['f1'].mean(), 4))) | |
# sim = [] | |
# for text, label in zip(data['predict'].tolist(), data['label'].tolist()): | |
# sim.append(levenshtein_sim(text, label)) | |
# | |
# data['sim_filter'] = sim | |
# data['avg_score'] = data['sim_filter'].apply(lambda x: round(np.mean(x), 3)) | |
# data['function'] = data['function'].apply(lambda x: eval(re.sub(';', ',', str(x)))) | |
# data['label'] = data['label'].apply(lambda x: eval(re.sub(';', ',', str(x)))) | |
# data['sim'] = data['sim'].apply(lambda x: eval(re.sub(';', ',', str(x)))) | |
# | |
# data['function'] = data['function'].apply(lambda x: re.sub(',', ';', str(x))) | |
# data['label'] = data['label'].apply(lambda x: re.sub(',', ';', str(x))) | |
# data['sim'] = data['sim'].apply(lambda x: re.sub(',', ';', str(x))) | |
# data['predict'] = data['predict'].apply(lambda x: re.sub(',', ';', str(x))) | |
# data['sim_filter'] = data['sim_filter'].apply(lambda x: re.sub(',', ';', str(x))) | |
data.to_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|', index=False) | |
# data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_sim{}.csv'.format(fix), sep='|') | |
# | |
# # example | |
# image = ['MIELKHVTFGYNKKQMVLQDINITIPDGENVGILGESGCGKSTLASLVLGLFKPVKGEIYLSDNAVLTIFQHPLTSFNPDWTIETSLKEALYYYRGLTDNTAQDQLLLQHLSTFELNAQLLTKLPSEVSGGQLQRFNVMRSLLAQPRVLICDEITSNLDVIAEQNVINILKAQTITNLNHFIVISHDLSVLQRLVNRIIVLKDGMIVDDFAIEELFNVDRHPYTKELVQTFSY'] | |
# image = [('protein{}'.format(i), x) for i, x in enumerate(image)] | |
# | |
# _, _, batch_tokens = model.visual_encoder(image) | |
# image_embeds = model.ln_vision(batch_tokens.to(device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous() | |
# | |
# image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
# | |
# query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
# | |
# query_output = model.Qformer.bert( | |
# query_embeds=query_tokens, | |
# encoder_hidden_states=image_embeds, | |
# encoder_attention_mask=image_atts, | |
# use_cache=True, | |
# return_dict=True, | |
# ) | |
# | |
# image_feats = F.normalize(model.vision_proj(query_output.last_hidden_state), dim=-1) | |
# | |
# image_feats_all = concat_all_gather(image_feats) | |
# | |
# functions = ['transmembrane transporter activity', 'nickel cation transmembrane transporter activity', 'nickel cation binding', 'atp hydrolysis activity', 'atp hydrolysis', 'cadmium binding', 'abc-type nickel transmembrane transporter activity', 'abc-type nickel transporter activity', 'nickel transmembrane transporter activity', 'atp binding'] | |
# for text in functions: | |
# with torch.no_grad(): | |
# # text = 'flavin adenine dinucleotide binding' | |
# text_tokens = model.tokenizer( | |
# text, | |
# padding="max_length", | |
# truncation=True, | |
# max_length=model.max_txt_len, | |
# return_tensors="pt", | |
# ).to(device) | |
# text_output = model.Qformer.bert( | |
# text_tokens.input_ids, | |
# attention_mask=text_tokens.attention_mask, | |
# return_dict=True, | |
# ) | |
# | |
# text_feat = F.normalize( | |
# model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 | |
# ) | |
# | |
# text_feat_all = concat_all_gather(text_feat) | |
# sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze() | |
# sim_i2t, _ = sim_q2t.max(-1) | |
# print('sim_i2t: {}'.format(sim_i2t)) | |
# | |
# # # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] | |
# # sim_t2q = torch.matmul( | |
# # text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) | |
# # ).squeeze() | |
# # | |
# # # text-image similarity: aggregate across all query tokens | |
# # sim_t2i, _ = sim_t2q.max(-1) | |
# # print('sim_t2i: {}'.format(sim_t2i)) | |