Spaces:
Running
Running
import logging | |
import os | |
import sys | |
sys.path.append("../") | |
from pytorch_metric_learning.distances import CosineSimilarity | |
from pytorch_metric_learning.reducers import ThresholdReducer | |
from pytorch_metric_learning.regularizers import LpRegularizer | |
from pytorch_metric_learning import losses | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from pytorch_metric_learning import losses, miners | |
from torch.cuda.amp import autocast | |
from torch.nn import Module | |
from tqdm import tqdm | |
from utils.gd_model import GDANet | |
from torch.nn import MultiheadAttention | |
from transformers import BertModel | |
from transformers import EsmModel, EsmConfig | |
LOGGER = logging.getLogger(__name__) | |
class FusionModule(nn.Module): | |
def __init__(self, out_dim, num_head, dropout= 0.1): | |
super(FusionModule, self).__init__() | |
"""FusionModule. | |
Args: | |
dropout= 0.1 is defaut | |
out_dim: model output dimension | |
num_head = 8: Multi-head Attention | |
""" | |
self.out_dim = out_dim | |
self.num_head = num_head | |
self.WqS = nn.Linear(out_dim, out_dim) | |
self.WkS = nn.Linear(out_dim, out_dim) | |
self.WvS = nn.Linear(out_dim, out_dim) | |
self.WqT = nn.Linear(out_dim, out_dim) | |
self.WkT = nn.Linear(out_dim, out_dim) | |
self.WvT = nn.Linear(out_dim, out_dim) | |
self.multi_head_attention = nn.MultiheadAttention(out_dim, num_head, dropout=dropout) | |
def forward(self, zs, zt): | |
# nn.MultiheadAttention The input representation is (token_length, batch_size, out_dim) | |
# zs = protein_representation.permute(1, 0, 2) | |
# zt = disease_representation.permute(1, 0, 2) | |
# Compute query, key and value representations | |
qs = self.WqS(zs) | |
ks = self.WkS(zs) | |
vs = self.WvS(zs) | |
qt = self.WqT(zt) | |
kt = self.WkT(zt) | |
vt = self.WvT(zt) | |
#self.multi_head_attention() The function returns two values: the representation and the attention weight matrix, computed after multiple attentions. In this case, we only care about the computed representation and not the attention weight matrix, so "_" is used to indicate that we do not intend to use or store the second return value. | |
zs_attention1, _ = self.multi_head_attention(qs, ks, vs) | |
zs_attention2, _ = self.multi_head_attention(qs, kt, vt) | |
zt_attention1, _ = self.multi_head_attention(qt, kt, vt) | |
zt_attention2, _ = self.multi_head_attention(qt, ks, vs) | |
protein_fused = 0.5 * (zs_attention1 + zs_attention2) | |
dis_fused = 0.5 * (zt_attention1 + zt_attention2) | |
return protein_fused, dis_fused | |
class CrossAttentionBlock(nn.Module): | |
def __init__(self, hidden_dim, num_heads): | |
super(CrossAttentionBlock, self).__init__() | |
if hidden_dim % num_heads != 0: | |
raise ValueError( | |
"The hidden size (%d) is not a multiple of the number of attention " | |
"heads (%d)" % (hidden_dim, num_heads)) | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.head_size = hidden_dim // num_heads | |
self.query1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.key1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.value1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.query2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.key2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
self.value2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
def _alpha_from_logits(self, logits, mask_row, mask_col, inf=1e6): | |
N, L1, L2, H = logits.shape | |
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H) | |
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H) | |
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col) | |
logits = torch.where(mask_pair, logits, logits - inf) | |
alpha = torch.softmax(logits, dim=2) | |
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1) | |
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) | |
return alpha | |
def _heads(self, x, n_heads, n_ch): | |
s = list(x.size())[:-1] + [n_heads, n_ch] | |
return x.view(*s) | |
def forward(self, input1, input2, mask1, mask2): | |
query1 = self._heads(self.query1(input1), self.num_heads, self.head_size) | |
key1 = self._heads(self.key1(input1), self.num_heads, self.head_size) | |
query2 = self._heads(self.query2(input2), self.num_heads, self.head_size) | |
key2 = self._heads(self.key2(input2), self.num_heads, self.head_size) | |
logits11 = torch.einsum('blhd, bkhd->blkh', query1, key1) | |
logits12 = torch.einsum('blhd, bkhd->blkh', query1, key2) | |
logits21 = torch.einsum('blhd, bkhd->blkh', query2, key1) | |
logits22 = torch.einsum('blhd, bkhd->blkh', query2, key2) | |
alpha11 = self._alpha_from_logits(logits11, mask1, mask1) | |
alpha12 = self._alpha_from_logits(logits12, mask1, mask2) | |
alpha21 = self._alpha_from_logits(logits21, mask2, mask1) | |
alpha22 = self._alpha_from_logits(logits22, mask2, mask2) | |
value1 = self._heads(self.value1(input1), self.num_heads, self.head_size) | |
value2 = self._heads(self.value2(input2), self.num_heads, self.head_size) | |
output1 = (torch.einsum('blkh, bkhd->blhd', alpha11, value1).flatten(-2) + | |
torch.einsum('blkh, bkhd->blhd', alpha12, value2).flatten(-2)) / 2 | |
output2 = (torch.einsum('blkh, bkhd->blhd', alpha21, value1).flatten(-2) + | |
torch.einsum('blkh, bkhd->blhd', alpha22, value2).flatten(-2)) / 2 | |
return output1, output2 | |
class GDA_Metric_Learning(GDANet): | |
def __init__( | |
self, prot_encoder, disease_encoder, prot_out_dim, disease_out_dim, args | |
): | |
"""Constructor for the model. | |
Args: | |
prot_encoder (_type_): Protein encoder. | |
disease_encoder (_type_): Disease Textual encoder. | |
prot_out_dim (_type_): Dimension of the Protein encoder. | |
disease_out_dim (_type_): Dimension of the Disease encoder. | |
args (_type_): _description_ | |
""" | |
super(GDA_Metric_Learning, self).__init__( | |
prot_encoder, | |
disease_encoder, | |
) | |
self.prot_encoder = prot_encoder | |
self.disease_encoder = disease_encoder | |
self.loss = args.loss | |
self.use_miner = args.use_miner | |
self.miner_margin = args.miner_margin | |
self.agg_mode = args.agg_mode | |
self.prot_reg = nn.Linear(prot_out_dim, 1024) | |
# self.prot_reg = nn.Linear(prot_out_dim, disease_out_dim) | |
self.dis_reg = nn.Linear(disease_out_dim, 1024) | |
# self.prot_adapter_name = None | |
# self.disease_adapter_name = None | |
self.fusion_layer = FusionModule(1024, num_head=8) | |
self.cross_attention_layer = CrossAttentionBlock(1024, 8) | |
# # MMP Prediction Heads | |
# self.prot_pred_head = nn.Sequential( | |
# nn.Linear(disease_out_dim, disease_out_dim), | |
# nn.ReLU(), | |
# nn.Linear(disease_out_dim, 1280) #vocabulary size : prot model tokenize length 30 446 | |
# ) | |
# self.dise_pred_head = nn.Sequential( | |
# nn.Linear(disease_out_dim, disease_out_dim), | |
# nn.ReLU(), | |
# nn.Linear(disease_out_dim, 768) #vocabulary size : disease model tokenize length 30522 | |
# ) | |
if self.use_miner: | |
self.miner = miners.TripletMarginMiner( | |
margin=args.miner_margin, type_of_triplets="all" | |
) | |
else: | |
self.miner = None | |
if self.loss == "ms_loss": | |
self.loss = losses.MultiSimilarityLoss( | |
alpha=2, beta=50, base=0.5 | |
) # 1,2,3; 40,50,60 | |
#1_40=1.5141 50=1.4988 60=1.4905 2_60=1.1786 50=1.1874 40=1.2008 3_40=1.1146 50=1.1012 | |
elif self.loss == "circle_loss": | |
self.loss = losses.CircleLoss( | |
m=0.4, gamma=80 | |
) | |
elif self.loss == "triplet_loss": | |
self.loss = losses.TripletMarginLoss( | |
margin=0.05, swap=False, smooth_loss=False, | |
triplets_per_anchor="all") | |
# distance = CosineSimilarity(), | |
# reducer = ThresholdReducer(high=0.3), | |
# embedding_regularizer = LpRegularizer() ) | |
elif self.loss == "infoNCE": | |
self.loss = losses.NTXentLoss( | |
temperature=0.07 | |
) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
elif self.loss == "lifted_structure_loss": | |
self.loss = losses.LiftedStructureLoss( | |
neg_margin=1, pos_margin=0 | |
) | |
elif self.loss == "nca_loss": | |
self.loss = losses.NCALoss( | |
softmax_scale=1 | |
) | |
self.fusion = False | |
# self.stack = False | |
self.dropout = torch.nn.Dropout(args.dropout) | |
print("miner:", self.miner) | |
print("loss:", self.loss) | |
# def add_fusion(self): | |
# adapter_setup = Fuse("prot_adapter", "disease_adapter") | |
# self.prot_encoder.add_fusion(adapter_setup) | |
# self.prot_encoder.set_active_adapters(adapter_setup) | |
# self.prot_encoder.train_fusion(adapter_setup) | |
# self.disease_encoder.add_fusion(adapter_setup) | |
# self.disease_encoder.set_active_adapters(adapter_setup) | |
# self.disease_encoder.train_fusion(adapter_setup) | |
# self.fusion = True | |
# def add_stack_gda(self, reduction_factor): | |
# self.add_gda_adapters(reduction_factor=reduction_factor) | |
# # adapter_setup = Fuse("prot_adapter", "disease_adapter") | |
# self.prot_encoder.active_adapters = Stack( | |
# self.prot_adapter_name, self.gda_adapter_name | |
# ) | |
# self.disease_encoder.active_adapters = Stack( | |
# self.disease_adapter_name, self.gda_adapter_name | |
# ) | |
# print("stacked adapters loaded.") | |
# self.stack = True | |
# def load_adapters( | |
# self, | |
# prot_model_path, | |
# disease_model_path, | |
# prot_adapter_name="prot_adapter", | |
# disease_adapter_name="disease_adapter", | |
# ): | |
# if os.path.exists(prot_model_path): | |
# print(f"loading prot adapter from: {prot_model_path}") | |
# self.prot_adapter_name = prot_adapter_name | |
# self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name) | |
# self.prot_encoder.set_active_adapters(prot_adapter_name) | |
# print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}") | |
# else: | |
# print(f"{prot_model_path} not exits") | |
# if os.path.exists(disease_model_path): | |
# print(f"loading prot adapter from: {disease_model_path}") | |
# self.disease_adapter_name = disease_adapter_name | |
# self.disease_encoder.load_adapter( | |
# disease_model_path, load_as=disease_adapter_name | |
# ) | |
# self.disease_encoder.set_active_adapters(disease_adapter_name) | |
# print( | |
# f"load disease adapters from: {disease_model_path} {disease_adapter_name}" | |
# ) | |
# else: | |
# print(f"{disease_model_path} not exits") | |
def non_adapters( | |
self, | |
prot_model_path, | |
disease_model_path, | |
): | |
if os.path.exists(prot_model_path): | |
# Load the entire model for prot_model | |
prot_model = torch.load(prot_model_path) | |
# Set the prot_encoder to the loaded model | |
self.prot_encoder = prot_model.prot_encoder | |
print(f"load protein from: {prot_model_path}") | |
else: | |
print(f"{prot_model_path} not exits") | |
if os.path.exists(disease_model_path): | |
# Load the entire model for disease_model | |
disease_model = torch.load(disease_model_path) | |
# Set the disease_encoder to the loaded model | |
self.disease_encoder = disease_model.disease_encoder | |
print(f"load disease from: {disease_model_path}") | |
else: | |
print(f"{disease_model_path} not exits") | |
# def add_gda_adapters( | |
# self, | |
# gda_adapter_name="gda_adapter", | |
# reduction_factor=16, | |
# ): | |
# """Initialise adapters | |
# Args: | |
# prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
# disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
# reduction_factor (int, optional): _description_. Defaults to 16. | |
# """ | |
# adapter_config = AdapterConfig.load( | |
# "pfeiffer", reduction_factor=reduction_factor | |
# ) | |
# self.gda_adapter_name = gda_adapter_name | |
# self.prot_encoder.add_adapter(gda_adapter_name, config=adapter_config) | |
# self.prot_encoder.train_adapter([gda_adapter_name]) | |
# self.disease_encoder.add_adapter(gda_adapter_name, config=adapter_config) | |
# self.disease_encoder.train_adapter([gda_adapter_name]) | |
# def init_adapters( | |
# self, | |
# prot_adapter_name="gda_prot_adapter", | |
# disease_adapter_name="gda_disease_adapter", | |
# reduction_factor=16, | |
# ): | |
# """Initialise adapters | |
# Args: | |
# prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
# disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
# reduction_factor (int, optional): _description_. Defaults to 16. | |
# """ | |
# adapter_config = AdapterConfig.load( | |
# "pfeiffer", reduction_factor=reduction_factor | |
# ) | |
# self.prot_adapter_name = prot_adapter_name | |
# self.disease_adapter_name = disease_adapter_name | |
# self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config) | |
# self.prot_encoder.train_adapter([prot_adapter_name]) | |
# self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config) | |
# self.disease_encoder.train_adapter([disease_adapter_name]) | |
# print(f"adapter modules initialized") | |
# def save_adapters(self, save_path_prefix, total_step): | |
# """Save adapters into file. | |
# Args: | |
# save_path_prefix (string): saving path prefix. | |
# total_step (int): total step number. | |
# """ | |
# prot_save_dir = os.path.join( | |
# save_path_prefix, f"prot_adapter_step_{total_step}" | |
# )# adapter | |
# disease_save_dir = os.path.join( | |
# save_path_prefix, f"disease_adapter_step_{total_step}" | |
# ) | |
# os.makedirs(prot_save_dir, exist_ok=True) | |
# os.makedirs(disease_save_dir, exist_ok=True) | |
# self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name) | |
# prot_head_save_path = os.path.join(prot_save_dir, "prot_head.bin") | |
# torch.save(self.prot_reg, prot_head_save_path) | |
# self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name) | |
# disease_head_save_path = os.path.join(prot_save_dir, "disease_head.bin") | |
# torch.save(self.prot_reg, disease_head_save_path) | |
# if self.fusion: | |
# self.prot_encoder.save_all_adapters(prot_save_dir) | |
# self.disease_encoder.save_all_adapters(disease_save_dir) | |
def predict(self, query_toks1, query_toks2): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
# Extract input_ids and attention_mask for protein | |
prot_input_ids = query_toks1["input_ids"] | |
prot_attention_mask = query_toks1["attention_mask"] | |
# Extract input_ids and attention_mask for dis | |
dis_input_ids = query_toks2["input_ids"] | |
dis_attention_mask = query_toks2["attention_mask"] | |
# Process inputs through encoders | |
last_hidden_state1 = self.prot_encoder( | |
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True | |
).last_hidden_state | |
last_hidden_state1 = self.prot_reg(last_hidden_state1) | |
last_hidden_state2 = self.disease_encoder( | |
input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True | |
).last_hidden_state | |
last_hidden_state2 = self.dis_reg(last_hidden_state2) | |
# Apply the cross-attention layer | |
prot_fused, dis_fused = self.cross_attention_layer( | |
last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask | |
) | |
# last_hidden_state1 = self.prot_encoder( | |
# query_toks1, return_dict=True | |
# ).last_hidden_state | |
# last_hidden_state1 = self.prot_reg( | |
# last_hidden_state1 | |
# ) # transform the prot embedding into the same dimension as the disease embedding | |
# last_hidden_state2 = self.disease_encoder( | |
# query_toks2, return_dict=True | |
# ).last_hidden_state | |
# last_hidden_state2 = self.dis_reg( | |
# last_hidden_state2 | |
# ) # transform the disease embedding into 1024 | |
# Apply the fusion layer and Recovery of representational shape | |
# prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2) | |
if self.agg_mode == "cls": | |
query_embed1 = prot_fused[:, 0] # query : [batch_size, hidden] | |
query_embed2 = dis_fused[:, 0] # query : [batch_size, hidden] | |
elif self.agg_mode == "mean_all_tok": | |
query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden] | |
query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden] | |
elif self.agg_mode == "mean": | |
query_embed1 = ( | |
prot_fused * query_toks1["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
query_embed2 = ( | |
dis_fused * query_toks2["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
else: | |
raise NotImplementedError() | |
query_embed = torch.cat([query_embed1, query_embed2], dim=1) | |
return query_embed | |
def forward(self, query_toks1, query_toks2, labels): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
# Extract input_ids and attention_mask for protein | |
prot_input_ids = query_toks1["input_ids"] | |
prot_attention_mask = query_toks1["attention_mask"] | |
# Extract input_ids and attention_mask for dis | |
dis_input_ids = query_toks2["input_ids"] | |
dis_attention_mask = query_toks2["attention_mask"] | |
# Process inputs through encoders | |
last_hidden_state1 = self.prot_encoder( | |
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True | |
).last_hidden_state | |
last_hidden_state1 = self.prot_reg(last_hidden_state1) | |
last_hidden_state2 = self.disease_encoder( | |
input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True | |
).last_hidden_state | |
last_hidden_state2 = self.dis_reg(last_hidden_state2) | |
# Apply the cross-attention layer | |
prot_fused, dis_fused = self.cross_attention_layer( | |
last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask | |
) | |
# last_hidden_state1 = self.prot_encoder( | |
# query_toks1, return_dict=True | |
# ).last_hidden_state | |
# last_hidden_state1 = self.prot_reg( | |
# last_hidden_state1 | |
# ) # transform the prot embedding into the same dimension as the disease embedding | |
# last_hidden_state2 = self.disease_encoder( | |
# query_toks2, return_dict=True | |
# ).last_hidden_state | |
# last_hidden_state2 = self.dis_reg( | |
# last_hidden_state2 | |
# ) # transform the disease embedding into 1024 | |
# # Apply the fusion layer and Recovery of representational shape | |
# prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2) | |
if self.agg_mode == "cls": | |
query_embed1 = prot_pred[:, 0] # query : [batch_size, hidden] | |
query_embed2 = dise_pred[:, 0] # query : [batch_size, hidden] | |
elif self.agg_mode == "mean_all_tok": | |
query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden] | |
query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden] | |
elif self.agg_mode == "mean": | |
query_embed1 = ( | |
prot_pred * query_toks1["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
query_embed2 = ( | |
dis_fused * query_toks2["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
else: | |
raise NotImplementedError() | |
# print("query_embed1 =", query_embed1.shape, "query_embed2 =", query_embed2.shape) | |
query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
# print("query_embed =", len(query_embed)) | |
labels = torch.cat([torch.arange(len(labels)), torch.arange(len(labels))], dim=0) | |
if self.use_miner: | |
hard_pairs = self.miner(query_embed, labels) | |
return self.loss(query_embed, labels, hard_pairs)# + loss_mmp | |
else: | |
loss = self.loss(query_embed, labels)# + loss_mmp | |
# print('loss :', loss) | |
return loss | |
def get_embeddings(self, mentions, batch_size=1024): | |
""" | |
Compute all embeddings from mention tokens. | |
""" | |
embedding_table = [] | |
with torch.no_grad(): | |
for start in tqdm(range(0, len(mentions), batch_size)): | |
end = min(start + batch_size, len(mentions)) | |
batch = mentions[start:end] | |
batch_embedding = self.vectorizer(batch) | |
batch_embedding = batch_embedding.cpu() | |
embedding_table.append(batch_embedding) | |
embedding_table = torch.cat(embedding_table, dim=0) | |
return embedding_table | |
class DDA_Metric_Learning(Module): | |
def __init__(self, disease_encoder, args): | |
"""Constructor for the model. | |
Args: | |
disease_encoder (_type_): disease encoder. | |
args (_type_): _description_ | |
""" | |
super(DDA_Metric_Learning, self).__init__() | |
self.disease_encoder = disease_encoder | |
self.loss = args.loss | |
self.use_miner = args.use_miner | |
self.miner_margin = args.miner_margin | |
self.agg_mode = args.agg_mode | |
self.disease_adapter_name = None | |
if self.use_miner: | |
self.miner = miners.TripletMarginMiner( | |
margin=args.miner_margin, type_of_triplets="all" | |
) | |
else: | |
self.miner = None | |
if self.loss == "ms_loss": | |
self.loss = losses.MultiSimilarityLoss( | |
alpha=1, beta=60, base=0.5 | |
) # 1,2,3; 40,50,60 | |
elif self.loss == "circle_loss": | |
self.loss = losses.CircleLoss() | |
elif self.loss == "triplet_loss": | |
self.loss = losses.TripletMarginLoss() | |
elif self.loss == "infoNCE": | |
self.loss = losses.NTXentLoss( | |
temperature=0.07 | |
) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
elif self.loss == "lifted_structure_loss": | |
self.loss = losses.LiftedStructureLoss() | |
elif self.loss == "nca_loss": | |
self.loss = losses.NCALoss() | |
self.reg = None | |
self.cls = None | |
self.dropout = torch.nn.Dropout(args.dropout) | |
print("miner:", self.miner) | |
print("loss:", self.loss) | |
def add_classification_head(self, disease_out_dim=768, out_dim=2): | |
"""Add regression head. | |
Args: | |
disease_out_dim (_type_): disease encoder output dimension. | |
out_dim (int, optional): output dimension. Defaults to 2. | |
drop_out (int, optional): dropout rate. Defaults to 0. | |
""" | |
self.cls = nn.Linear(disease_out_dim * 2, out_dim) | |
def load_disease_adapter( | |
self, | |
disease_model_path, | |
disease_adapter_name="disease_adapter", | |
): | |
if os.path.exists(disease_model_path): | |
self.disease_adapter_name = disease_adapter_name | |
self.disease_encoder.load_adapter( | |
disease_model_path, load_as=disease_adapter_name | |
) | |
self.disease_encoder.set_active_adapters(disease_adapter_name) | |
print( | |
f"load disease adapters from: {disease_model_path} {disease_adapter_name}" | |
) | |
else: | |
print(f"{disease_adapter_name} not exits") | |
def init_adapters( | |
self, | |
disease_adapter_name="disease_adapter", | |
reduction_factor=16, | |
): | |
"""Initialise adapters | |
Args: | |
disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
reduction_factor (int, optional): _description_. Defaults to 16. | |
""" | |
adapter_config = AdapterConfig.load( | |
"pfeiffer", reduction_factor=reduction_factor | |
) | |
self.disease_adapter_name = disease_adapter_name | |
self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config) | |
self.disease_encoder.train_adapter([disease_adapter_name]) | |
def save_adapters(self, save_path_prefix, total_step): | |
"""Save adapters into file. | |
Args: | |
save_path_prefix (string): saving path prefix. | |
total_step (int): total step number. | |
""" | |
disease_save_dir = os.path.join( | |
save_path_prefix, f"disease_adapter_step_{total_step}" | |
) | |
os.makedirs(disease_save_dir, exist_ok=True) | |
self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name) | |
def predict(self, x1, x2): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
if self.agg_mode == "cls": | |
x1 = self.disease_encoder(x1).last_hidden_state[:, 0] | |
x2 = self.disease_encoder(x2).last_hidden_state[:, 0] | |
x = torch.cat((x1, x2), 1) | |
return x | |
else: | |
x1 = self.disease_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x2 = self.disease_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x = torch.cat((x1, x2), 1) | |
return x | |
def module_predict(self, x1, x2): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
if self.agg_mode == "cls": | |
x1 = self.disease_encoder.module(x1).last_hidden_state[:, 0] | |
x2 = self.disease_encoder.module(x2).last_hidden_state[:, 0] | |
x = torch.cat((x1, x2), 1) | |
return x | |
else: | |
x1 = self.disease_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x2 = self.disease_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x = torch.cat((x1, x2), 1) | |
return x | |
def forward(self, query_toks1, query_toks2, labels): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
last_hidden_state1 = self.disease_encoder( | |
**query_toks1, return_dict=True | |
).last_hidden_state | |
last_hidden_state2 = self.disease_encoder( | |
**query_toks2, return_dict=True | |
).last_hidden_state | |
if self.agg_mode == "cls": | |
query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden] | |
query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden] | |
elif self.agg_mode == "mean_all_tok": | |
query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden] | |
query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden] | |
elif self.agg_mode == "mean": | |
query_embed1 = ( | |
last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
query_embed2 = ( | |
last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
else: | |
raise NotImplementedError() | |
query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
labels = torch.cat([labels, labels], dim=0) | |
if self.use_miner: | |
hard_pairs = self.miner(query_embed, labels) | |
print('miner used') | |
return self.loss(query_embed, labels, hard_pairs) | |
else: | |
print('no miner') | |
return self.loss(query_embed, labels) | |
class PPI_Metric_Learning(Module): | |
def __init__(self, prot_encoder, args): | |
"""Constructor for the model. | |
Args: | |
prot_encoder (_type_): Protein encoder. | |
prot_encoder (_type_): prot Textual encoder. | |
prot_out_dim (_type_): Dimension of the Protein encoder. | |
prot_out_dim (_type_): Dimension of the prot encoder. | |
args (_type_): _description_ | |
""" | |
super(PPI_Metric_Learning, self).__init__() | |
self.prot_encoder = prot_encoder | |
self.loss = args.loss | |
self.use_miner = args.use_miner | |
self.miner_margin = args.miner_margin | |
self.agg_mode = args.agg_mode | |
self.prot_adapter_name = None | |
if self.use_miner: | |
self.miner = miners.TripletMarginMiner( | |
margin=args.miner_margin, type_of_triplets="all" | |
) | |
else: | |
self.miner = None | |
if self.loss == "ms_loss": | |
self.loss = losses.MultiSimilarityLoss( | |
alpha=1, beta=60, base=0.5 | |
) # 1,2,3; 40,50,60 | |
elif self.loss == "circle_loss": | |
self.loss = losses.CircleLoss() | |
elif self.loss == "triplet_loss": | |
self.loss = losses.TripletMarginLoss() | |
elif self.loss == "infoNCE": | |
self.loss = losses.NTXentLoss( | |
temperature=0.07 | |
) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
elif self.loss == "lifted_structure_loss": | |
self.loss = losses.LiftedStructureLoss() | |
elif self.loss == "nca_loss": | |
self.loss = losses.NCALoss() | |
self.reg = None | |
self.cls = None | |
self.dropout = torch.nn.Dropout(args.dropout) | |
print("miner:", self.miner) | |
print("loss:", self.loss) | |
def add_classification_head(self, prot_out_dim=1024, out_dim=2): | |
"""Add regression head. | |
Args: | |
prot_out_dim (_type_): protein encoder output dimension. | |
disease_out_dim (_type_): disease encoder output dimension. | |
out_dim (int, optional): output dimension. Defaults to 2. | |
drop_out (int, optional): dropout rate. Defaults to 0. | |
""" | |
self.cls = nn.Linear(prot_out_dim + prot_out_dim, out_dim) | |
def load_prot_adapter( | |
self, | |
prot_model_path, | |
prot_adapter_name="prot_adapter", | |
): | |
if os.path.exists(prot_model_path): | |
self.prot_adapter_name = prot_adapter_name | |
self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name) | |
self.prot_encoder.set_active_adapters(prot_adapter_name) | |
print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}") | |
else: | |
print(f"{prot_model_path} not exits") | |
def init_adapters( | |
self, | |
prot_adapter_name="prot_adapter", | |
reduction_factor=16, | |
): | |
"""Initialise adapters | |
Args: | |
prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
reduction_factor (int, optional): _description_. Defaults to 16. | |
""" | |
adapter_config = AdapterConfig.load( | |
"pfeiffer", reduction_factor=reduction_factor | |
) | |
self.prot_adapter_name = prot_adapter_name | |
self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config) | |
self.prot_encoder.train_adapter([prot_adapter_name]) | |
def save_adapters(self, save_path_prefix, total_step): | |
"""Save adapters into file. | |
Args: | |
save_path_prefix (string): saving path prefix. | |
total_step (int): total step number. | |
""" | |
prot_save_dir = os.path.join( | |
save_path_prefix, f"prot_adapter_step_{total_step}" | |
) | |
os.makedirs(prot_save_dir, exist_ok=True) | |
self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name) | |
def predict(self, x1, x2): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
if self.agg_mode == "cls": | |
x1 = self.prot_encoder(x1).last_hidden_state[:, 0] | |
x2 = self.prot_encoder(x2).last_hidden_state[:, 0] | |
x = torch.cat((x1, x2), 1) | |
return x | |
else: | |
x1 = self.prot_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x2 = self.prot_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x = torch.cat((x1, x2), 1) | |
return x | |
def module_predict(self, x1, x2): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
if self.agg_mode == "cls": | |
x1 = self.prot_encoder.module(x1).last_hidden_state[:, 0] | |
x2 = self.prot_encoder.module(x2).last_hidden_state[:, 0] | |
x = torch.cat((x1, x2), 1) | |
return x | |
else: | |
x1 = self.prot_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x2 = self.prot_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
x = torch.cat((x1, x2), 1) | |
return x | |
def forward(self, query_toks1, query_toks2, labels): | |
""" | |
query : (N, h), candidates : (N, topk, h) | |
output : (N, topk) | |
""" | |
last_hidden_state1 = self.prot_encoder( | |
**query_toks1, return_dict=True | |
).last_hidden_state | |
last_hidden_state2 = self.prot_encoder( | |
**query_toks2, return_dict=True | |
).last_hidden_state | |
if self.agg_mode == "cls": | |
query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden] | |
query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden] | |
elif self.agg_mode == "mean_all_tok": | |
query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden] | |
query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden] | |
elif self.agg_mode == "mean": | |
query_embed1 = ( | |
last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
query_embed2 = ( | |
last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1) | |
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
else: | |
raise NotImplementedError() | |
query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
labels = torch.cat([labels, labels], dim=0) | |
if self.use_miner: | |
hard_pairs = self.miner(query_embed, labels) | |
return self.loss(query_embed, labels, hard_pairs) | |
else: | |
return self.loss(query_embed, labels) | |