Last commit not found
import os | |
import re | |
import torch | |
from PIL import Image | |
from lavis.models import load_model_and_preprocess | |
from lavis.processors import load_processor | |
from lavis.common.registry import registry | |
from torch.nn import functional as F | |
from lavis.models.base_model import all_gather_with_grad, concat_all_gather | |
import numpy as np | |
import pandas as pd | |
import time | |
from fuzzywuzzy import process | |
from multiprocessing import Pool, Queue, Process | |
import difflib | |
import Levenshtein | |
# import obonet | |
# setup device to use | |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
# device = torch.device("cuda") | |
def txt_map(x, txt_dict): | |
if type(x) == str: | |
x = eval(x) | |
x_ = [] | |
for i in x: | |
if i in txt_dict: | |
x_.append(txt_dict[i]) | |
else: | |
x_.append(i) | |
return x_ | |
def levenshtein_sim(text, label): | |
all_s = [] | |
for x in label: | |
s = 0 | |
for y in text: | |
temp = Levenshtein.ratio(x, y) | |
if temp > s: | |
s = temp | |
all_s.append(s) | |
all_s = [round(i, 3) for i in all_s] | |
return all_s | |
def func(text, label): | |
all_s = [] | |
for x in text: | |
s = 0 | |
for y in label: | |
temp = Levenshtein.ratio(x, y) | |
if temp > s: | |
s = temp | |
all_s.append(s) | |
all_s = [round(i, 3) for i in all_s] | |
return all_s | |
def stage2_output(df_test, return_num_txt=1): | |
config = {'arch': 'blip2_protein_opt', 'load_finetuned': False, | |
'pretrained': '/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20231029182/checkpoint_0.pth', | |
'finetuned': '', 'num_query_token': 32, 'opt_model': 'facebook/opt-2.7b', 'prompt': '', | |
'model_type': 'pretrain_protein_opt2.7b', 'load_pretrained': True, 'freeze_vit': True, | |
'max_protein_len': 600, | |
'max_txt_len': 256} | |
model_cls = registry.get_model_class(config['arch']) | |
model = model_cls.from_config(config) | |
model.to(device) | |
model.eval() | |
images = df_test['protein'].tolist() | |
n = len(images) | |
bsz = 8 | |
iter = n // bsz + 1 | |
with open('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), 'a+') as f: | |
for i in range(iter): | |
image = images[i * bsz: min(n, (i + 1) * bsz)] | |
image = [('protein{}'.format(i), x) for i, x in enumerate(image)] | |
with model.maybe_autocast(): | |
_, _, batch_tokens = model.visual_encoder(image) | |
image_embeds = \ | |
model.ln_vision(batch_tokens.to(device), repr_layers=[model.vis_layers], return_contacts=True)[ | |
"representations"][model.vis_layers].contiguous() | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = model.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_opt = model.opt_proj(query_output.last_hidden_state) | |
atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(device) | |
model.opt_tokenizer.padding_side = "right" | |
text = ['' for i in range(len(image))] | |
opt_tokens = model.opt_tokenizer( | |
text, | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=model.max_txt_len, | |
).to(device) | |
inputs_embeds = model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids) | |
inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) | |
attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) | |
num_txt = 5 | |
with model.maybe_autocast(): | |
outputs = model.opt_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=1, | |
max_length=256, | |
repetition_penalty=1., num_beams=num_txt, eos_token_id=50118, | |
length_penalty=1., num_return_sequences=return_num_txt, temperature=1.) | |
output_text = model.opt_tokenizer.batch_decode(outputs) | |
output_text = [re.sub('\t', '', str(x)) for x in output_text] | |
output_text = [text.strip() for text in output_text] | |
output_text_ = [] | |
for i in range(len(image)): | |
output_text_.append(';'.join(output_text[i * return_num_txt:(i + 1) * return_num_txt])) | |
for i in range(len(image)): | |
f.write(image[i][1] + "|" + output_text_[i] + '\n') | |
if __name__=="__main__": | |
split = 'test' | |
cat = 'bp' | |
fix = '_mf' | |
type_fix = '' | |
if cat == 'bp': | |
fix = '_bp' | |
if cat == 'cc': | |
fix = '_cc' | |
print(device) | |
return_num_txt = 1 | |
# graph = obonet.read_obo("http://purl.obolibrary.org/obo/go.obo") | |
### Levenshtein similarity | |
print("reading file ...") | |
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split_concat/{}{}.csv'.format(split, fix), | |
usecols=['name', 'protein', 'function'], sep='|') | |
# test['function'] = test['function'].apply(lambda x: x.lower().split('; ')) | |
test.columns = ['name', 'protein', 'label'] | |
if os.path.exists('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)): | |
os.remove('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix)) | |
print("stage 2 predict starting") | |
stage2_output(test) | |
print("stage 2 predict completed") | |
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_concat_{}{}{}.txt'.format(split, fix, type_fix), sep='|', | |
header=None, on_bad_lines='warn') | |
df_pred.columns = ['protein', 'pred'] | |
df_pred = df_pred.drop_duplicates() | |
# df_pred['function'] = df_pred['function'].apply(lambda x: str(x).split(';')) | |
# df_pred['function'] = df_pred['function'].apply(lambda x: [i.strip() for i in list(set(x))]) | |
data = pd.merge(df_pred, test, on='protein', how='left') | |
data = data[data['label'].notnull()] | |
# sim = [] | |
# for text, label in zip(data['function'].tolist(), data['label'].tolist()): | |
# sim.append(func(text, label)) | |
# data['sim'] = sim | |
# data['avg_score'] = data['sim'].apply(lambda x: round(np.mean(x), 3)) | |
# data['count'] = data['sim'].apply(lambda x: x.count(1.)) | |
# print("average similarity score: {}".format(round(data['avg_score'].mean(), 3))) | |
# print("Return texts: {}; Accuracy: {}".format(return_num_txt, data['count'].sum()/(return_num_txt*data.shape[0]))) | |
data[['name', 'label', 'pred']].to_csv( | |
'/cluster/home/wenkai/LAVIS/output/predict_concat_{}{}{}.csv'.format(split, cat, type_fix), index=False, sep='|') | |