File size: 2,747 Bytes
a1af661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys

import torch
import torch.nn as nn

sys.path.append("../")


class GDANet(torch.nn.Module):
    def __init__(
        self,
        prot_encoder,
        disease_encoder,
    ):
        """_summary_

        Args:
            prot_encoder (_type_): _description_
            disease_encoder (_type_): _description_
            prot_out_dim (int, optional): _description_. Defaults to 1024.
            disease_out_dim (int, optional): _description_. Defaults to 768.
            drop_out (int, optional): _description_. Defaults to 0.
            freeze_prot_encoder (bool, optional): _description_. Defaults to True.
            freeze_disease_encoder (bool, optional): _description_. Defaults to True.
        """
        super(GDANet, self).__init__()
        self.prot_encoder = prot_encoder
        self.disease_encoder = disease_encoder
        self.cls = None
        self.reg = None

    def add_regression_head(self, prot_out_dim=1024, disease_out_dim=768):
        """Add regression head.

        Args:
            prot_out_dim (_type_): protein encoder output dimension.
            disease_out_dim (_type_): disease encoder output dimension.
            drop_out (int, optional): dropout rate. Defaults to 0.
        """
        self.reg = nn.Linear(prot_out_dim + disease_out_dim, 1)
    

    def add_classification_head(
        self, prot_out_dim=1024, disease_out_dim=768, out_dim=2
    ):
        """Add classification head.

        Args:
            prot_out_dim (_type_): protein encoder output dimension.
            disease_out_dim (_type_): disease encoder output dimension.
            out_dim (int, optional): output dimension. Defaults to 2.
            drop_out (int, optional): dropout rate. Defaults to 0.
        """
        self.cls = nn.Linear(prot_out_dim + disease_out_dim, out_dim)
        

    def freeze_encoders(self, freeze_prot_encoder, freeze_disease_encoder):
        """Freeze encoders.

        Args:
            freeze_prot_encoder (boolean): freeze protein encoder
            freeze_disease_encoder (boolean): freeze disease textual encoder
        """
        if freeze_prot_encoder:
            for param in self.prot_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
        if freeze_disease_encoder:
            for param in self.disease_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.disease_encoder.parameters():
                param.requires_grad = True
        print(f"freeze_prot_encoder:{freeze_prot_encoder}")
        print(f"freeze_disease_encoder:{freeze_disease_encoder}")