FusionGDA / src /utils /gd_model.py
ZhaohanM
FusionGDA
a1af661
import sys
import torch
import torch.nn as nn
sys.path.append("../")
class GDANet(torch.nn.Module):
def __init__(
self,
prot_encoder,
disease_encoder,
):
"""_summary_
Args:
prot_encoder (_type_): _description_
disease_encoder (_type_): _description_
prot_out_dim (int, optional): _description_. Defaults to 1024.
disease_out_dim (int, optional): _description_. Defaults to 768.
drop_out (int, optional): _description_. Defaults to 0.
freeze_prot_encoder (bool, optional): _description_. Defaults to True.
freeze_disease_encoder (bool, optional): _description_. Defaults to True.
"""
super(GDANet, self).__init__()
self.prot_encoder = prot_encoder
self.disease_encoder = disease_encoder
self.cls = None
self.reg = None
def add_regression_head(self, prot_out_dim=1024, disease_out_dim=768):
"""Add regression head.
Args:
prot_out_dim (_type_): protein encoder output dimension.
disease_out_dim (_type_): disease encoder output dimension.
drop_out (int, optional): dropout rate. Defaults to 0.
"""
self.reg = nn.Linear(prot_out_dim + disease_out_dim, 1)
def add_classification_head(
self, prot_out_dim=1024, disease_out_dim=768, out_dim=2
):
"""Add classification head.
Args:
prot_out_dim (_type_): protein encoder output dimension.
disease_out_dim (_type_): disease encoder output dimension.
out_dim (int, optional): output dimension. Defaults to 2.
drop_out (int, optional): dropout rate. Defaults to 0.
"""
self.cls = nn.Linear(prot_out_dim + disease_out_dim, out_dim)
def freeze_encoders(self, freeze_prot_encoder, freeze_disease_encoder):
"""Freeze encoders.
Args:
freeze_prot_encoder (boolean): freeze protein encoder
freeze_disease_encoder (boolean): freeze disease textual encoder
"""
if freeze_prot_encoder:
for param in self.prot_encoder.parameters():
param.requires_grad = False
else:
for param in self.disease_encoder.parameters():
param.requires_grad = True
if freeze_disease_encoder:
for param in self.disease_encoder.parameters():
param.requires_grad = False
else:
for param in self.disease_encoder.parameters():
param.requires_grad = True
print(f"freeze_prot_encoder:{freeze_prot_encoder}")
print(f"freeze_disease_encoder:{freeze_disease_encoder}")