# Copyright 2024 the Llamole team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import global_add_pool, global_max_pool from torch_geometric.nn import MessagePassing import json class GraphCLIP(nn.Module): def __init__( self, graph_num_layer, graph_hidden_size, dropout, model_config, ): super().__init__() self.model_config = model_config self.hidden_size = graph_hidden_size self.molecule_encoder = GNNEncoder(num_layer=graph_num_layer, hidden_size=graph_hidden_size, drop_ratio=dropout) self.molecule_projection = ProjectionHead(embedding_dim=graph_hidden_size, projection_dim=graph_hidden_size, dropout=dropout) def forward(self, x, edge_index, edge_attr, batch): molecule_features = self.molecule_encoder(x, edge_index, edge_attr, batch) molecule_embeddings = self.molecule_projection(molecule_features) molecule_embeddings = molecule_embeddings / molecule_embeddings.norm(dim=-1, keepdim=True) return molecule_embeddings def save_pretrained(self, output_dir): """ Save the molecule encoder, projection models, and model_config to the output directory. """ if not os.path.exists(output_dir): os.makedirs(output_dir) molecule_path = os.path.join(output_dir, 'model.pt') proj_path = molecule_path.replace('model', 'model_proj') config_path = os.path.join(output_dir, 'model_config.json') torch.save(self.molecule_encoder.state_dict(), molecule_path) torch.save(self.molecule_projection.state_dict(), proj_path) # Save model_config to JSON file with open(config_path, 'w') as f: json.dump(self.model_config, f, indent=2) def disable_grads(self): """ Disable gradients for all parameters in the model. """ for param in self.parameters(): param.requires_grad = False def init_model(self, model_path, verbose=True): molecule_path = os.path.join(model_path, 'model.pt') proj_path = molecule_path.replace('model', 'model_proj') if os.path.exists(molecule_path): self.molecule_encoder.load_state_dict(torch.load(molecule_path, map_location='cpu', weights_only=False)) else: raise FileNotFoundError(f"Molecule encoder file not found: {molecule_path}") if os.path.exists(proj_path): self.molecule_projection.load_state_dict(torch.load(proj_path, map_location='cpu', weights_only=False)) else: raise FileNotFoundError(f"Molecule projection file not found: {proj_path}") if verbose: print('GraphCLIP Models initialized.') print('Molecule model:\n', self.molecule_encoder) print('Molecule projection:\n', self.molecule_projection) class GNNEncoder(nn.Module): def __init__(self, num_layer, hidden_size, drop_ratio): super(GNNEncoder, self).__init__() self.num_layer = num_layer self.drop_ratio = drop_ratio if self.num_layer < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = nn.Embedding(118, hidden_size) ### set the initial virtual node embedding to 0. self.virtualnode_embedding = nn.Embedding(1, hidden_size) nn.init.constant_(self.virtualnode_embedding.weight.data, 0) ### List of GNNs self.convs = nn.ModuleList() self.norms = nn.ModuleList() self.mlp_virtualnode_list = nn.ModuleList() for layer in range(num_layer): self.convs.append(GINConv(hidden_size, drop_ratio)) self.norms.append(nn.LayerNorm(hidden_size, elementwise_affine=True)) if layer < num_layer - 1: self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), \ nn.Linear(4*hidden_size, hidden_size))) def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) def forward(self, x, edge_index, edge_attr, batch): ### virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layer): ### add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] ### Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.norms[layer](h) if layer < self.num_layer - 1: h = F.gelu(h) h = F.dropout(h, self.drop_ratio, training = self.training) h = h + h_list[layer] h_list.append(h) if layer < self.num_layer - 1: ### add message from graph nodes to virtual nodes virtual_pool = global_max_pool(h_list[layer], batch) virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtual_pool), self.drop_ratio, training = self.training) h_node = h_list[-1] h_graph = global_add_pool(h_node, batch) return h_graph class GINConv(MessagePassing): def __init__(self, hidden_size, drop_ratio): ''' hidden_size (int) ''' super(GINConv, self).__init__(aggr = "add") self.mlp = nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), nn.Linear(4*hidden_size, hidden_size)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = nn.Embedding(5, hidden_size) def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) return out def message(self, x_j, edge_attr): return F.gelu(x_j + edge_attr) def update(self, aggr_out): return aggr_out class ProjectionHead(nn.Module): def __init__( self, embedding_dim, projection_dim, dropout, act_layer=nn.GELU, hidden_features=None, bias=True ): super().__init__() projection_dim = projection_dim or embedding_dim hidden_features = hidden_features or embedding_dim linear_layer = nn.Linear self.fc1 = linear_layer(embedding_dim, hidden_features, bias=bias) self.norm1 = nn.LayerNorm(hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(dropout) self.fc2 = linear_layer(hidden_features, projection_dim, bias=bias) def forward(self, x): x = self.fc1(x) x = self.norm1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) return x