Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from src.model.layers import TransformerEncoder | |
class Generator(nn.Module): | |
""" | |
Generator network that uses a Transformer Encoder to process node and edge features. | |
The network first processes input node and edge features with separate linear layers, | |
then applies a Transformer Encoder to model interactions, and finally outputs both transformed | |
features and readout samples. | |
""" | |
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
""" | |
Initializes the Generator. | |
Args: | |
act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh"). | |
vertexes (int): Number of vertexes in the graph. | |
edges (int): Number of edge features. | |
nodes (int): Number of node features. | |
dropout (float): Dropout rate. | |
dim (int): Dimensionality used for intermediate features. | |
depth (int): Number of Transformer encoder blocks. | |
heads (int): Number of attention heads in the Transformer. | |
mlp_ratio (int): Ratio for determining hidden layer size in MLP modules. | |
""" | |
super(Generator, self).__init__() | |
self.vertexes = vertexes | |
self.edges = edges | |
self.nodes = nodes | |
self.depth = depth | |
self.dim = dim | |
self.heads = heads | |
self.mlp_ratio = mlp_ratio | |
self.dropout = dropout | |
# Set the activation function based on the provided string | |
if act == "relu": | |
act = nn.ReLU() | |
elif act == "leaky": | |
act = nn.LeakyReLU() | |
elif act == "sigmoid": | |
act = nn.Sigmoid() | |
elif act == "tanh": | |
act = nn.Tanh() | |
# Calculate the total number of features and dimensions for transformer | |
self.features = vertexes * vertexes * edges + vertexes * nodes | |
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
self.node_layers = nn.Sequential( | |
nn.Linear(nodes, 64), act, | |
nn.Linear(64, dim), act, | |
nn.Dropout(self.dropout) | |
) | |
self.edge_layers = nn.Sequential( | |
nn.Linear(edges, 64), act, | |
nn.Linear(64, dim), act, | |
nn.Dropout(self.dropout) | |
) | |
self.TransformerEncoder = TransformerEncoder( | |
dim=self.dim, depth=self.depth, heads=self.heads, act=act, | |
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout | |
) | |
self.readout_e = nn.Linear(self.dim, edges) | |
self.readout_n = nn.Linear(self.dim, nodes) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, z_e, z_n): | |
""" | |
Forward pass of the Generator. | |
Args: | |
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). | |
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). | |
Returns: | |
tuple: A tuple containing: | |
- node: Updated node features after the transformer. | |
- edge: Updated edge features after the transformer. | |
- node_sample: Readout sample from node features. | |
- edge_sample: Readout sample from edge features. | |
""" | |
b, n, c = z_n.shape | |
# The fourth dimension of edge features | |
_, _, _, d = z_e.shape | |
# Process node and edge features through their respective layers | |
node = self.node_layers(z_n) | |
edge = self.edge_layers(z_e) | |
# Symmetrize the edge features by averaging with its transpose along vertex dimensions | |
edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
# Pass the features through the Transformer Encoder | |
node, edge = self.TransformerEncoder(node, edge) | |
# Readout layers to generate final outputs | |
node_sample = self.readout_n(node) | |
edge_sample = self.readout_e(edge) | |
return node, edge, node_sample, edge_sample | |
class Discriminator(nn.Module): | |
""" | |
Discriminator network that evaluates node and edge features. | |
It processes features with linear layers, applies a Transformer Encoder to capture dependencies, | |
and finally predicts a scalar value using an MLP on aggregated node features. | |
This class is used in DrugGEN model. | |
""" | |
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio): | |
""" | |
Initializes the Discriminator. | |
Args: | |
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). | |
vertexes (int): Number of vertexes. | |
edges (int): Number of edge features. | |
nodes (int): Number of node features. | |
dropout (float): Dropout rate. | |
dim (int): Dimensionality for intermediate representations. | |
depth (int): Number of Transformer encoder blocks. | |
heads (int): Number of attention heads. | |
mlp_ratio (int): MLP ratio for hidden layer dimensions. | |
""" | |
super(Discriminator, self).__init__() | |
self.vertexes = vertexes | |
self.edges = edges | |
self.nodes = nodes | |
self.depth = depth | |
self.dim = dim | |
self.heads = heads | |
self.mlp_ratio = mlp_ratio | |
self.dropout = dropout | |
# Set the activation function | |
if act == "relu": | |
act = nn.ReLU() | |
elif act == "leaky": | |
act = nn.LeakyReLU() | |
elif act == "sigmoid": | |
act = nn.Sigmoid() | |
elif act == "tanh": | |
act = nn.Tanh() | |
self.features = vertexes * vertexes * edges + vertexes * nodes | |
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim | |
# Define layers for processing node and edge features | |
self.node_layers = nn.Sequential( | |
nn.Linear(nodes, 64), act, | |
nn.Linear(64, dim), act, | |
nn.Dropout(self.dropout) | |
) | |
self.edge_layers = nn.Sequential( | |
nn.Linear(edges, 64), act, | |
nn.Linear(64, dim), act, | |
nn.Dropout(self.dropout) | |
) | |
# Transformer Encoder for modeling node and edge interactions | |
self.TransformerEncoder = TransformerEncoder( | |
dim=self.dim, depth=self.depth, heads=self.heads, act=act, | |
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout | |
) | |
# Calculate dimensions for node features aggregation | |
self.node_features = vertexes * dim | |
self.edge_features = vertexes * vertexes * dim | |
# MLP to predict a scalar value from aggregated node features | |
self.node_mlp = nn.Sequential( | |
nn.Linear(self.node_features, 64), act, | |
nn.Linear(64, 32), act, | |
nn.Linear(32, 16), act, | |
nn.Linear(16, 1) | |
) | |
def forward(self, z_e, z_n): | |
""" | |
Forward pass of the Discriminator. | |
Args: | |
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges). | |
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes). | |
Returns: | |
torch.Tensor: Prediction scores (typically a scalar per sample). | |
""" | |
b, n, c = z_n.shape | |
# Unpack the shape of edge features (not used further directly) | |
_, _, _, d = z_e.shape | |
# Process node and edge features separately | |
node = self.node_layers(z_n) | |
edge = self.edge_layers(z_e) | |
# Symmetrize edge features by averaging with its transpose | |
edge = (edge + edge.permute(0, 2, 1, 3)) / 2 | |
# Process features through the Transformer Encoder | |
node, edge = self.TransformerEncoder(node, edge) | |
# Flatten node features for MLP | |
node = node.view(b, -1) | |
# Predict a scalar score using the node MLP | |
prediction = self.node_mlp(node) | |
return prediction | |
class simple_disc(nn.Module): | |
""" | |
A simplified discriminator that processes flattened features through an MLP | |
to predict a scalar score. | |
This class is used in NoTarget model. | |
""" | |
def __init__(self, act, m_dim, vertexes, b_dim): | |
""" | |
Initializes the simple discriminator. | |
Args: | |
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh"). | |
m_dim (int): Dimensionality for atom type features. | |
vertexes (int): Number of vertexes. | |
b_dim (int): Dimensionality for bond type features. | |
""" | |
super().__init__() | |
# Set the activation function and check if it's supported | |
if act == "relu": | |
act = nn.ReLU() | |
elif act == "leaky": | |
act = nn.LeakyReLU() | |
elif act == "sigmoid": | |
act = nn.Sigmoid() | |
elif act == "tanh": | |
act = nn.Tanh() | |
else: | |
raise ValueError("Unsupported activation function: {}".format(act)) | |
# Compute total number of features combining both dimensions | |
features = vertexes * m_dim + vertexes * vertexes * b_dim | |
print(vertexes) | |
print(m_dim) | |
print(b_dim) | |
print(features) | |
self.predictor = nn.Sequential( | |
nn.Linear(features, 256), act, | |
nn.Linear(256, 128), act, | |
nn.Linear(128, 64), act, | |
nn.Linear(64, 32), act, | |
nn.Linear(32, 16), act, | |
nn.Linear(16, 1) | |
) | |
def forward(self, x): | |
""" | |
Forward pass of the simple discriminator. | |
Args: | |
x (torch.Tensor): Input features tensor. | |
Returns: | |
torch.Tensor: Prediction scores. | |
""" | |
prediction = self.predictor(x) | |
return prediction |