FusionGDA / src /utils /metric_learning_models.py
ZhaohanM
FusionGDA
7c46397
import logging
import os
import sys
sys.path.append("../")
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
import torch
import torch.nn as nn
from torch.nn import functional as F
from pytorch_metric_learning import losses, miners
from torch.cuda.amp import autocast
from torch.nn import Module
from tqdm import tqdm
from utils.gd_model import GDANet
from torch.nn import MultiheadAttention
from transformers import BertModel
from transformers import EsmModel, EsmConfig
LOGGER = logging.getLogger(__name__)
class FusionModule(nn.Module):
def __init__(self, out_dim, num_head, dropout= 0.1):
super(FusionModule, self).__init__()
"""FusionModule.
Args:
dropout= 0.1 is defaut
out_dim: model output dimension
num_head = 8: Multi-head Attention
"""
self.out_dim = out_dim
self.num_head = num_head
self.WqS = nn.Linear(out_dim, out_dim)
self.WkS = nn.Linear(out_dim, out_dim)
self.WvS = nn.Linear(out_dim, out_dim)
self.WqT = nn.Linear(out_dim, out_dim)
self.WkT = nn.Linear(out_dim, out_dim)
self.WvT = nn.Linear(out_dim, out_dim)
self.multi_head_attention = nn.MultiheadAttention(out_dim, num_head, dropout=dropout)
def forward(self, zs, zt):
# nn.MultiheadAttention The input representation is (token_length, batch_size, out_dim)
# zs = protein_representation.permute(1, 0, 2)
# zt = disease_representation.permute(1, 0, 2)
# Compute query, key and value representations
qs = self.WqS(zs)
ks = self.WkS(zs)
vs = self.WvS(zs)
qt = self.WqT(zt)
kt = self.WkT(zt)
vt = self.WvT(zt)
#self.multi_head_attention() The function returns two values: the representation and the attention weight matrix, computed after multiple attentions. In this case, we only care about the computed representation and not the attention weight matrix, so "_" is used to indicate that we do not intend to use or store the second return value.
zs_attention1, _ = self.multi_head_attention(qs, ks, vs)
zs_attention2, _ = self.multi_head_attention(qs, kt, vt)
zt_attention1, _ = self.multi_head_attention(qt, kt, vt)
zt_attention2, _ = self.multi_head_attention(qt, ks, vs)
protein_fused = 0.5 * (zs_attention1 + zs_attention2)
dis_fused = 0.5 * (zt_attention1 + zt_attention2)
return protein_fused, dis_fused
class CrossAttentionBlock(nn.Module):
def __init__(self, hidden_dim, num_heads):
super(CrossAttentionBlock, self).__init__()
if hidden_dim % num_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_dim, num_heads))
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_size = hidden_dim // num_heads
self.query1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.key1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.value1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.query2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.key2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.value2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
def _alpha_from_logits(self, logits, mask_row, mask_col, inf=1e6):
N, L1, L2, H = logits.shape
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
logits = torch.where(mask_pair, logits, logits - inf)
alpha = torch.softmax(logits, dim=2)
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
return alpha
def _heads(self, x, n_heads, n_ch):
s = list(x.size())[:-1] + [n_heads, n_ch]
return x.view(*s)
def forward(self, input1, input2, mask1, mask2):
query1 = self._heads(self.query1(input1), self.num_heads, self.head_size)
key1 = self._heads(self.key1(input1), self.num_heads, self.head_size)
query2 = self._heads(self.query2(input2), self.num_heads, self.head_size)
key2 = self._heads(self.key2(input2), self.num_heads, self.head_size)
logits11 = torch.einsum('blhd, bkhd->blkh', query1, key1)
logits12 = torch.einsum('blhd, bkhd->blkh', query1, key2)
logits21 = torch.einsum('blhd, bkhd->blkh', query2, key1)
logits22 = torch.einsum('blhd, bkhd->blkh', query2, key2)
alpha11 = self._alpha_from_logits(logits11, mask1, mask1)
alpha12 = self._alpha_from_logits(logits12, mask1, mask2)
alpha21 = self._alpha_from_logits(logits21, mask2, mask1)
alpha22 = self._alpha_from_logits(logits22, mask2, mask2)
value1 = self._heads(self.value1(input1), self.num_heads, self.head_size)
value2 = self._heads(self.value2(input2), self.num_heads, self.head_size)
output1 = (torch.einsum('blkh, bkhd->blhd', alpha11, value1).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', alpha12, value2).flatten(-2)) / 2
output2 = (torch.einsum('blkh, bkhd->blhd', alpha21, value1).flatten(-2) +
torch.einsum('blkh, bkhd->blhd', alpha22, value2).flatten(-2)) / 2
return output1, output2
class GDA_Metric_Learning(GDANet):
def __init__(
self, prot_encoder, disease_encoder, prot_out_dim, disease_out_dim, args
):
"""Constructor for the model.
Args:
prot_encoder (_type_): Protein encoder.
disease_encoder (_type_): Disease Textual encoder.
prot_out_dim (_type_): Dimension of the Protein encoder.
disease_out_dim (_type_): Dimension of the Disease encoder.
args (_type_): _description_
"""
super(GDA_Metric_Learning, self).__init__(
prot_encoder,
disease_encoder,
)
self.prot_encoder = prot_encoder
self.disease_encoder = disease_encoder
self.loss = args.loss
self.use_miner = args.use_miner
self.miner_margin = args.miner_margin
self.agg_mode = args.agg_mode
self.prot_reg = nn.Linear(prot_out_dim, 1024)
# self.prot_reg = nn.Linear(prot_out_dim, disease_out_dim)
self.dis_reg = nn.Linear(disease_out_dim, 1024)
# self.prot_adapter_name = None
# self.disease_adapter_name = None
self.fusion_layer = FusionModule(1024, num_head=8)
self.cross_attention_layer = CrossAttentionBlock(1024, 8)
# # MMP Prediction Heads
# self.prot_pred_head = nn.Sequential(
# nn.Linear(disease_out_dim, disease_out_dim),
# nn.ReLU(),
# nn.Linear(disease_out_dim, 1280) #vocabulary size : prot model tokenize length 30 446
# )
# self.dise_pred_head = nn.Sequential(
# nn.Linear(disease_out_dim, disease_out_dim),
# nn.ReLU(),
# nn.Linear(disease_out_dim, 768) #vocabulary size : disease model tokenize length 30522
# )
if self.use_miner:
self.miner = miners.TripletMarginMiner(
margin=args.miner_margin, type_of_triplets="all"
)
else:
self.miner = None
if self.loss == "ms_loss":
self.loss = losses.MultiSimilarityLoss(
alpha=2, beta=50, base=0.5
) # 1,2,3; 40,50,60
#1_40=1.5141 50=1.4988 60=1.4905 2_60=1.1786 50=1.1874 40=1.2008 3_40=1.1146 50=1.1012
elif self.loss == "circle_loss":
self.loss = losses.CircleLoss(
m=0.4, gamma=80
)
elif self.loss == "triplet_loss":
self.loss = losses.TripletMarginLoss(
margin=0.05, swap=False, smooth_loss=False,
triplets_per_anchor="all")
# distance = CosineSimilarity(),
# reducer = ThresholdReducer(high=0.3),
# embedding_regularizer = LpRegularizer() )
elif self.loss == "infoNCE":
self.loss = losses.NTXentLoss(
temperature=0.07
) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
elif self.loss == "lifted_structure_loss":
self.loss = losses.LiftedStructureLoss(
neg_margin=1, pos_margin=0
)
elif self.loss == "nca_loss":
self.loss = losses.NCALoss(
softmax_scale=1
)
self.fusion = False
# self.stack = False
self.dropout = torch.nn.Dropout(args.dropout)
print("miner:", self.miner)
print("loss:", self.loss)
# def add_fusion(self):
# adapter_setup = Fuse("prot_adapter", "disease_adapter")
# self.prot_encoder.add_fusion(adapter_setup)
# self.prot_encoder.set_active_adapters(adapter_setup)
# self.prot_encoder.train_fusion(adapter_setup)
# self.disease_encoder.add_fusion(adapter_setup)
# self.disease_encoder.set_active_adapters(adapter_setup)
# self.disease_encoder.train_fusion(adapter_setup)
# self.fusion = True
# def add_stack_gda(self, reduction_factor):
# self.add_gda_adapters(reduction_factor=reduction_factor)
# # adapter_setup = Fuse("prot_adapter", "disease_adapter")
# self.prot_encoder.active_adapters = Stack(
# self.prot_adapter_name, self.gda_adapter_name
# )
# self.disease_encoder.active_adapters = Stack(
# self.disease_adapter_name, self.gda_adapter_name
# )
# print("stacked adapters loaded.")
# self.stack = True
# def load_adapters(
# self,
# prot_model_path,
# disease_model_path,
# prot_adapter_name="prot_adapter",
# disease_adapter_name="disease_adapter",
# ):
# if os.path.exists(prot_model_path):
# print(f"loading prot adapter from: {prot_model_path}")
# self.prot_adapter_name = prot_adapter_name
# self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name)
# self.prot_encoder.set_active_adapters(prot_adapter_name)
# print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}")
# else:
# print(f"{prot_model_path} not exits")
# if os.path.exists(disease_model_path):
# print(f"loading prot adapter from: {disease_model_path}")
# self.disease_adapter_name = disease_adapter_name
# self.disease_encoder.load_adapter(
# disease_model_path, load_as=disease_adapter_name
# )
# self.disease_encoder.set_active_adapters(disease_adapter_name)
# print(
# f"load disease adapters from: {disease_model_path} {disease_adapter_name}"
# )
# else:
# print(f"{disease_model_path} not exits")
def non_adapters(
self,
prot_model_path,
disease_model_path,
):
if os.path.exists(prot_model_path):
# Load the entire model for prot_model
prot_model = torch.load(prot_model_path)
# Set the prot_encoder to the loaded model
self.prot_encoder = prot_model.prot_encoder
print(f"load protein from: {prot_model_path}")
else:
print(f"{prot_model_path} not exits")
if os.path.exists(disease_model_path):
# Load the entire model for disease_model
disease_model = torch.load(disease_model_path)
# Set the disease_encoder to the loaded model
self.disease_encoder = disease_model.disease_encoder
print(f"load disease from: {disease_model_path}")
else:
print(f"{disease_model_path} not exits")
# def add_gda_adapters(
# self,
# gda_adapter_name="gda_adapter",
# reduction_factor=16,
# ):
# """Initialise adapters
# Args:
# prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
# disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
# reduction_factor (int, optional): _description_. Defaults to 16.
# """
# adapter_config = AdapterConfig.load(
# "pfeiffer", reduction_factor=reduction_factor
# )
# self.gda_adapter_name = gda_adapter_name
# self.prot_encoder.add_adapter(gda_adapter_name, config=adapter_config)
# self.prot_encoder.train_adapter([gda_adapter_name])
# self.disease_encoder.add_adapter(gda_adapter_name, config=adapter_config)
# self.disease_encoder.train_adapter([gda_adapter_name])
# def init_adapters(
# self,
# prot_adapter_name="gda_prot_adapter",
# disease_adapter_name="gda_disease_adapter",
# reduction_factor=16,
# ):
# """Initialise adapters
# Args:
# prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
# disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
# reduction_factor (int, optional): _description_. Defaults to 16.
# """
# adapter_config = AdapterConfig.load(
# "pfeiffer", reduction_factor=reduction_factor
# )
# self.prot_adapter_name = prot_adapter_name
# self.disease_adapter_name = disease_adapter_name
# self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config)
# self.prot_encoder.train_adapter([prot_adapter_name])
# self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config)
# self.disease_encoder.train_adapter([disease_adapter_name])
# print(f"adapter modules initialized")
# def save_adapters(self, save_path_prefix, total_step):
# """Save adapters into file.
# Args:
# save_path_prefix (string): saving path prefix.
# total_step (int): total step number.
# """
# prot_save_dir = os.path.join(
# save_path_prefix, f"prot_adapter_step_{total_step}"
# )# adapter
# disease_save_dir = os.path.join(
# save_path_prefix, f"disease_adapter_step_{total_step}"
# )
# os.makedirs(prot_save_dir, exist_ok=True)
# os.makedirs(disease_save_dir, exist_ok=True)
# self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name)
# prot_head_save_path = os.path.join(prot_save_dir, "prot_head.bin")
# torch.save(self.prot_reg, prot_head_save_path)
# self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name)
# disease_head_save_path = os.path.join(prot_save_dir, "disease_head.bin")
# torch.save(self.prot_reg, disease_head_save_path)
# if self.fusion:
# self.prot_encoder.save_all_adapters(prot_save_dir)
# self.disease_encoder.save_all_adapters(disease_save_dir)
def predict(self, query_toks1, query_toks2):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
# Extract input_ids and attention_mask for protein
prot_input_ids = query_toks1["input_ids"]
prot_attention_mask = query_toks1["attention_mask"]
# Extract input_ids and attention_mask for dis
dis_input_ids = query_toks2["input_ids"]
dis_attention_mask = query_toks2["attention_mask"]
# Process inputs through encoders
last_hidden_state1 = self.prot_encoder(
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
).last_hidden_state
last_hidden_state1 = self.prot_reg(last_hidden_state1)
last_hidden_state2 = self.disease_encoder(
input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True
).last_hidden_state
last_hidden_state2 = self.dis_reg(last_hidden_state2)
# Apply the cross-attention layer
prot_fused, dis_fused = self.cross_attention_layer(
last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask
)
# last_hidden_state1 = self.prot_encoder(
# query_toks1, return_dict=True
# ).last_hidden_state
# last_hidden_state1 = self.prot_reg(
# last_hidden_state1
# ) # transform the prot embedding into the same dimension as the disease embedding
# last_hidden_state2 = self.disease_encoder(
# query_toks2, return_dict=True
# ).last_hidden_state
# last_hidden_state2 = self.dis_reg(
# last_hidden_state2
# ) # transform the disease embedding into 1024
# Apply the fusion layer and Recovery of representational shape
# prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2)
if self.agg_mode == "cls":
query_embed1 = prot_fused[:, 0] # query : [batch_size, hidden]
query_embed2 = dis_fused[:, 0] # query : [batch_size, hidden]
elif self.agg_mode == "mean_all_tok":
query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden]
query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden]
elif self.agg_mode == "mean":
query_embed1 = (
prot_fused * query_toks1["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
query_embed2 = (
dis_fused * query_toks2["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
query_embed = torch.cat([query_embed1, query_embed2], dim=1)
return query_embed
def forward(self, query_toks1, query_toks2, labels):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
# Extract input_ids and attention_mask for protein
prot_input_ids = query_toks1["input_ids"]
prot_attention_mask = query_toks1["attention_mask"]
# Extract input_ids and attention_mask for dis
dis_input_ids = query_toks2["input_ids"]
dis_attention_mask = query_toks2["attention_mask"]
# Process inputs through encoders
last_hidden_state1 = self.prot_encoder(
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
).last_hidden_state
last_hidden_state1 = self.prot_reg(last_hidden_state1)
last_hidden_state2 = self.disease_encoder(
input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True
).last_hidden_state
last_hidden_state2 = self.dis_reg(last_hidden_state2)
# Apply the cross-attention layer
prot_fused, dis_fused = self.cross_attention_layer(
last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask
)
# last_hidden_state1 = self.prot_encoder(
# query_toks1, return_dict=True
# ).last_hidden_state
# last_hidden_state1 = self.prot_reg(
# last_hidden_state1
# ) # transform the prot embedding into the same dimension as the disease embedding
# last_hidden_state2 = self.disease_encoder(
# query_toks2, return_dict=True
# ).last_hidden_state
# last_hidden_state2 = self.dis_reg(
# last_hidden_state2
# ) # transform the disease embedding into 1024
# # Apply the fusion layer and Recovery of representational shape
# prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2)
if self.agg_mode == "cls":
query_embed1 = prot_pred[:, 0] # query : [batch_size, hidden]
query_embed2 = dise_pred[:, 0] # query : [batch_size, hidden]
elif self.agg_mode == "mean_all_tok":
query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden]
query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden]
elif self.agg_mode == "mean":
query_embed1 = (
prot_pred * query_toks1["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
query_embed2 = (
dis_fused * query_toks2["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
# print("query_embed1 =", query_embed1.shape, "query_embed2 =", query_embed2.shape)
query_embed = torch.cat([query_embed1, query_embed2], dim=0)
# print("query_embed =", len(query_embed))
labels = torch.cat([torch.arange(len(labels)), torch.arange(len(labels))], dim=0)
if self.use_miner:
hard_pairs = self.miner(query_embed, labels)
return self.loss(query_embed, labels, hard_pairs)# + loss_mmp
else:
loss = self.loss(query_embed, labels)# + loss_mmp
# print('loss :', loss)
return loss
def get_embeddings(self, mentions, batch_size=1024):
"""
Compute all embeddings from mention tokens.
"""
embedding_table = []
with torch.no_grad():
for start in tqdm(range(0, len(mentions), batch_size)):
end = min(start + batch_size, len(mentions))
batch = mentions[start:end]
batch_embedding = self.vectorizer(batch)
batch_embedding = batch_embedding.cpu()
embedding_table.append(batch_embedding)
embedding_table = torch.cat(embedding_table, dim=0)
return embedding_table
class DDA_Metric_Learning(Module):
def __init__(self, disease_encoder, args):
"""Constructor for the model.
Args:
disease_encoder (_type_): disease encoder.
args (_type_): _description_
"""
super(DDA_Metric_Learning, self).__init__()
self.disease_encoder = disease_encoder
self.loss = args.loss
self.use_miner = args.use_miner
self.miner_margin = args.miner_margin
self.agg_mode = args.agg_mode
self.disease_adapter_name = None
if self.use_miner:
self.miner = miners.TripletMarginMiner(
margin=args.miner_margin, type_of_triplets="all"
)
else:
self.miner = None
if self.loss == "ms_loss":
self.loss = losses.MultiSimilarityLoss(
alpha=1, beta=60, base=0.5
) # 1,2,3; 40,50,60
elif self.loss == "circle_loss":
self.loss = losses.CircleLoss()
elif self.loss == "triplet_loss":
self.loss = losses.TripletMarginLoss()
elif self.loss == "infoNCE":
self.loss = losses.NTXentLoss(
temperature=0.07
) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
elif self.loss == "lifted_structure_loss":
self.loss = losses.LiftedStructureLoss()
elif self.loss == "nca_loss":
self.loss = losses.NCALoss()
self.reg = None
self.cls = None
self.dropout = torch.nn.Dropout(args.dropout)
print("miner:", self.miner)
print("loss:", self.loss)
def add_classification_head(self, disease_out_dim=768, out_dim=2):
"""Add regression head.
Args:
disease_out_dim (_type_): disease encoder output dimension.
out_dim (int, optional): output dimension. Defaults to 2.
drop_out (int, optional): dropout rate. Defaults to 0.
"""
self.cls = nn.Linear(disease_out_dim * 2, out_dim)
def load_disease_adapter(
self,
disease_model_path,
disease_adapter_name="disease_adapter",
):
if os.path.exists(disease_model_path):
self.disease_adapter_name = disease_adapter_name
self.disease_encoder.load_adapter(
disease_model_path, load_as=disease_adapter_name
)
self.disease_encoder.set_active_adapters(disease_adapter_name)
print(
f"load disease adapters from: {disease_model_path} {disease_adapter_name}"
)
else:
print(f"{disease_adapter_name} not exits")
def init_adapters(
self,
disease_adapter_name="disease_adapter",
reduction_factor=16,
):
"""Initialise adapters
Args:
disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
reduction_factor (int, optional): _description_. Defaults to 16.
"""
adapter_config = AdapterConfig.load(
"pfeiffer", reduction_factor=reduction_factor
)
self.disease_adapter_name = disease_adapter_name
self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config)
self.disease_encoder.train_adapter([disease_adapter_name])
def save_adapters(self, save_path_prefix, total_step):
"""Save adapters into file.
Args:
save_path_prefix (string): saving path prefix.
total_step (int): total step number.
"""
disease_save_dir = os.path.join(
save_path_prefix, f"disease_adapter_step_{total_step}"
)
os.makedirs(disease_save_dir, exist_ok=True)
self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name)
def predict(self, x1, x2):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
if self.agg_mode == "cls":
x1 = self.disease_encoder(x1).last_hidden_state[:, 0]
x2 = self.disease_encoder(x2).last_hidden_state[:, 0]
x = torch.cat((x1, x2), 1)
return x
else:
x1 = self.disease_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
x2 = self.disease_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
x = torch.cat((x1, x2), 1)
return x
def module_predict(self, x1, x2):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
if self.agg_mode == "cls":
x1 = self.disease_encoder.module(x1).last_hidden_state[:, 0]
x2 = self.disease_encoder.module(x2).last_hidden_state[:, 0]
x = torch.cat((x1, x2), 1)
return x
else:
x1 = self.disease_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
x2 = self.disease_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
x = torch.cat((x1, x2), 1)
return x
@autocast()
def forward(self, query_toks1, query_toks2, labels):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
last_hidden_state1 = self.disease_encoder(
**query_toks1, return_dict=True
).last_hidden_state
last_hidden_state2 = self.disease_encoder(
**query_toks2, return_dict=True
).last_hidden_state
if self.agg_mode == "cls":
query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden]
query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden]
elif self.agg_mode == "mean_all_tok":
query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden]
query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden]
elif self.agg_mode == "mean":
query_embed1 = (
last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
query_embed2 = (
last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
query_embed = torch.cat([query_embed1, query_embed2], dim=0)
labels = torch.cat([labels, labels], dim=0)
if self.use_miner:
hard_pairs = self.miner(query_embed, labels)
print('miner used')
return self.loss(query_embed, labels, hard_pairs)
else:
print('no miner')
return self.loss(query_embed, labels)
class PPI_Metric_Learning(Module):
def __init__(self, prot_encoder, args):
"""Constructor for the model.
Args:
prot_encoder (_type_): Protein encoder.
prot_encoder (_type_): prot Textual encoder.
prot_out_dim (_type_): Dimension of the Protein encoder.
prot_out_dim (_type_): Dimension of the prot encoder.
args (_type_): _description_
"""
super(PPI_Metric_Learning, self).__init__()
self.prot_encoder = prot_encoder
self.loss = args.loss
self.use_miner = args.use_miner
self.miner_margin = args.miner_margin
self.agg_mode = args.agg_mode
self.prot_adapter_name = None
if self.use_miner:
self.miner = miners.TripletMarginMiner(
margin=args.miner_margin, type_of_triplets="all"
)
else:
self.miner = None
if self.loss == "ms_loss":
self.loss = losses.MultiSimilarityLoss(
alpha=1, beta=60, base=0.5
) # 1,2,3; 40,50,60
elif self.loss == "circle_loss":
self.loss = losses.CircleLoss()
elif self.loss == "triplet_loss":
self.loss = losses.TripletMarginLoss()
elif self.loss == "infoNCE":
self.loss = losses.NTXentLoss(
temperature=0.07
) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
elif self.loss == "lifted_structure_loss":
self.loss = losses.LiftedStructureLoss()
elif self.loss == "nca_loss":
self.loss = losses.NCALoss()
self.reg = None
self.cls = None
self.dropout = torch.nn.Dropout(args.dropout)
print("miner:", self.miner)
print("loss:", self.loss)
def add_classification_head(self, prot_out_dim=1024, out_dim=2):
"""Add regression head.
Args:
prot_out_dim (_type_): protein encoder output dimension.
disease_out_dim (_type_): disease encoder output dimension.
out_dim (int, optional): output dimension. Defaults to 2.
drop_out (int, optional): dropout rate. Defaults to 0.
"""
self.cls = nn.Linear(prot_out_dim + prot_out_dim, out_dim)
def load_prot_adapter(
self,
prot_model_path,
prot_adapter_name="prot_adapter",
):
if os.path.exists(prot_model_path):
self.prot_adapter_name = prot_adapter_name
self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name)
self.prot_encoder.set_active_adapters(prot_adapter_name)
print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}")
else:
print(f"{prot_model_path} not exits")
def init_adapters(
self,
prot_adapter_name="prot_adapter",
reduction_factor=16,
):
"""Initialise adapters
Args:
prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
reduction_factor (int, optional): _description_. Defaults to 16.
"""
adapter_config = AdapterConfig.load(
"pfeiffer", reduction_factor=reduction_factor
)
self.prot_adapter_name = prot_adapter_name
self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config)
self.prot_encoder.train_adapter([prot_adapter_name])
def save_adapters(self, save_path_prefix, total_step):
"""Save adapters into file.
Args:
save_path_prefix (string): saving path prefix.
total_step (int): total step number.
"""
prot_save_dir = os.path.join(
save_path_prefix, f"prot_adapter_step_{total_step}"
)
os.makedirs(prot_save_dir, exist_ok=True)
self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name)
def predict(self, x1, x2):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
if self.agg_mode == "cls":
x1 = self.prot_encoder(x1).last_hidden_state[:, 0]
x2 = self.prot_encoder(x2).last_hidden_state[:, 0]
x = torch.cat((x1, x2), 1)
return x
else:
x1 = self.prot_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
x2 = self.prot_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
x = torch.cat((x1, x2), 1)
return x
def module_predict(self, x1, x2):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
if self.agg_mode == "cls":
x1 = self.prot_encoder.module(x1).last_hidden_state[:, 0]
x2 = self.prot_encoder.module(x2).last_hidden_state[:, 0]
x = torch.cat((x1, x2), 1)
return x
else:
x1 = self.prot_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
x2 = self.prot_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
x = torch.cat((x1, x2), 1)
return x
@autocast()
def forward(self, query_toks1, query_toks2, labels):
"""
query : (N, h), candidates : (N, topk, h)
output : (N, topk)
"""
last_hidden_state1 = self.prot_encoder(
**query_toks1, return_dict=True
).last_hidden_state
last_hidden_state2 = self.prot_encoder(
**query_toks2, return_dict=True
).last_hidden_state
if self.agg_mode == "cls":
query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden]
query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden]
elif self.agg_mode == "mean_all_tok":
query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden]
query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden]
elif self.agg_mode == "mean":
query_embed1 = (
last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
query_embed2 = (
last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1)
).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
else:
raise NotImplementedError()
query_embed = torch.cat([query_embed1, query_embed2], dim=0)
labels = torch.cat([labels, labels], dim=0)
if self.use_miner:
hard_pairs = self.miner(query_embed, labels)
return self.loss(query_embed, labels, hard_pairs)
else:
return self.loss(query_embed, labels)