Spaces:
Running
Running
import sys | |
import torch | |
import torch.nn as nn | |
sys.path.append("../") | |
class GDANet(torch.nn.Module): | |
def __init__( | |
self, | |
prot_encoder, | |
disease_encoder, | |
): | |
"""_summary_ | |
Args: | |
prot_encoder (_type_): _description_ | |
disease_encoder (_type_): _description_ | |
prot_out_dim (int, optional): _description_. Defaults to 1024. | |
disease_out_dim (int, optional): _description_. Defaults to 768. | |
drop_out (int, optional): _description_. Defaults to 0. | |
freeze_prot_encoder (bool, optional): _description_. Defaults to True. | |
freeze_disease_encoder (bool, optional): _description_. Defaults to True. | |
""" | |
super(GDANet, self).__init__() | |
self.prot_encoder = prot_encoder | |
self.disease_encoder = disease_encoder | |
self.cls = None | |
self.reg = None | |
def add_regression_head(self, prot_out_dim=1024, disease_out_dim=768): | |
"""Add regression head. | |
Args: | |
prot_out_dim (_type_): protein encoder output dimension. | |
disease_out_dim (_type_): disease encoder output dimension. | |
drop_out (int, optional): dropout rate. Defaults to 0. | |
""" | |
self.reg = nn.Linear(prot_out_dim + disease_out_dim, 1) | |
def add_classification_head( | |
self, prot_out_dim=1024, disease_out_dim=768, out_dim=2 | |
): | |
"""Add classification head. | |
Args: | |
prot_out_dim (_type_): protein encoder output dimension. | |
disease_out_dim (_type_): disease encoder output dimension. | |
out_dim (int, optional): output dimension. Defaults to 2. | |
drop_out (int, optional): dropout rate. Defaults to 0. | |
""" | |
self.cls = nn.Linear(prot_out_dim + disease_out_dim, out_dim) | |
def freeze_encoders(self, freeze_prot_encoder, freeze_disease_encoder): | |
"""Freeze encoders. | |
Args: | |
freeze_prot_encoder (boolean): freeze protein encoder | |
freeze_disease_encoder (boolean): freeze disease textual encoder | |
""" | |
if freeze_prot_encoder: | |
for param in self.prot_encoder.parameters(): | |
param.requires_grad = False | |
else: | |
for param in self.disease_encoder.parameters(): | |
param.requires_grad = True | |
if freeze_disease_encoder: | |
for param in self.disease_encoder.parameters(): | |
param.requires_grad = False | |
else: | |
for param in self.disease_encoder.parameters(): | |
param.requires_grad = True | |
print(f"freeze_prot_encoder:{freeze_prot_encoder}") | |
print(f"freeze_disease_encoder:{freeze_disease_encoder}") | |