import torch import torch.nn as nn from src.model.layers import TransformerEncoder class Generator(nn.Module): """ Generator network that uses a Transformer Encoder to process node and edge features. The network first processes input node and edge features with separate linear layers, then applies a Transformer Encoder to model interactions, and finally outputs both transformed features and readout samples. """ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): """ Initializes the Generator. Args: act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh"). vertexes (int): Number of vertexes in the graph. edges (int): Number of edge features. nodes (int): Number of node features. dropout (float): Dropout rate. dim (int): Dimensionality used for intermediate features. depth (int): Number of Transformer encoder blocks. heads (int): Number of attention heads in the Transformer. mlp_ratio (int): Ratio for determining hidden layer size in MLP modules. """ super(Generator, self).__init__() self.vertexes = vertexes self.edges = edges self.nodes = nodes self.depth = depth self.dim = dim self.heads = heads self.mlp_ratio = mlp_ratio self.dropout = dropout # Set the activation function based on the provided string if act == "relu": act = nn.ReLU() elif act == "leaky": act = nn.LeakyReLU() elif act == "sigmoid": act = nn.Sigmoid() elif act == "tanh": act = nn.Tanh() # Calculate the total number of features and dimensions for transformer self.features = vertexes * vertexes * edges + vertexes * nodes self.transformer_dim = vertexes * vertexes * dim + vertexes * dim self.node_layers = nn.Sequential( nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout) ) self.edge_layers = nn.Sequential( nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout) ) self.TransformerEncoder = TransformerEncoder( dim=self.dim, depth=self.depth, heads=self.heads, act=act, mlp_ratio=self.mlp_ratio, drop_rate=self.dropout ) self.readout_e = nn.Linear(self.dim, edges) self.readout_n = nn.Linear(self.dim, nodes) self.softmax = nn.Softmax(dim=-1) def forward(self, z_e, z_n): """ Forward pass of the Generator. Args: z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). Returns: tuple: A tuple containing: - node: Updated node features after the transformer. - edge: Updated edge features after the transformer. - node_sample: Readout sample from node features. - edge_sample: Readout sample from edge features. """ b, n, c = z_n.shape # The fourth dimension of edge features _, _, _, d = z_e.shape # Process node and edge features through their respective layers node = self.node_layers(z_n) edge = self.edge_layers(z_e) # Symmetrize the edge features by averaging with its transpose along vertex dimensions edge = (edge + edge.permute(0, 2, 1, 3)) / 2 # Pass the features through the Transformer Encoder node, edge = self.TransformerEncoder(node, edge) # Readout layers to generate final outputs node_sample = self.readout_n(node) edge_sample = self.readout_e(edge) return node, edge, node_sample, edge_sample class Discriminator(nn.Module): """ Discriminator network that evaluates node and edge features. It processes features with linear layers, applies a Transformer Encoder to capture dependencies, and finally predicts a scalar value using an MLP on aggregated node features. This class is used in DrugGEN model. """ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): """ Initializes the Discriminator. Args: act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). vertexes (int): Number of vertexes. edges (int): Number of edge features. nodes (int): Number of node features. dropout (float): Dropout rate. dim (int): Dimensionality for intermediate representations. depth (int): Number of Transformer encoder blocks. heads (int): Number of attention heads. mlp_ratio (int): MLP ratio for hidden layer dimensions. """ super(Discriminator, self).__init__() self.vertexes = vertexes self.edges = edges self.nodes = nodes self.depth = depth self.dim = dim self.heads = heads self.mlp_ratio = mlp_ratio self.dropout = dropout # Set the activation function if act == "relu": act = nn.ReLU() elif act == "leaky": act = nn.LeakyReLU() elif act == "sigmoid": act = nn.Sigmoid() elif act == "tanh": act = nn.Tanh() self.features = vertexes * vertexes * edges + vertexes * nodes self.transformer_dim = vertexes * vertexes * dim + vertexes * dim # Define layers for processing node and edge features self.node_layers = nn.Sequential( nn.Linear(nodes, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout) ) self.edge_layers = nn.Sequential( nn.Linear(edges, 64), act, nn.Linear(64, dim), act, nn.Dropout(self.dropout) ) # Transformer Encoder for modeling node and edge interactions self.TransformerEncoder = TransformerEncoder( dim=self.dim, depth=self.depth, heads=self.heads, act=act, mlp_ratio=self.mlp_ratio, drop_rate=self.dropout ) # Calculate dimensions for node features aggregation self.node_features = vertexes * dim self.edge_features = vertexes * vertexes * dim # MLP to predict a scalar value from aggregated node features self.node_mlp = nn.Sequential( nn.Linear(self.node_features, 64), act, nn.Linear(64, 32), act, nn.Linear(32, 16), act, nn.Linear(16, 1) ) def forward(self, z_e, z_n): """ Forward pass of the Discriminator. Args: z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). Returns: torch.Tensor: Prediction scores (typically a scalar per sample). """ b, n, c = z_n.shape # Unpack the shape of edge features (not used further directly) _, _, _, d = z_e.shape # Process node and edge features separately node = self.node_layers(z_n) edge = self.edge_layers(z_e) # Symmetrize edge features by averaging with its transpose edge = (edge + edge.permute(0, 2, 1, 3)) / 2 # Process features through the Transformer Encoder node, edge = self.TransformerEncoder(node, edge) # Flatten node features for MLP node = node.view(b, -1) # Predict a scalar score using the node MLP prediction = self.node_mlp(node) return prediction class simple_disc(nn.Module): """ A simplified discriminator that processes flattened features through an MLP to predict a scalar score. This class is used in NoTarget model. """ def __init__(self, act, m_dim, vertexes, b_dim): """ Initializes the simple discriminator. Args: act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). m_dim (int): Dimensionality for atom type features. vertexes (int): Number of vertexes. b_dim (int): Dimensionality for bond type features. """ super().__init__() # Set the activation function and check if it's supported if act == "relu": act = nn.ReLU() elif act == "leaky": act = nn.LeakyReLU() elif act == "sigmoid": act = nn.Sigmoid() elif act == "tanh": act = nn.Tanh() else: raise ValueError("Unsupported activation function: {}".format(act)) # Compute total number of features combining both dimensions features = vertexes * m_dim + vertexes * vertexes * b_dim print(vertexes) print(m_dim) print(b_dim) print(features) self.predictor = nn.Sequential( nn.Linear(features, 256), act, nn.Linear(256, 128), act, nn.Linear(128, 64), act, nn.Linear(64, 32), act, nn.Linear(32, 16), act, nn.Linear(16, 1) ) def forward(self, x): """ Forward pass of the simple discriminator. Args: x (torch.Tensor): Input features tensor. Returns: torch.Tensor: Prediction scores. """ prediction = self.predictor(x) return prediction