Raykarr commited on
Commit
925085f
Β·
verified Β·
1 Parent(s): f5d3c94

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ qm9.csv filter=lfs diff=lfs merge=lfs -text
GCN_final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33163e1500e271e15a4d815475329b732598b081b1fedf2ca4dc12afb39e63af
3
+ size 142176
GCN_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:015533d8925e17077100a7dadccd3363e9208cc11c544aec398786089f0eddd2
3
+ size 142104
GIN_final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a45333977c150f0068b4221fefcbab0fd75502ebefb4df5269561de3ee4508b
3
+ size 117472
GIN_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18fb6277f2001e1b2b6b32c4395d89d793fa7de81107843dc3714408784de8f9
3
+ size 117244
app.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import torch
5
+ import torch.nn as nn
6
+ import pickle
7
+ import numpy as np
8
+ import pandas as pd
9
+ from typing import List, Dict, Tuple, Optional
10
+
11
+ # RDKit for molecule handling
12
+ from rdkit import Chem
13
+ from rdkit.Chem import Draw, Descriptors
14
+ from rdkit import RDLogger
15
+ RDLogger.DisableLog('rdApp.*')
16
+
17
+ # Visualization libraries
18
+ import matplotlib.pyplot as plt
19
+ import seaborn as sns
20
+
21
+ # For generating images in Streamlit
22
+ from PIL import Image
23
+
24
+ # Suppress warnings in RDKit
25
+ import warnings
26
+ warnings.filterwarnings('ignore')
27
+
28
+ # Set Seaborn style
29
+ sns.set_style('whitegrid')
30
+
31
+ # Additional imports for GNN
32
+ import torch.nn.functional as F
33
+ from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
34
+
35
+ from torch_geometric.data import Data
36
+ from torch_geometric.nn import GCNConv, GINConv
37
+ from torch_geometric.nn import global_mean_pool, global_add_pool
38
+
39
+ # Function to load the VAE model
40
+ @st.cache_resource
41
+ def load_vae_model(device):
42
+ # Load the vocabulary
43
+ with open('vae_vocab.pkl', 'rb') as f:
44
+ vocab = pickle.load(f)
45
+ vocab_size = len(vocab)
46
+
47
+ # Initialize the model with the same parameters
48
+ hidden_dim = 256 # Ensure this matches your trained model
49
+ latent_dim = 64 # Ensure this matches your trained model
50
+
51
+ # Define the VAE class (same as in your training script)
52
+ class VAE(nn.Module):
53
+ def __init__(self, vocab_size: int, hidden_dim: int, latent_dim: int):
54
+ super(VAE, self).__init__()
55
+ self.vocab_size = vocab_size
56
+ self.hidden_dim = hidden_dim
57
+ self.latent_dim = latent_dim
58
+
59
+ self.encoder = nn.GRU(vocab_size, hidden_dim, batch_first=True)
60
+ self.fc_mu = nn.Linear(hidden_dim, latent_dim)
61
+ self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
62
+
63
+ self.decoder = nn.GRU(vocab_size + latent_dim, hidden_dim, batch_first=True)
64
+ self.fc_output = nn.Linear(hidden_dim, vocab_size)
65
+
66
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ _, h = self.encoder(x)
68
+ h = h.squeeze(0)
69
+ return self.fc_mu(h), self.fc_logvar(h)
70
+
71
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
72
+ std = torch.exp(0.5 * logvar)
73
+ eps = torch.randn_like(std)
74
+ return mu + eps * std
75
+
76
+ def decode(self, z: torch.Tensor, max_length: int) -> torch.Tensor:
77
+ batch_size = z.size(0)
78
+ h = torch.zeros(1, batch_size, self.hidden_dim).to(z.device)
79
+ x = torch.zeros(batch_size, 1, self.vocab_size).to(z.device)
80
+ x[:, 0, vocab['<']] = 1 # Start token
81
+ outputs = []
82
+
83
+ for _ in range(max_length):
84
+ z_input = z.unsqueeze(1)
85
+ decoder_input = torch.cat([x, z_input], dim=2)
86
+ output, h = self.decoder(decoder_input, h)
87
+ output = self.fc_output(output)
88
+ outputs.append(output)
89
+ x = torch.softmax(output, dim=-1)
90
+
91
+ return torch.cat(outputs, dim=1)
92
+
93
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94
+ mu, logvar = self.encode(x)
95
+ z = self.reparameterize(mu, logvar)
96
+ return self.decode(z, x.size(1)), mu, logvar
97
+
98
+ model = VAE(vocab_size, hidden_dim, latent_dim)
99
+ model.load_state_dict(torch.load('vae_model.pth', map_location=device))
100
+ model.to(device)
101
+ model.eval()
102
+ return model, vocab
103
+
104
+ # Function to generate molecules using VAE
105
+ def generate_smiles_vae(model, vocab, num_samples=10, max_length=100):
106
+ model.eval()
107
+ inv_vocab = {v: k for k, v in vocab.items()}
108
+ generated_smiles = []
109
+ device = next(model.parameters()).device
110
+
111
+ with torch.no_grad():
112
+ for _ in range(num_samples):
113
+ z = torch.randn(1, model.latent_dim).to(device)
114
+ x = torch.zeros(1, 1, model.vocab_size).to(device)
115
+ x[0, 0, vocab['<']] = 1
116
+ h = torch.zeros(1, 1, model.hidden_dim).to(device)
117
+
118
+ smiles = ''
119
+ for _ in range(max_length):
120
+ z_input = z.unsqueeze(1)
121
+ decoder_input = torch.cat([x, z_input], dim=2)
122
+ output, h = model.decoder(decoder_input, h)
123
+ output = model.fc_output(output)
124
+
125
+ probs = torch.softmax(output.squeeze(0), dim=-1)
126
+ next_char = torch.multinomial(probs, 1).item()
127
+
128
+ if next_char == vocab['>']:
129
+ break
130
+
131
+ smiles += inv_vocab.get(next_char, '')
132
+ x = torch.zeros(1, 1, model.vocab_size).to(device)
133
+ x[0, 0, next_char] = 1
134
+
135
+ generated_smiles.append(smiles)
136
+
137
+ return generated_smiles
138
+
139
+ # Function to post-process and validate SMILES strings
140
+ def enhanced_post_process_smiles(smiles: str) -> str:
141
+ smiles = smiles.replace('<', '').replace('>', '')
142
+ allowed_chars = set('CNOPSFIBrClcnops()[]=@+-#0123456789')
143
+ smiles = ''.join(c for c in smiles if c in allowed_chars)
144
+
145
+ # Balance parentheses
146
+ open_count = smiles.count('(')
147
+ close_count = smiles.count(')')
148
+ if open_count > close_count:
149
+ smiles += ')' * (open_count - close_count)
150
+ elif close_count > open_count:
151
+ smiles = '(' * (close_count - open_count) + smiles
152
+
153
+ # Replace invalid double bonds
154
+ smiles = smiles.replace('==', '=')
155
+
156
+ # Attempt to close unclosed rings
157
+ for i in range(1, 10):
158
+ if smiles.count(str(i)) % 2 != 0:
159
+ smiles += str(i)
160
+
161
+ return smiles
162
+
163
+ def validate_and_correct_smiles(smiles: str) -> Tuple[bool, str]:
164
+ mol = Chem.MolFromSmiles(smiles)
165
+ if mol is not None:
166
+ try:
167
+ Chem.SanitizeMol(mol)
168
+ return True, Chem.MolToSmiles(mol, isomericSmiles=True)
169
+ except:
170
+ pass
171
+ return False, smiles
172
+
173
+ # Function to analyze molecules
174
+ def analyze_molecules(smiles_list: List[str], training_smiles_set: set) -> Dict:
175
+ results = {
176
+ 'total': len(smiles_list),
177
+ 'valid': 0,
178
+ 'invalid': 0,
179
+ 'unique': 0,
180
+ 'corrected': 0,
181
+ 'novel': 0,
182
+ 'valid_properties': [],
183
+ 'novel_properties': [],
184
+ 'invalid_smiles': []
185
+ }
186
+
187
+ unique_smiles = set()
188
+ novel_smiles = set()
189
+
190
+ for smiles in smiles_list:
191
+ processed_smiles = enhanced_post_process_smiles(smiles)
192
+ is_valid, corrected_smiles = validate_and_correct_smiles(processed_smiles)
193
+
194
+ if is_valid:
195
+ results['valid'] += 1
196
+ unique_smiles.add(corrected_smiles)
197
+ if corrected_smiles != processed_smiles:
198
+ results['corrected'] += 1
199
+
200
+ mol = Chem.MolFromSmiles(corrected_smiles)
201
+ if mol:
202
+ props = {
203
+ 'smiles': corrected_smiles,
204
+ 'MolWt': Descriptors.ExactMolWt(mol),
205
+ 'LogP': Descriptors.MolLogP(mol),
206
+ 'NumHDonors': Descriptors.NumHDonors(mol),
207
+ 'NumHAcceptors': Descriptors.NumHAcceptors(mol)
208
+ }
209
+
210
+ if corrected_smiles not in training_smiles_set:
211
+ novel_smiles.add(corrected_smiles)
212
+ results['novel'] += 1
213
+ results['novel_properties'].append(props)
214
+ else:
215
+ results['valid_properties'].append(props)
216
+ else:
217
+ results['invalid'] += 1
218
+ results['invalid_smiles'].append(smiles)
219
+
220
+ results['unique'] = len(unique_smiles)
221
+ return results
222
+
223
+ # Function to visualize molecules
224
+ def visualize_molecules(smiles_list: List[str], n: int = 5) -> Optional[Image.Image]:
225
+ valid_mols = []
226
+ for smiles in smiles_list:
227
+ smiles = smiles.strip().strip('<>').strip()
228
+ if not smiles:
229
+ continue
230
+ try:
231
+ mol = Chem.MolFromSmiles(smiles)
232
+ if mol is not None:
233
+ valid_mols.append(mol)
234
+ if len(valid_mols) == n:
235
+ break
236
+ except Exception:
237
+ continue
238
+
239
+ if not valid_mols:
240
+ return None
241
+
242
+ try:
243
+ img = Draw.MolsToGridImage(
244
+ valid_mols,
245
+ molsPerRow=min(3, len(valid_mols)),
246
+ subImgSize=(200, 200),
247
+ legends=[f"Mol {i+1}" for i in range(len(valid_mols))]
248
+ )
249
+ return img
250
+ except Exception:
251
+ return None
252
+
253
+ # GCN and GIN model definitions
254
+ class GCN(torch.nn.Module):
255
+ """Graph Convolutional Network class with 3 convolutional layers and a linear layer"""
256
+
257
+ def __init__(self, dim_h):
258
+ """init method for GCN
259
+
260
+ Args:
261
+ dim_h (int): the dimension of hidden layers
262
+ """
263
+ super().__init__()
264
+ self.conv1 = GCNConv(11, dim_h)
265
+ self.conv2 = GCNConv(dim_h, dim_h)
266
+ self.conv3 = GCNConv(dim_h, dim_h)
267
+ self.lin = torch.nn.Linear(dim_h, 1)
268
+
269
+ def forward(self, data):
270
+ e = data.edge_index
271
+ x = data.x
272
+
273
+ x = self.conv1(x, e)
274
+ x = x.relu()
275
+ x = self.conv2(x, e)
276
+ x = x.relu()
277
+ x = self.conv3(x, e)
278
+ x = global_mean_pool(x, data.batch)
279
+
280
+ x = F.dropout(x, p=0.5, training=self.training)
281
+ x = self.lin(x)
282
+
283
+ return x
284
+
285
+ class GIN(torch.nn.Module):
286
+ """Graph Isomorphism Network class with 3 GINConv layers and 2 linear layers"""
287
+
288
+ def __init__(self, dim_h):
289
+ """Initializing GIN class
290
+
291
+ Args:
292
+ dim_h (int): the dimension of hidden layers
293
+ """
294
+ super(GIN, self).__init__()
295
+ nn1 = Sequential(Linear(11, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
296
+ self.conv1 = GINConv(nn1)
297
+ nn2 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
298
+ self.conv2 = GINConv(nn2)
299
+ nn3 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
300
+ self.conv3 = GINConv(nn3)
301
+ self.lin1 = Linear(dim_h, dim_h)
302
+ self.lin2 = Linear(dim_h, 1)
303
+
304
+ def forward(self, data):
305
+ x = data.x
306
+ edge_index = data.edge_index
307
+ batch = data.batch
308
+
309
+ # Node embeddings
310
+ h = self.conv1(x, edge_index)
311
+ h = h.relu()
312
+ h = self.conv2(h, edge_index)
313
+ h = h.relu()
314
+ h = self.conv3(h, edge_index)
315
+
316
+ # Graph-level readout
317
+ h = global_add_pool(h, batch)
318
+
319
+ h = self.lin1(h)
320
+ h = h.relu()
321
+ h = F.dropout(h, p=0.5, training=self.training)
322
+ h = self.lin2(h)
323
+
324
+ return h
325
+
326
+ # Function to load GNN models
327
+ @st.cache_resource
328
+ def load_gnn_models(device):
329
+ # Load GCN model
330
+ gcn_model = GCN(dim_h=128)
331
+ gcn_model.load_state_dict(torch.load("GCN_model.pth", map_location=device))
332
+ gcn_model.to(device)
333
+ gcn_model.eval()
334
+
335
+ # Load GIN model
336
+ gin_model = GIN(dim_h=64)
337
+ gin_model.load_state_dict(torch.load("GIN_model.pth", map_location=device))
338
+ gin_model.to(device)
339
+ gin_model.eval()
340
+
341
+ return gcn_model, gin_model
342
+
343
+ # Function to load normalization parameters
344
+ @st.cache_resource
345
+ def load_data_norm(device):
346
+ data_norm = torch.load('data_norm.pth', map_location=device)
347
+ data_mean = data_norm['mean']
348
+ data_std = data_norm['std']
349
+ return data_mean, data_std
350
+
351
+ # Function to convert SMILES to Data object
352
+ def smiles_to_data(smiles):
353
+ mol = Chem.MolFromSmiles(smiles)
354
+ if mol is None:
355
+ return None
356
+
357
+ atoms = mol.GetAtoms()
358
+ num_atoms = len(atoms)
359
+
360
+ atom_type_list = ['H', 'C', 'N', 'O', 'F']
361
+ hybridization_list = [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3]
362
+
363
+ x = []
364
+ for atom in atoms:
365
+ atom_type = atom.GetSymbol()
366
+ atom_type_feature = [int(atom_type == s) for s in atom_type_list] # 5 features
367
+
368
+ # Atom degree (scalar between 0 and 4)
369
+ degree = atom.GetDegree()
370
+ degree_feature = [degree / 4] # Normalize degree to [0,1] # 1 feature
371
+
372
+ # Formal charge
373
+ formal_charge = atom.GetFormalCharge()
374
+ formal_charge_feature = [formal_charge / 4] # Assume max formal charge is 4 # 1 feature
375
+
376
+ # Aromaticity
377
+ is_aromatic = atom.GetIsAromatic()
378
+ aromatic_feature = [int(is_aromatic)] # 1 feature
379
+
380
+ # Hybridization
381
+ hybridization = atom.GetHybridization()
382
+ hybridization_feature = [int(hybridization == hyb) for hyb in hybridization_list] # 3 features
383
+
384
+ # Total features: 5 + 1 +1 +1 +3 = 11
385
+ atom_feature = atom_type_feature + degree_feature + formal_charge_feature + aromatic_feature + hybridization_feature
386
+ x.append(atom_feature)
387
+
388
+ x = torch.tensor(x, dtype=torch.float)
389
+
390
+ # Build edge indices
391
+ edge_index = []
392
+ for bond in mol.GetBonds():
393
+ i = bond.GetBeginAtomIdx()
394
+ j = bond.GetEndAtomIdx()
395
+ edge_index.append([i, j])
396
+ edge_index.append([j, i]) # Since it's undirected
397
+
398
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
399
+
400
+ # Build batch tensor (since batch size is 1)
401
+ batch = torch.zeros(num_atoms, dtype=torch.long)
402
+
403
+ # Build Data object
404
+ data = Data(x=x, edge_index=edge_index, batch=batch)
405
+
406
+ return data
407
+
408
+ # Streamlit app
409
+ def main():
410
+ st.set_page_config(
411
+ page_title="πŸ§ͺ Molecule Generator and Property Predictor",
412
+ page_icon="πŸ§ͺ",
413
+ layout="wide",
414
+ initial_sidebar_state="expanded",
415
+ )
416
+
417
+ # Main Title and Description
418
+ st.title("πŸ§ͺ Molecular Generation and Analysis using VAE and GNN")
419
+ st.markdown("""
420
+ This application allows you to generate novel molecular structures using a Variational Autoencoder (VAE) model trained on the QM9 dataset.
421
+ You can also predict molecular properties using pre-trained Graph Neural Network (GNN) models (GCN and GIN).
422
+ """)
423
+
424
+ # Initialize session state variables
425
+ if 'analysis' not in st.session_state:
426
+ st.session_state.analysis = None
427
+ if 'generated_smiles' not in st.session_state:
428
+ st.session_state.generated_smiles = []
429
+ if 'vae_generated' not in st.session_state:
430
+ st.session_state.vae_generated = False
431
+
432
+ # Sidebar configuration
433
+ st.sidebar.title("πŸ”§ Configuration")
434
+ st.sidebar.markdown("Adjust the settings below to generate molecules or predict properties.")
435
+
436
+ # Load training data and canonicalize SMILES
437
+ @st.cache_data
438
+ def load_training_data():
439
+ df = pd.read_csv("qm9.csv")
440
+ smiles_list_raw = df['smiles'].tolist()
441
+ # Canonicalize SMILES for accurate comparison
442
+ smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(s), isomericSmiles=True) for s in smiles_list_raw]
443
+ return set(smiles_list)
444
+
445
+ training_smiles_set = load_training_data()
446
+
447
+ # Device selection
448
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
+
450
+ # Model selection
451
+ st.sidebar.title("πŸ“Œ Model Selection")
452
+ model_option = st.sidebar.selectbox("Choose a functionality", ("Generate Molecules (VAE)", "Predict Property (GNN)"))
453
+
454
+ if model_option == "Generate Molecules (VAE)":
455
+ # Number of samples
456
+ num_samples = st.sidebar.slider("Number of Molecules to Generate", min_value=5, max_value=500, value=50, step=5)
457
+
458
+ # Random seed
459
+ seed = st.sidebar.number_input("Random Seed", value=42, step=1)
460
+ torch.manual_seed(seed)
461
+ np.random.seed(seed)
462
+
463
+ if st.sidebar.button("πŸš€ Generate Molecules"):
464
+ with st.spinner("Generating molecules..."):
465
+ # Load VAE model
466
+ model, vocab = load_vae_model(device)
467
+ generated_smiles = generate_smiles_vae(model, vocab, num_samples=num_samples)
468
+ # Analyze molecules
469
+ analysis = analyze_molecules(generated_smiles, training_smiles_set)
470
+ # Store results in session state
471
+ st.session_state.generated_smiles = generated_smiles
472
+ st.session_state.analysis = analysis
473
+ st.session_state.vae_generated = True
474
+
475
+ # Display summary
476
+ st.success("βœ… Molecule generation completed!")
477
+ st.subheader("Summary of Generated Molecules")
478
+ col1, col2, col3, col4 = st.columns(4)
479
+ col1.metric("Total Generated", analysis['total'])
480
+ col2.metric("Valid Molecules", f"{analysis['valid']} ({(analysis['valid']/analysis['total'])*100:.2f}%)")
481
+ col3.metric("Unique Molecules", f"{analysis['unique']} ({(analysis['unique']/analysis['total'])*100:.2f}%)")
482
+ col4.metric("Corrected Molecules", f"{analysis['corrected']} ({(analysis['corrected']/analysis['total'])*100:.2f}%)")
483
+
484
+ col1, col2 = st.columns(2)
485
+ col1.metric("Novel Molecules", f"{analysis['novel']} ({(analysis['novel']/analysis['total'])*100:.2f}%)")
486
+ col2.metric("Invalid Molecules", f"{analysis['invalid']} ({(analysis['invalid']/analysis['total'])*100:.2f}%)")
487
+
488
+ # Display properties
489
+ if analysis['valid_properties'] or analysis['novel_properties']:
490
+ st.subheader("Properties of Generated Molecules")
491
+
492
+ tab1, tab2 = st.tabs(["βœ… Valid Molecules", "🌟 Novel Molecules"])
493
+ with tab1:
494
+ if analysis['valid_properties']:
495
+ df_valid = pd.DataFrame(analysis['valid_properties'])
496
+ st.dataframe(df_valid)
497
+ # Visualize valid molecules (limit to 9 for performance)
498
+ st.subheader("Sample Valid Molecules")
499
+ mol_image_valid = visualize_molecules([prop['smiles'] for prop in analysis['valid_properties']], n=9)
500
+ if mol_image_valid:
501
+ st.image(mol_image_valid)
502
+ else:
503
+ st.write("No valid molecules to display.")
504
+ else:
505
+ st.write("No valid molecules found.")
506
+
507
+ with tab2:
508
+ if analysis['novel_properties']:
509
+ df_novel = pd.DataFrame(analysis['novel_properties'])
510
+ st.dataframe(df_novel)
511
+ # Visualize novel molecules (limit to 9 for performance)
512
+ st.subheader("Sample Novel Molecules")
513
+ mol_image_novel = visualize_molecules([prop['smiles'] for prop in analysis['novel_properties']], n=9)
514
+ if mol_image_novel:
515
+ st.image(mol_image_novel)
516
+ else:
517
+ st.write("No novel molecules to display.")
518
+ else:
519
+ st.write("No novel molecules found.")
520
+
521
+ # Property distributions
522
+ st.subheader("Property Distributions")
523
+ fig, axs = plt.subplots(2, 2, figsize=(14, 10))
524
+ if analysis['valid_properties']:
525
+ sns.histplot(df_valid['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='skyblue', label='Valid')
526
+ if analysis['novel_properties']:
527
+ sns.histplot(df_novel['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='orange', label='Novel')
528
+ axs[0, 0].set_title('Molecular Weight Distribution')
529
+ axs[0, 0].legend()
530
+
531
+ if analysis['valid_properties']:
532
+ sns.histplot(df_valid['LogP'], bins=20, ax=axs[0, 1], kde=True, color='skyblue', label='Valid')
533
+ if analysis['novel_properties']:
534
+ sns.histplot(df_novel['LogP'], bins=20, ax=axs[0, 1], kde=True, color='orange', label='Novel')
535
+ axs[0, 1].set_title('LogP Distribution')
536
+ axs[0, 1].legend()
537
+
538
+ if analysis['valid_properties']:
539
+ sns.histplot(df_valid['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
540
+ df_novel['NumHDonors'].max()) + 2),
541
+ ax=axs[1, 0], kde=False, color='skyblue', label='Valid')
542
+ if analysis['novel_properties']:
543
+ sns.histplot(df_novel['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
544
+ df_novel['NumHDonors'].max()) + 2),
545
+ ax=axs[1, 0], kde=False, color='orange', label='Novel')
546
+ axs[1, 0].set_title('Number of H Donors')
547
+ axs[1, 0].legend()
548
+
549
+ if analysis['valid_properties']:
550
+ sns.histplot(df_valid['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
551
+ df_novel['NumHAcceptors'].max()) + 2),
552
+ ax=axs[1, 1], kde=False, color='skyblue', label='Valid')
553
+ if analysis['novel_properties']:
554
+ sns.histplot(df_novel['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
555
+ df_novel['NumHAcceptors'].max()) + 2),
556
+ ax=axs[1, 1], kde=False, color='orange', label='Novel')
557
+ axs[1, 1].set_title('Number of H Acceptors')
558
+ axs[1, 1].legend()
559
+
560
+ plt.tight_layout()
561
+ st.pyplot(fig)
562
+
563
+ # Download options
564
+ csv_valid = df_valid.to_csv(index=False).encode('utf-8')
565
+ csv_novel = df_novel.to_csv(index=False).encode('utf-8')
566
+ col1, col2 = st.columns(2)
567
+ with col1:
568
+ st.download_button(
569
+ label="πŸ’Ύ Download Valid Molecules CSV",
570
+ data=csv_valid,
571
+ file_name='valid_molecules.csv',
572
+ mime='text/csv'
573
+ )
574
+ with col2:
575
+ st.download_button(
576
+ label="πŸ’Ύ Download Novel Molecules CSV",
577
+ data=csv_novel,
578
+ file_name='novel_molecules.csv',
579
+ mime='text/csv'
580
+ )
581
+ else:
582
+ st.warning("No valid or novel molecules were generated.")
583
+
584
+ elif model_option == "Predict Property (GNN)":
585
+ # Load GNN models
586
+ with st.spinner("Loading GNN models..."):
587
+ gcn_model, gin_model = load_gnn_models(device)
588
+ # Load normalization parameters
589
+ data_mean, data_std = load_data_norm(device)
590
+
591
+ # GNN Model selection
592
+ gnn_model_option = st.sidebar.selectbox("Choose a GNN model", ("GCN", "GIN"))
593
+
594
+ st.subheader("πŸ” Predict Molecular Property using GNN")
595
+ st.markdown("""
596
+ Input a SMILES string to predict the dipole moment using the selected GNN model.
597
+ """)
598
+
599
+ # User inputs a SMILES string
600
+ user_smiles = st.text_input("Enter a SMILES string for property prediction:", "")
601
+
602
+ if user_smiles:
603
+ data = smiles_to_data(user_smiles)
604
+ if data:
605
+ data = data.to(device)
606
+ if gnn_model_option == "GCN":
607
+ prediction = gcn_model(data)
608
+ prediction = prediction.item()
609
+ elif gnn_model_option == "GIN":
610
+ prediction = gin_model(data)
611
+ prediction = prediction.item()
612
+ # Unnormalize the prediction
613
+ prediction = prediction * data_std.item() + data_mean.item()
614
+ st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
615
+ # Display molecule
616
+ mol = Chem.MolFromSmiles(user_smiles)
617
+ if mol:
618
+ st.subheader("Molecular Structure")
619
+ st.image(Draw.MolToImage(mol, size=(300, 300)))
620
+ else:
621
+ st.error("❌ Invalid SMILES string.")
622
+
623
+ st.markdown("---")
624
+ st.markdown("### Or select a molecule from the generated molecules (if any).")
625
+
626
+ # Check if molecules have been generated
627
+ if st.session_state.vae_generated and st.session_state.analysis is not None:
628
+ # Combine valid and novel properties
629
+ all_properties = st.session_state.analysis['valid_properties'] + st.session_state.analysis['novel_properties']
630
+ if all_properties:
631
+ smiles_options = [prop['smiles'] for prop in all_properties]
632
+ selected_smiles = st.selectbox("Select a molecule", smiles_options)
633
+ if selected_smiles:
634
+ data = smiles_to_data(selected_smiles)
635
+ if data:
636
+ data = data.to(device)
637
+ if gnn_model_option == "GCN":
638
+ prediction = gcn_model(data)
639
+ prediction = prediction.item()
640
+ elif gnn_model_option == "GIN":
641
+ prediction = gin_model(data)
642
+ prediction = prediction.item()
643
+ # Unnormalize the prediction
644
+ prediction = prediction * data_std.item() + data_mean.item()
645
+ st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
646
+ # Display molecule
647
+ mol = Chem.MolFromSmiles(selected_smiles)
648
+ if mol:
649
+ st.subheader("Molecular Structure")
650
+ st.image(Draw.MolToImage(mol, size=(300, 300)))
651
+ else:
652
+ st.error("❌ Invalid SMILES string.")
653
+ else:
654
+ st.info("πŸ” No valid or novel molecules available.")
655
+ else:
656
+ st.info("πŸ” No generated molecules available. Generate molecules using the VAE first.")
657
+
658
+ # About section
659
+ st.sidebar.title("ℹ️ About")
660
+ st.sidebar.info("""
661
+ **Molecule Generator and Property Predictor App**
662
+
663
+ This app uses a Variational Autoencoder (VAE) model and Graph Neural Networks (GNNs) to generate novel molecular structures and predict molecular properties.
664
+
665
+ - **Developed by**: Arjun, Kaustubh, and Nachiket
666
+ - **Hugging Face Repository**: [Your Hugging Face Repository](https://huggingface.co/YourRepositoryLink)
667
+ """)
668
+
669
+ # Hide Streamlit footer and header
670
+ hide_streamlit_style = """
671
+ <style>
672
+ footer {visibility: hidden;}
673
+ header {visibility: hidden;}
674
+ </style>
675
+ """
676
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
677
+
678
+ # Run the app
679
+ if __name__ == "__main__":
680
+ main()
data_norm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d86e7264dc9cf6b98296acc1ac0d1180511bc714f7d5a3336212c27c9df32ff7
3
+ size 1444
gan_mol_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd75fc4b2d59d77b7abfc85d1cbe65c7caff987661a7423db8b8848044c99e7f
3
+ size 697168
qm9.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e668f8c34e4bc392a90d417a50a5eed3b64b842a817a633024bdc054c68ccb4
3
+ size 29856825
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ torch-scatter
5
+ torch-sparse
6
+ torch-cluster
7
+ torch-spline-conv
8
+ torch-geometric
9
+ rdkit-pypi
10
+ pandas
11
+ numpy
12
+ matplotlib
13
+ seaborn
14
+ Pillow
vae_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60b3c9608612eeebfda8659bc017caeca6a26caebc3fa8b07ee5c9abd8af7f03
3
+ size 2082136
vae_vocab.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a3784f5b486a116e95054673663ef811802663affb4b1180c36f53090cc2f00
3
+ size 154