Raykarr commited on
Commit
c3e7177
Β·
verified Β·
1 Parent(s): 64b81e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +680 -680
app.py CHANGED
@@ -1,680 +1,680 @@
1
- # app.py
2
-
3
- import streamlit as st
4
- import torch
5
- import torch.nn as nn
6
- import pickle
7
- import numpy as np
8
- import pandas as pd
9
- from typing import List, Dict, Tuple, Optional
10
-
11
- # RDKit for molecule handling
12
- from rdkit import Chem
13
- from rdkit.Chem import Draw, Descriptors
14
- from rdkit import RDLogger
15
- RDLogger.DisableLog('rdApp.*')
16
-
17
- # Visualization libraries
18
- import matplotlib.pyplot as plt
19
- import seaborn as sns
20
-
21
- # For generating images in Streamlit
22
- from PIL import Image
23
-
24
- # Suppress warnings in RDKit
25
- import warnings
26
- warnings.filterwarnings('ignore')
27
-
28
- # Set Seaborn style
29
- sns.set_style('whitegrid')
30
-
31
- # Additional imports for GNN
32
- import torch.nn.functional as F
33
- from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
34
-
35
- from torch_geometric.data import Data
36
- from torch_geometric.nn import GCNConv, GINConv
37
- from torch_geometric.nn import global_mean_pool, global_add_pool
38
-
39
- # Function to load the VAE model
40
- @st.cache_resource
41
- def load_vae_model(device):
42
- # Load the vocabulary
43
- with open('vae_vocab.pkl', 'rb') as f:
44
- vocab = pickle.load(f)
45
- vocab_size = len(vocab)
46
-
47
- # Initialize the model with the same parameters
48
- hidden_dim = 256 # Ensure this matches your trained model
49
- latent_dim = 64 # Ensure this matches your trained model
50
-
51
- # Define the VAE class (same as in your training script)
52
- class VAE(nn.Module):
53
- def __init__(self, vocab_size: int, hidden_dim: int, latent_dim: int):
54
- super(VAE, self).__init__()
55
- self.vocab_size = vocab_size
56
- self.hidden_dim = hidden_dim
57
- self.latent_dim = latent_dim
58
-
59
- self.encoder = nn.GRU(vocab_size, hidden_dim, batch_first=True)
60
- self.fc_mu = nn.Linear(hidden_dim, latent_dim)
61
- self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
62
-
63
- self.decoder = nn.GRU(vocab_size + latent_dim, hidden_dim, batch_first=True)
64
- self.fc_output = nn.Linear(hidden_dim, vocab_size)
65
-
66
- def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
67
- _, h = self.encoder(x)
68
- h = h.squeeze(0)
69
- return self.fc_mu(h), self.fc_logvar(h)
70
-
71
- def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
72
- std = torch.exp(0.5 * logvar)
73
- eps = torch.randn_like(std)
74
- return mu + eps * std
75
-
76
- def decode(self, z: torch.Tensor, max_length: int) -> torch.Tensor:
77
- batch_size = z.size(0)
78
- h = torch.zeros(1, batch_size, self.hidden_dim).to(z.device)
79
- x = torch.zeros(batch_size, 1, self.vocab_size).to(z.device)
80
- x[:, 0, vocab['<']] = 1 # Start token
81
- outputs = []
82
-
83
- for _ in range(max_length):
84
- z_input = z.unsqueeze(1)
85
- decoder_input = torch.cat([x, z_input], dim=2)
86
- output, h = self.decoder(decoder_input, h)
87
- output = self.fc_output(output)
88
- outputs.append(output)
89
- x = torch.softmax(output, dim=-1)
90
-
91
- return torch.cat(outputs, dim=1)
92
-
93
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94
- mu, logvar = self.encode(x)
95
- z = self.reparameterize(mu, logvar)
96
- return self.decode(z, x.size(1)), mu, logvar
97
-
98
- model = VAE(vocab_size, hidden_dim, latent_dim)
99
- model.load_state_dict(torch.load('vae_model.pth', map_location=device))
100
- model.to(device)
101
- model.eval()
102
- return model, vocab
103
-
104
- # Function to generate molecules using VAE
105
- def generate_smiles_vae(model, vocab, num_samples=10, max_length=100):
106
- model.eval()
107
- inv_vocab = {v: k for k, v in vocab.items()}
108
- generated_smiles = []
109
- device = next(model.parameters()).device
110
-
111
- with torch.no_grad():
112
- for _ in range(num_samples):
113
- z = torch.randn(1, model.latent_dim).to(device)
114
- x = torch.zeros(1, 1, model.vocab_size).to(device)
115
- x[0, 0, vocab['<']] = 1
116
- h = torch.zeros(1, 1, model.hidden_dim).to(device)
117
-
118
- smiles = ''
119
- for _ in range(max_length):
120
- z_input = z.unsqueeze(1)
121
- decoder_input = torch.cat([x, z_input], dim=2)
122
- output, h = model.decoder(decoder_input, h)
123
- output = model.fc_output(output)
124
-
125
- probs = torch.softmax(output.squeeze(0), dim=-1)
126
- next_char = torch.multinomial(probs, 1).item()
127
-
128
- if next_char == vocab['>']:
129
- break
130
-
131
- smiles += inv_vocab.get(next_char, '')
132
- x = torch.zeros(1, 1, model.vocab_size).to(device)
133
- x[0, 0, next_char] = 1
134
-
135
- generated_smiles.append(smiles)
136
-
137
- return generated_smiles
138
-
139
- # Function to post-process and validate SMILES strings
140
- def enhanced_post_process_smiles(smiles: str) -> str:
141
- smiles = smiles.replace('<', '').replace('>', '')
142
- allowed_chars = set('CNOPSFIBrClcnops()[]=@+-#0123456789')
143
- smiles = ''.join(c for c in smiles if c in allowed_chars)
144
-
145
- # Balance parentheses
146
- open_count = smiles.count('(')
147
- close_count = smiles.count(')')
148
- if open_count > close_count:
149
- smiles += ')' * (open_count - close_count)
150
- elif close_count > open_count:
151
- smiles = '(' * (close_count - open_count) + smiles
152
-
153
- # Replace invalid double bonds
154
- smiles = smiles.replace('==', '=')
155
-
156
- # Attempt to close unclosed rings
157
- for i in range(1, 10):
158
- if smiles.count(str(i)) % 2 != 0:
159
- smiles += str(i)
160
-
161
- return smiles
162
-
163
- def validate_and_correct_smiles(smiles: str) -> Tuple[bool, str]:
164
- mol = Chem.MolFromSmiles(smiles)
165
- if mol is not None:
166
- try:
167
- Chem.SanitizeMol(mol)
168
- return True, Chem.MolToSmiles(mol, isomericSmiles=True)
169
- except:
170
- pass
171
- return False, smiles
172
-
173
- # Function to analyze molecules
174
- def analyze_molecules(smiles_list: List[str], training_smiles_set: set) -> Dict:
175
- results = {
176
- 'total': len(smiles_list),
177
- 'valid': 0,
178
- 'invalid': 0,
179
- 'unique': 0,
180
- 'corrected': 0,
181
- 'novel': 0,
182
- 'valid_properties': [],
183
- 'novel_properties': [],
184
- 'invalid_smiles': []
185
- }
186
-
187
- unique_smiles = set()
188
- novel_smiles = set()
189
-
190
- for smiles in smiles_list:
191
- processed_smiles = enhanced_post_process_smiles(smiles)
192
- is_valid, corrected_smiles = validate_and_correct_smiles(processed_smiles)
193
-
194
- if is_valid:
195
- results['valid'] += 1
196
- unique_smiles.add(corrected_smiles)
197
- if corrected_smiles != processed_smiles:
198
- results['corrected'] += 1
199
-
200
- mol = Chem.MolFromSmiles(corrected_smiles)
201
- if mol:
202
- props = {
203
- 'smiles': corrected_smiles,
204
- 'MolWt': Descriptors.ExactMolWt(mol),
205
- 'LogP': Descriptors.MolLogP(mol),
206
- 'NumHDonors': Descriptors.NumHDonors(mol),
207
- 'NumHAcceptors': Descriptors.NumHAcceptors(mol)
208
- }
209
-
210
- if corrected_smiles not in training_smiles_set:
211
- novel_smiles.add(corrected_smiles)
212
- results['novel'] += 1
213
- results['novel_properties'].append(props)
214
- else:
215
- results['valid_properties'].append(props)
216
- else:
217
- results['invalid'] += 1
218
- results['invalid_smiles'].append(smiles)
219
-
220
- results['unique'] = len(unique_smiles)
221
- return results
222
-
223
- # Function to visualize molecules
224
- def visualize_molecules(smiles_list: List[str], n: int = 5) -> Optional[Image.Image]:
225
- valid_mols = []
226
- for smiles in smiles_list:
227
- smiles = smiles.strip().strip('<>').strip()
228
- if not smiles:
229
- continue
230
- try:
231
- mol = Chem.MolFromSmiles(smiles)
232
- if mol is not None:
233
- valid_mols.append(mol)
234
- if len(valid_mols) == n:
235
- break
236
- except Exception:
237
- continue
238
-
239
- if not valid_mols:
240
- return None
241
-
242
- try:
243
- img = Draw.MolsToGridImage(
244
- valid_mols,
245
- molsPerRow=min(3, len(valid_mols)),
246
- subImgSize=(200, 200),
247
- legends=[f"Mol {i+1}" for i in range(len(valid_mols))]
248
- )
249
- return img
250
- except Exception:
251
- return None
252
-
253
- # GCN and GIN model definitions
254
- class GCN(torch.nn.Module):
255
- """Graph Convolutional Network class with 3 convolutional layers and a linear layer"""
256
-
257
- def __init__(self, dim_h):
258
- """init method for GCN
259
-
260
- Args:
261
- dim_h (int): the dimension of hidden layers
262
- """
263
- super().__init__()
264
- self.conv1 = GCNConv(11, dim_h)
265
- self.conv2 = GCNConv(dim_h, dim_h)
266
- self.conv3 = GCNConv(dim_h, dim_h)
267
- self.lin = torch.nn.Linear(dim_h, 1)
268
-
269
- def forward(self, data):
270
- e = data.edge_index
271
- x = data.x
272
-
273
- x = self.conv1(x, e)
274
- x = x.relu()
275
- x = self.conv2(x, e)
276
- x = x.relu()
277
- x = self.conv3(x, e)
278
- x = global_mean_pool(x, data.batch)
279
-
280
- x = F.dropout(x, p=0.5, training=self.training)
281
- x = self.lin(x)
282
-
283
- return x
284
-
285
- class GIN(torch.nn.Module):
286
- """Graph Isomorphism Network class with 3 GINConv layers and 2 linear layers"""
287
-
288
- def __init__(self, dim_h):
289
- """Initializing GIN class
290
-
291
- Args:
292
- dim_h (int): the dimension of hidden layers
293
- """
294
- super(GIN, self).__init__()
295
- nn1 = Sequential(Linear(11, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
296
- self.conv1 = GINConv(nn1)
297
- nn2 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
298
- self.conv2 = GINConv(nn2)
299
- nn3 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
300
- self.conv3 = GINConv(nn3)
301
- self.lin1 = Linear(dim_h, dim_h)
302
- self.lin2 = Linear(dim_h, 1)
303
-
304
- def forward(self, data):
305
- x = data.x
306
- edge_index = data.edge_index
307
- batch = data.batch
308
-
309
- # Node embeddings
310
- h = self.conv1(x, edge_index)
311
- h = h.relu()
312
- h = self.conv2(h, edge_index)
313
- h = h.relu()
314
- h = self.conv3(h, edge_index)
315
-
316
- # Graph-level readout
317
- h = global_add_pool(h, batch)
318
-
319
- h = self.lin1(h)
320
- h = h.relu()
321
- h = F.dropout(h, p=0.5, training=self.training)
322
- h = self.lin2(h)
323
-
324
- return h
325
-
326
- # Function to load GNN models
327
- @st.cache_resource
328
- def load_gnn_models(device):
329
- # Load GCN model
330
- gcn_model = GCN(dim_h=128)
331
- gcn_model.load_state_dict(torch.load("GCN_model.pth", map_location=device))
332
- gcn_model.to(device)
333
- gcn_model.eval()
334
-
335
- # Load GIN model
336
- gin_model = GIN(dim_h=64)
337
- gin_model.load_state_dict(torch.load("GIN_model.pth", map_location=device))
338
- gin_model.to(device)
339
- gin_model.eval()
340
-
341
- return gcn_model, gin_model
342
-
343
- # Function to load normalization parameters
344
- @st.cache_resource
345
- def load_data_norm(device):
346
- data_norm = torch.load('data_norm.pth', map_location=device)
347
- data_mean = data_norm['mean']
348
- data_std = data_norm['std']
349
- return data_mean, data_std
350
-
351
- # Function to convert SMILES to Data object
352
- def smiles_to_data(smiles):
353
- mol = Chem.MolFromSmiles(smiles)
354
- if mol is None:
355
- return None
356
-
357
- atoms = mol.GetAtoms()
358
- num_atoms = len(atoms)
359
-
360
- atom_type_list = ['H', 'C', 'N', 'O', 'F']
361
- hybridization_list = [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3]
362
-
363
- x = []
364
- for atom in atoms:
365
- atom_type = atom.GetSymbol()
366
- atom_type_feature = [int(atom_type == s) for s in atom_type_list] # 5 features
367
-
368
- # Atom degree (scalar between 0 and 4)
369
- degree = atom.GetDegree()
370
- degree_feature = [degree / 4] # Normalize degree to [0,1] # 1 feature
371
-
372
- # Formal charge
373
- formal_charge = atom.GetFormalCharge()
374
- formal_charge_feature = [formal_charge / 4] # Assume max formal charge is 4 # 1 feature
375
-
376
- # Aromaticity
377
- is_aromatic = atom.GetIsAromatic()
378
- aromatic_feature = [int(is_aromatic)] # 1 feature
379
-
380
- # Hybridization
381
- hybridization = atom.GetHybridization()
382
- hybridization_feature = [int(hybridization == hyb) for hyb in hybridization_list] # 3 features
383
-
384
- # Total features: 5 + 1 +1 +1 +3 = 11
385
- atom_feature = atom_type_feature + degree_feature + formal_charge_feature + aromatic_feature + hybridization_feature
386
- x.append(atom_feature)
387
-
388
- x = torch.tensor(x, dtype=torch.float)
389
-
390
- # Build edge indices
391
- edge_index = []
392
- for bond in mol.GetBonds():
393
- i = bond.GetBeginAtomIdx()
394
- j = bond.GetEndAtomIdx()
395
- edge_index.append([i, j])
396
- edge_index.append([j, i]) # Since it's undirected
397
-
398
- edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
399
-
400
- # Build batch tensor (since batch size is 1)
401
- batch = torch.zeros(num_atoms, dtype=torch.long)
402
-
403
- # Build Data object
404
- data = Data(x=x, edge_index=edge_index, batch=batch)
405
-
406
- return data
407
-
408
- # Streamlit app
409
- def main():
410
- st.set_page_config(
411
- page_title="πŸ§ͺ Molecule Generator and Property Predictor",
412
- page_icon="πŸ§ͺ",
413
- layout="wide",
414
- initial_sidebar_state="expanded",
415
- )
416
-
417
- # Main Title and Description
418
- st.title("πŸ§ͺ Molecular Generation and Analysis using VAE and GNN")
419
- st.markdown("""
420
- This application allows you to generate novel molecular structures using a Variational Autoencoder (VAE) model trained on the QM9 dataset.
421
- You can also predict molecular properties using pre-trained Graph Neural Network (GNN) models (GCN and GIN).
422
- """)
423
-
424
- # Initialize session state variables
425
- if 'analysis' not in st.session_state:
426
- st.session_state.analysis = None
427
- if 'generated_smiles' not in st.session_state:
428
- st.session_state.generated_smiles = []
429
- if 'vae_generated' not in st.session_state:
430
- st.session_state.vae_generated = False
431
-
432
- # Sidebar configuration
433
- st.sidebar.title("πŸ”§ Configuration")
434
- st.sidebar.markdown("Adjust the settings below to generate molecules or predict properties.")
435
-
436
- # Load training data and canonicalize SMILES
437
- @st.cache_data
438
- def load_training_data():
439
- df = pd.read_csv("qm9.csv")
440
- smiles_list_raw = df['smiles'].tolist()
441
- # Canonicalize SMILES for accurate comparison
442
- smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(s), isomericSmiles=True) for s in smiles_list_raw]
443
- return set(smiles_list)
444
-
445
- training_smiles_set = load_training_data()
446
-
447
- # Device selection
448
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
-
450
- # Model selection
451
- st.sidebar.title("πŸ“Œ Model Selection")
452
- model_option = st.sidebar.selectbox("Choose a functionality", ("Generate Molecules (VAE)", "Predict Property (GNN)"))
453
-
454
- if model_option == "Generate Molecules (VAE)":
455
- # Number of samples
456
- num_samples = st.sidebar.slider("Number of Molecules to Generate", min_value=5, max_value=500, value=50, step=5)
457
-
458
- # Random seed
459
- seed = st.sidebar.number_input("Random Seed", value=42, step=1)
460
- torch.manual_seed(seed)
461
- np.random.seed(seed)
462
-
463
- if st.sidebar.button("πŸš€ Generate Molecules"):
464
- with st.spinner("Generating molecules..."):
465
- # Load VAE model
466
- model, vocab = load_vae_model(device)
467
- generated_smiles = generate_smiles_vae(model, vocab, num_samples=num_samples)
468
- # Analyze molecules
469
- analysis = analyze_molecules(generated_smiles, training_smiles_set)
470
- # Store results in session state
471
- st.session_state.generated_smiles = generated_smiles
472
- st.session_state.analysis = analysis
473
- st.session_state.vae_generated = True
474
-
475
- # Display summary
476
- st.success("βœ… Molecule generation completed!")
477
- st.subheader("Summary of Generated Molecules")
478
- col1, col2, col3, col4 = st.columns(4)
479
- col1.metric("Total Generated", analysis['total'])
480
- col2.metric("Valid Molecules", f"{analysis['valid']} ({(analysis['valid']/analysis['total'])*100:.2f}%)")
481
- col3.metric("Unique Molecules", f"{analysis['unique']} ({(analysis['unique']/analysis['total'])*100:.2f}%)")
482
- col4.metric("Corrected Molecules", f"{analysis['corrected']} ({(analysis['corrected']/analysis['total'])*100:.2f}%)")
483
-
484
- col1, col2 = st.columns(2)
485
- col1.metric("Novel Molecules", f"{analysis['novel']} ({(analysis['novel']/analysis['total'])*100:.2f}%)")
486
- col2.metric("Invalid Molecules", f"{analysis['invalid']} ({(analysis['invalid']/analysis['total'])*100:.2f}%)")
487
-
488
- # Display properties
489
- if analysis['valid_properties'] or analysis['novel_properties']:
490
- st.subheader("Properties of Generated Molecules")
491
-
492
- tab1, tab2 = st.tabs(["βœ… Valid Molecules", "🌟 Novel Molecules"])
493
- with tab1:
494
- if analysis['valid_properties']:
495
- df_valid = pd.DataFrame(analysis['valid_properties'])
496
- st.dataframe(df_valid)
497
- # Visualize valid molecules (limit to 9 for performance)
498
- st.subheader("Sample Valid Molecules")
499
- mol_image_valid = visualize_molecules([prop['smiles'] for prop in analysis['valid_properties']], n=9)
500
- if mol_image_valid:
501
- st.image(mol_image_valid)
502
- else:
503
- st.write("No valid molecules to display.")
504
- else:
505
- st.write("No valid molecules found.")
506
-
507
- with tab2:
508
- if analysis['novel_properties']:
509
- df_novel = pd.DataFrame(analysis['novel_properties'])
510
- st.dataframe(df_novel)
511
- # Visualize novel molecules (limit to 9 for performance)
512
- st.subheader("Sample Novel Molecules")
513
- mol_image_novel = visualize_molecules([prop['smiles'] for prop in analysis['novel_properties']], n=9)
514
- if mol_image_novel:
515
- st.image(mol_image_novel)
516
- else:
517
- st.write("No novel molecules to display.")
518
- else:
519
- st.write("No novel molecules found.")
520
-
521
- # Property distributions
522
- st.subheader("Property Distributions")
523
- fig, axs = plt.subplots(2, 2, figsize=(14, 10))
524
- if analysis['valid_properties']:
525
- sns.histplot(df_valid['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='skyblue', label='Valid')
526
- if analysis['novel_properties']:
527
- sns.histplot(df_novel['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='orange', label='Novel')
528
- axs[0, 0].set_title('Molecular Weight Distribution')
529
- axs[0, 0].legend()
530
-
531
- if analysis['valid_properties']:
532
- sns.histplot(df_valid['LogP'], bins=20, ax=axs[0, 1], kde=True, color='skyblue', label='Valid')
533
- if analysis['novel_properties']:
534
- sns.histplot(df_novel['LogP'], bins=20, ax=axs[0, 1], kde=True, color='orange', label='Novel')
535
- axs[0, 1].set_title('LogP Distribution')
536
- axs[0, 1].legend()
537
-
538
- if analysis['valid_properties']:
539
- sns.histplot(df_valid['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
540
- df_novel['NumHDonors'].max()) + 2),
541
- ax=axs[1, 0], kde=False, color='skyblue', label='Valid')
542
- if analysis['novel_properties']:
543
- sns.histplot(df_novel['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
544
- df_novel['NumHDonors'].max()) + 2),
545
- ax=axs[1, 0], kde=False, color='orange', label='Novel')
546
- axs[1, 0].set_title('Number of H Donors')
547
- axs[1, 0].legend()
548
-
549
- if analysis['valid_properties']:
550
- sns.histplot(df_valid['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
551
- df_novel['NumHAcceptors'].max()) + 2),
552
- ax=axs[1, 1], kde=False, color='skyblue', label='Valid')
553
- if analysis['novel_properties']:
554
- sns.histplot(df_novel['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
555
- df_novel['NumHAcceptors'].max()) + 2),
556
- ax=axs[1, 1], kde=False, color='orange', label='Novel')
557
- axs[1, 1].set_title('Number of H Acceptors')
558
- axs[1, 1].legend()
559
-
560
- plt.tight_layout()
561
- st.pyplot(fig)
562
-
563
- # Download options
564
- csv_valid = df_valid.to_csv(index=False).encode('utf-8')
565
- csv_novel = df_novel.to_csv(index=False).encode('utf-8')
566
- col1, col2 = st.columns(2)
567
- with col1:
568
- st.download_button(
569
- label="πŸ’Ύ Download Valid Molecules CSV",
570
- data=csv_valid,
571
- file_name='valid_molecules.csv',
572
- mime='text/csv'
573
- )
574
- with col2:
575
- st.download_button(
576
- label="πŸ’Ύ Download Novel Molecules CSV",
577
- data=csv_novel,
578
- file_name='novel_molecules.csv',
579
- mime='text/csv'
580
- )
581
- else:
582
- st.warning("No valid or novel molecules were generated.")
583
-
584
- elif model_option == "Predict Property (GNN)":
585
- # Load GNN models
586
- with st.spinner("Loading GNN models..."):
587
- gcn_model, gin_model = load_gnn_models(device)
588
- # Load normalization parameters
589
- data_mean, data_std = load_data_norm(device)
590
-
591
- # GNN Model selection
592
- gnn_model_option = st.sidebar.selectbox("Choose a GNN model", ("GCN", "GIN"))
593
-
594
- st.subheader("πŸ” Predict Molecular Property using GNN")
595
- st.markdown("""
596
- Input a SMILES string to predict the dipole moment using the selected GNN model.
597
- """)
598
-
599
- # User inputs a SMILES string
600
- user_smiles = st.text_input("Enter a SMILES string for property prediction:", "")
601
-
602
- if user_smiles:
603
- data = smiles_to_data(user_smiles)
604
- if data:
605
- data = data.to(device)
606
- if gnn_model_option == "GCN":
607
- prediction = gcn_model(data)
608
- prediction = prediction.item()
609
- elif gnn_model_option == "GIN":
610
- prediction = gin_model(data)
611
- prediction = prediction.item()
612
- # Unnormalize the prediction
613
- prediction = prediction * data_std.item() + data_mean.item()
614
- st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
615
- # Display molecule
616
- mol = Chem.MolFromSmiles(user_smiles)
617
- if mol:
618
- st.subheader("Molecular Structure")
619
- st.image(Draw.MolToImage(mol, size=(300, 300)))
620
- else:
621
- st.error("❌ Invalid SMILES string.")
622
-
623
- st.markdown("---")
624
- st.markdown("### Or select a molecule from the generated molecules (if any).")
625
-
626
- # Check if molecules have been generated
627
- if st.session_state.vae_generated and st.session_state.analysis is not None:
628
- # Combine valid and novel properties
629
- all_properties = st.session_state.analysis['valid_properties'] + st.session_state.analysis['novel_properties']
630
- if all_properties:
631
- smiles_options = [prop['smiles'] for prop in all_properties]
632
- selected_smiles = st.selectbox("Select a molecule", smiles_options)
633
- if selected_smiles:
634
- data = smiles_to_data(selected_smiles)
635
- if data:
636
- data = data.to(device)
637
- if gnn_model_option == "GCN":
638
- prediction = gcn_model(data)
639
- prediction = prediction.item()
640
- elif gnn_model_option == "GIN":
641
- prediction = gin_model(data)
642
- prediction = prediction.item()
643
- # Unnormalize the prediction
644
- prediction = prediction * data_std.item() + data_mean.item()
645
- st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
646
- # Display molecule
647
- mol = Chem.MolFromSmiles(selected_smiles)
648
- if mol:
649
- st.subheader("Molecular Structure")
650
- st.image(Draw.MolToImage(mol, size=(300, 300)))
651
- else:
652
- st.error("❌ Invalid SMILES string.")
653
- else:
654
- st.info("πŸ” No valid or novel molecules available.")
655
- else:
656
- st.info("πŸ” No generated molecules available. Generate molecules using the VAE first.")
657
-
658
- # About section
659
- st.sidebar.title("ℹ️ About")
660
- st.sidebar.info("""
661
- **Molecule Generator and Property Predictor App**
662
-
663
- This app uses a Variational Autoencoder (VAE) model and Graph Neural Networks (GNNs) to generate novel molecular structures and predict molecular properties.
664
-
665
- - **Developed by**: Arjun, Kaustubh, and Nachiket
666
- - **Hugging Face Repository**: [Your Hugging Face Repository](https://huggingface.co/YourRepositoryLink)
667
- """)
668
-
669
- # Hide Streamlit footer and header
670
- hide_streamlit_style = """
671
- <style>
672
- footer {visibility: hidden;}
673
- header {visibility: hidden;}
674
- </style>
675
- """
676
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
677
-
678
- # Run the app
679
- if __name__ == "__main__":
680
- main()
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import torch
5
+ import torch.nn as nn
6
+ import pickle
7
+ import numpy as np
8
+ import pandas as pd
9
+ from typing import List, Dict, Tuple, Optional
10
+
11
+ # RDKit for molecule handling
12
+ from rdkit import Chem
13
+ from rdkit.Chem import Draw, Descriptors
14
+ from rdkit import RDLogger
15
+ RDLogger.DisableLog('rdApp.*')
16
+
17
+ # Visualization libraries
18
+ import matplotlib.pyplot as plt
19
+ import seaborn as sns
20
+
21
+ # For generating images in Streamlit
22
+ from PIL import Image
23
+
24
+ # Suppress warnings in RDKit
25
+ import warnings
26
+ warnings.filterwarnings('ignore')
27
+
28
+ # Set Seaborn style
29
+ sns.set_style('whitegrid')
30
+
31
+ # Additional imports for GNN
32
+ import torch.nn.functional as F
33
+ from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
34
+
35
+ from torch_geometric.data import Data
36
+ from torch_geometric.nn import GCNConv, GINConv
37
+ from torch_geometric.nn import global_mean_pool, global_add_pool
38
+
39
+ # Function to load the VAE model
40
+ @st.cache_resource
41
+ def load_vae_model(device):
42
+ # Load the vocabulary
43
+ with open('vae_vocab.pkl', 'rb') as f:
44
+ vocab = pickle.load(f)
45
+ vocab_size = len(vocab)
46
+
47
+ # Initialize the model with the same parameters
48
+ hidden_dim = 256 # Ensure this matches your trained model
49
+ latent_dim = 64 # Ensure this matches your trained model
50
+
51
+ # Define the VAE class (same as in your training script)
52
+ class VAE(nn.Module):
53
+ def __init__(self, vocab_size: int, hidden_dim: int, latent_dim: int):
54
+ super(VAE, self).__init__()
55
+ self.vocab_size = vocab_size
56
+ self.hidden_dim = hidden_dim
57
+ self.latent_dim = latent_dim
58
+
59
+ self.encoder = nn.GRU(vocab_size, hidden_dim, batch_first=True)
60
+ self.fc_mu = nn.Linear(hidden_dim, latent_dim)
61
+ self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
62
+
63
+ self.decoder = nn.GRU(vocab_size + latent_dim, hidden_dim, batch_first=True)
64
+ self.fc_output = nn.Linear(hidden_dim, vocab_size)
65
+
66
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ _, h = self.encoder(x)
68
+ h = h.squeeze(0)
69
+ return self.fc_mu(h), self.fc_logvar(h)
70
+
71
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
72
+ std = torch.exp(0.5 * logvar)
73
+ eps = torch.randn_like(std)
74
+ return mu + eps * std
75
+
76
+ def decode(self, z: torch.Tensor, max_length: int) -> torch.Tensor:
77
+ batch_size = z.size(0)
78
+ h = torch.zeros(1, batch_size, self.hidden_dim).to(z.device)
79
+ x = torch.zeros(batch_size, 1, self.vocab_size).to(z.device)
80
+ x[:, 0, vocab['<']] = 1 # Start token
81
+ outputs = []
82
+
83
+ for _ in range(max_length):
84
+ z_input = z.unsqueeze(1)
85
+ decoder_input = torch.cat([x, z_input], dim=2)
86
+ output, h = self.decoder(decoder_input, h)
87
+ output = self.fc_output(output)
88
+ outputs.append(output)
89
+ x = torch.softmax(output, dim=-1)
90
+
91
+ return torch.cat(outputs, dim=1)
92
+
93
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94
+ mu, logvar = self.encode(x)
95
+ z = self.reparameterize(mu, logvar)
96
+ return self.decode(z, x.size(1)), mu, logvar
97
+
98
+ model = VAE(vocab_size, hidden_dim, latent_dim)
99
+ model.load_state_dict(torch.load('vae_model.pth', map_location=device))
100
+ model.to(device)
101
+ model.eval()
102
+ return model, vocab
103
+
104
+ # Function to generate molecules using VAE
105
+ def generate_smiles_vae(model, vocab, num_samples=10, max_length=100):
106
+ model.eval()
107
+ inv_vocab = {v: k for k, v in vocab.items()}
108
+ generated_smiles = []
109
+ device = next(model.parameters()).device
110
+
111
+ with torch.no_grad():
112
+ for _ in range(num_samples):
113
+ z = torch.randn(1, model.latent_dim).to(device)
114
+ x = torch.zeros(1, 1, model.vocab_size).to(device)
115
+ x[0, 0, vocab['<']] = 1
116
+ h = torch.zeros(1, 1, model.hidden_dim).to(device)
117
+
118
+ smiles = ''
119
+ for _ in range(max_length):
120
+ z_input = z.unsqueeze(1)
121
+ decoder_input = torch.cat([x, z_input], dim=2)
122
+ output, h = model.decoder(decoder_input, h)
123
+ output = model.fc_output(output)
124
+
125
+ probs = torch.softmax(output.squeeze(0), dim=-1)
126
+ next_char = torch.multinomial(probs, 1).item()
127
+
128
+ if next_char == vocab['>']:
129
+ break
130
+
131
+ smiles += inv_vocab.get(next_char, '')
132
+ x = torch.zeros(1, 1, model.vocab_size).to(device)
133
+ x[0, 0, next_char] = 1
134
+
135
+ generated_smiles.append(smiles)
136
+
137
+ return generated_smiles
138
+
139
+ # Function to post-process and validate SMILES strings
140
+ def enhanced_post_process_smiles(smiles: str) -> str:
141
+ smiles = smiles.replace('<', '').replace('>', '')
142
+ allowed_chars = set('CNOPSFIBrClcnops()[]=@+-#0123456789')
143
+ smiles = ''.join(c for c in smiles if c in allowed_chars)
144
+
145
+ # Balance parentheses
146
+ open_count = smiles.count('(')
147
+ close_count = smiles.count(')')
148
+ if open_count > close_count:
149
+ smiles += ')' * (open_count - close_count)
150
+ elif close_count > open_count:
151
+ smiles = '(' * (close_count - open_count) + smiles
152
+
153
+ # Replace invalid double bonds
154
+ smiles = smiles.replace('==', '=')
155
+
156
+ # Attempt to close unclosed rings
157
+ for i in range(1, 10):
158
+ if smiles.count(str(i)) % 2 != 0:
159
+ smiles += str(i)
160
+
161
+ return smiles
162
+
163
+ def validate_and_correct_smiles(smiles: str) -> Tuple[bool, str]:
164
+ mol = Chem.MolFromSmiles(smiles)
165
+ if mol is not None:
166
+ try:
167
+ Chem.SanitizeMol(mol)
168
+ return True, Chem.MolToSmiles(mol, isomericSmiles=True)
169
+ except:
170
+ pass
171
+ return False, smiles
172
+
173
+ # Function to analyze molecules
174
+ def analyze_molecules(smiles_list: List[str], training_smiles_set: set) -> Dict:
175
+ results = {
176
+ 'total': len(smiles_list),
177
+ 'valid': 0,
178
+ 'invalid': 0,
179
+ 'unique': 0,
180
+ 'corrected': 0,
181
+ 'novel': 0,
182
+ 'valid_properties': [],
183
+ 'novel_properties': [],
184
+ 'invalid_smiles': []
185
+ }
186
+
187
+ unique_smiles = set()
188
+ novel_smiles = set()
189
+
190
+ for smiles in smiles_list:
191
+ processed_smiles = enhanced_post_process_smiles(smiles)
192
+ is_valid, corrected_smiles = validate_and_correct_smiles(processed_smiles)
193
+
194
+ if is_valid:
195
+ results['valid'] += 1
196
+ unique_smiles.add(corrected_smiles)
197
+ if corrected_smiles != processed_smiles:
198
+ results['corrected'] += 1
199
+
200
+ mol = Chem.MolFromSmiles(corrected_smiles)
201
+ if mol:
202
+ props = {
203
+ 'smiles': corrected_smiles,
204
+ 'MolWt': Descriptors.ExactMolWt(mol),
205
+ 'LogP': Descriptors.MolLogP(mol),
206
+ 'NumHDonors': Descriptors.NumHDonors(mol),
207
+ 'NumHAcceptors': Descriptors.NumHAcceptors(mol)
208
+ }
209
+
210
+ if corrected_smiles not in training_smiles_set:
211
+ novel_smiles.add(corrected_smiles)
212
+ results['novel'] += 1
213
+ results['novel_properties'].append(props)
214
+ else:
215
+ results['valid_properties'].append(props)
216
+ else:
217
+ results['invalid'] += 1
218
+ results['invalid_smiles'].append(smiles)
219
+
220
+ results['unique'] = len(unique_smiles)
221
+ return results
222
+
223
+ # Function to visualize molecules
224
+ def visualize_molecules(smiles_list: List[str], n: int = 5) -> Optional[Image.Image]:
225
+ valid_mols = []
226
+ for smiles in smiles_list:
227
+ smiles = smiles.strip().strip('<>').strip()
228
+ if not smiles:
229
+ continue
230
+ try:
231
+ mol = Chem.MolFromSmiles(smiles)
232
+ if mol is not None:
233
+ valid_mols.append(mol)
234
+ if len(valid_mols) == n:
235
+ break
236
+ except Exception:
237
+ continue
238
+
239
+ if not valid_mols:
240
+ return None
241
+
242
+ try:
243
+ img = Draw.MolsToGridImage(
244
+ valid_mols,
245
+ molsPerRow=min(3, len(valid_mols)),
246
+ subImgSize=(200, 200),
247
+ legends=[f"Mol {i+1}" for i in range(len(valid_mols))]
248
+ )
249
+ return img
250
+ except Exception:
251
+ return None
252
+
253
+ # GCN and GIN model definitions
254
+ class GCN(torch.nn.Module):
255
+ """Graph Convolutional Network class with 3 convolutional layers and a linear layer"""
256
+
257
+ def __init__(self, dim_h):
258
+ """init method for GCN
259
+
260
+ Args:
261
+ dim_h (int): the dimension of hidden layers
262
+ """
263
+ super().__init__()
264
+ self.conv1 = GCNConv(11, dim_h)
265
+ self.conv2 = GCNConv(dim_h, dim_h)
266
+ self.conv3 = GCNConv(dim_h, dim_h)
267
+ self.lin = torch.nn.Linear(dim_h, 1)
268
+
269
+ def forward(self, data):
270
+ e = data.edge_index
271
+ x = data.x
272
+
273
+ x = self.conv1(x, e)
274
+ x = x.relu()
275
+ x = self.conv2(x, e)
276
+ x = x.relu()
277
+ x = self.conv3(x, e)
278
+ x = global_mean_pool(x, data.batch)
279
+
280
+ x = F.dropout(x, p=0.5, training=self.training)
281
+ x = self.lin(x)
282
+
283
+ return x
284
+
285
+ class GIN(torch.nn.Module):
286
+ """Graph Isomorphism Network class with 3 GINConv layers and 2 linear layers"""
287
+
288
+ def __init__(self, dim_h):
289
+ """Initializing GIN class
290
+
291
+ Args:
292
+ dim_h (int): the dimension of hidden layers
293
+ """
294
+ super(GIN, self).__init__()
295
+ nn1 = Sequential(Linear(11, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
296
+ self.conv1 = GINConv(nn1)
297
+ nn2 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
298
+ self.conv2 = GINConv(nn2)
299
+ nn3 = Sequential(Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
300
+ self.conv3 = GINConv(nn3)
301
+ self.lin1 = Linear(dim_h, dim_h)
302
+ self.lin2 = Linear(dim_h, 1)
303
+
304
+ def forward(self, data):
305
+ x = data.x
306
+ edge_index = data.edge_index
307
+ batch = data.batch
308
+
309
+ # Node embeddings
310
+ h = self.conv1(x, edge_index)
311
+ h = h.relu()
312
+ h = self.conv2(h, edge_index)
313
+ h = h.relu()
314
+ h = self.conv3(h, edge_index)
315
+
316
+ # Graph-level readout
317
+ h = global_add_pool(h, batch)
318
+
319
+ h = self.lin1(h)
320
+ h = h.relu()
321
+ h = F.dropout(h, p=0.5, training=self.training)
322
+ h = self.lin2(h)
323
+
324
+ return h
325
+
326
+ # Function to load GNN models
327
+ @st.cache_resource
328
+ def load_gnn_models(device):
329
+ # Load GCN model
330
+ gcn_model = GCN(dim_h=128)
331
+ gcn_model.load_state_dict(torch.load("GCN_model.pth", map_location=device))
332
+ gcn_model.to(device)
333
+ gcn_model.eval()
334
+
335
+ # Load GIN model
336
+ gin_model = GIN(dim_h=64)
337
+ gin_model.load_state_dict(torch.load("GIN_model.pth", map_location=device))
338
+ gin_model.to(device)
339
+ gin_model.eval()
340
+
341
+ return gcn_model, gin_model
342
+
343
+ # Function to load normalization parameters
344
+ @st.cache_resource
345
+ def load_data_norm(device):
346
+ data_norm = torch.load('data_norm.pth', map_location=device)
347
+ data_mean = data_norm['mean']
348
+ data_std = data_norm['std']
349
+ return data_mean, data_std
350
+
351
+ # Function to convert SMILES to Data object
352
+ def smiles_to_data(smiles):
353
+ mol = Chem.MolFromSmiles(smiles)
354
+ if mol is None:
355
+ return None
356
+
357
+ atoms = mol.GetAtoms()
358
+ num_atoms = len(atoms)
359
+
360
+ atom_type_list = ['H', 'C', 'N', 'O', 'F']
361
+ hybridization_list = [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3]
362
+
363
+ x = []
364
+ for atom in atoms:
365
+ atom_type = atom.GetSymbol()
366
+ atom_type_feature = [int(atom_type == s) for s in atom_type_list] # 5 features
367
+
368
+ # Atom degree (scalar between 0 and 4)
369
+ degree = atom.GetDegree()
370
+ degree_feature = [degree / 4] # Normalize degree to [0,1] # 1 feature
371
+
372
+ # Formal charge
373
+ formal_charge = atom.GetFormalCharge()
374
+ formal_charge_feature = [formal_charge / 4] # Assume max formal charge is 4 # 1 feature
375
+
376
+ # Aromaticity
377
+ is_aromatic = atom.GetIsAromatic()
378
+ aromatic_feature = [int(is_aromatic)] # 1 feature
379
+
380
+ # Hybridization
381
+ hybridization = atom.GetHybridization()
382
+ hybridization_feature = [int(hybridization == hyb) for hyb in hybridization_list] # 3 features
383
+
384
+ # Total features: 5 + 1 +1 +1 +3 = 11
385
+ atom_feature = atom_type_feature + degree_feature + formal_charge_feature + aromatic_feature + hybridization_feature
386
+ x.append(atom_feature)
387
+
388
+ x = torch.tensor(x, dtype=torch.float)
389
+
390
+ # Build edge indices
391
+ edge_index = []
392
+ for bond in mol.GetBonds():
393
+ i = bond.GetBeginAtomIdx()
394
+ j = bond.GetEndAtomIdx()
395
+ edge_index.append([i, j])
396
+ edge_index.append([j, i]) # Since it's undirected
397
+
398
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
399
+
400
+ # Build batch tensor (since batch size is 1)
401
+ batch = torch.zeros(num_atoms, dtype=torch.long)
402
+
403
+ # Build Data object
404
+ data = Data(x=x, edge_index=edge_index, batch=batch)
405
+
406
+ return data
407
+
408
+ # Streamlit app
409
+ def main():
410
+ st.set_page_config(
411
+ page_title="πŸ§ͺ Molecule Generator and Property Predictor",
412
+ page_icon="πŸ§ͺ",
413
+ layout="wide",
414
+ initial_sidebar_state="expanded",
415
+ )
416
+
417
+ # Main Title and Description
418
+ st.title("πŸ§ͺ Molecular Generation and Analysis using VAE and GNN")
419
+ st.markdown("""
420
+ This application allows you to generate novel molecular structures using a Variational Autoencoder (VAE) model trained on the QM9 dataset.
421
+ You can also predict molecular properties using pre-trained Graph Neural Network (GNN) models (GCN and GIN).
422
+ """)
423
+
424
+ # Initialize session state variables
425
+ if 'analysis' not in st.session_state:
426
+ st.session_state.analysis = None
427
+ if 'generated_smiles' not in st.session_state:
428
+ st.session_state.generated_smiles = []
429
+ if 'vae_generated' not in st.session_state:
430
+ st.session_state.vae_generated = False
431
+
432
+ # Sidebar configuration
433
+ st.sidebar.title("πŸ”§ Configuration")
434
+ st.sidebar.markdown("Adjust the settings below to generate molecules or predict properties.")
435
+
436
+ # Load training data and canonicalize SMILES
437
+ @st.cache_data
438
+ def load_training_data():
439
+ df = pd.read_csv("qm9.csv")
440
+ smiles_list_raw = df['smiles'].tolist()
441
+ # Canonicalize SMILES for accurate comparison
442
+ smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(s), isomericSmiles=True) for s in smiles_list_raw]
443
+ return set(smiles_list)
444
+
445
+ training_smiles_set = load_training_data()
446
+
447
+ # Device selection
448
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
+
450
+ # Model selection
451
+ st.sidebar.title("πŸ“Œ Model Selection")
452
+ model_option = st.sidebar.selectbox("Choose a functionality", ("Generate Molecules (VAE)", "Predict Property (GNN)"))
453
+
454
+ if model_option == "Generate Molecules (VAE)":
455
+ # Number of samples
456
+ num_samples = st.sidebar.slider("Number of Molecules to Generate", min_value=5, max_value=500, value=50, step=5)
457
+
458
+ # Random seed
459
+ seed = st.sidebar.number_input("Random Seed", value=42, step=1)
460
+ torch.manual_seed(seed)
461
+ np.random.seed(seed)
462
+
463
+ if st.sidebar.button("πŸš€ Generate Molecules"):
464
+ with st.spinner("Generating molecules..."):
465
+ # Load VAE model
466
+ model, vocab = load_vae_model(device)
467
+ generated_smiles = generate_smiles_vae(model, vocab, num_samples=num_samples)
468
+ # Analyze molecules
469
+ analysis = analyze_molecules(generated_smiles, training_smiles_set)
470
+ # Store results in session state
471
+ st.session_state.generated_smiles = generated_smiles
472
+ st.session_state.analysis = analysis
473
+ st.session_state.vae_generated = True
474
+
475
+ # Display summary
476
+ st.success("βœ… Molecule generation completed!")
477
+ st.subheader("Summary of Generated Molecules")
478
+ col1, col2, col3, col4 = st.columns(4)
479
+ col1.metric("Total Generated", analysis['total'])
480
+ col2.metric("Valid Molecules", f"{analysis['valid']} ({(analysis['valid']/analysis['total'])*100:.2f}%)")
481
+ col3.metric("Unique Molecules", f"{analysis['unique']} ({(analysis['unique']/analysis['total'])*100:.2f}%)")
482
+ col4.metric("Corrected Molecules", f"{analysis['corrected']} ({(analysis['corrected']/analysis['total'])*100:.2f}%)")
483
+
484
+ col1, col2 = st.columns(2)
485
+ col1.metric("Novel Molecules", f"{analysis['novel']} ({(analysis['novel']/analysis['total'])*100:.2f}%)")
486
+ col2.metric("Invalid Molecules", f"{analysis['invalid']} ({(analysis['invalid']/analysis['total'])*100:.2f}%)")
487
+
488
+ # Display properties
489
+ if analysis['valid_properties'] or analysis['novel_properties']:
490
+ st.subheader("Properties of Generated Molecules")
491
+
492
+ tab1, tab2 = st.tabs(["βœ… Valid Molecules", "🌟 Novel Molecules"])
493
+ with tab1:
494
+ if analysis['valid_properties']:
495
+ df_valid = pd.DataFrame(analysis['valid_properties'])
496
+ st.dataframe(df_valid)
497
+ # Visualize valid molecules (limit to 9 for performance)
498
+ st.subheader("Sample Valid Molecules")
499
+ mol_image_valid = visualize_molecules([prop['smiles'] for prop in analysis['valid_properties']], n=9)
500
+ if mol_image_valid:
501
+ st.image(mol_image_valid)
502
+ else:
503
+ st.write("No valid molecules to display.")
504
+ else:
505
+ st.write("No valid molecules found.")
506
+
507
+ with tab2:
508
+ if analysis['novel_properties']:
509
+ df_novel = pd.DataFrame(analysis['novel_properties'])
510
+ st.dataframe(df_novel)
511
+ # Visualize novel molecules (limit to 9 for performance)
512
+ st.subheader("Sample Novel Molecules")
513
+ mol_image_novel = visualize_molecules([prop['smiles'] for prop in analysis['novel_properties']], n=9)
514
+ if mol_image_novel:
515
+ st.image(mol_image_novel)
516
+ else:
517
+ st.write("No novel molecules to display.")
518
+ else:
519
+ st.write("No novel molecules found.")
520
+
521
+ # Property distributions
522
+ st.subheader("Property Distributions")
523
+ fig, axs = plt.subplots(2, 2, figsize=(14, 10))
524
+ if analysis['valid_properties']:
525
+ sns.histplot(df_valid['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='skyblue', label='Valid')
526
+ if analysis['novel_properties']:
527
+ sns.histplot(df_novel['MolWt'], bins=20, ax=axs[0, 0], kde=True, color='orange', label='Novel')
528
+ axs[0, 0].set_title('Molecular Weight Distribution')
529
+ axs[0, 0].legend()
530
+
531
+ if analysis['valid_properties']:
532
+ sns.histplot(df_valid['LogP'], bins=20, ax=axs[0, 1], kde=True, color='skyblue', label='Valid')
533
+ if analysis['novel_properties']:
534
+ sns.histplot(df_novel['LogP'], bins=20, ax=axs[0, 1], kde=True, color='orange', label='Novel')
535
+ axs[0, 1].set_title('LogP Distribution')
536
+ axs[0, 1].legend()
537
+
538
+ if analysis['valid_properties']:
539
+ sns.histplot(df_valid['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
540
+ df_novel['NumHDonors'].max()) + 2),
541
+ ax=axs[1, 0], kde=False, color='skyblue', label='Valid')
542
+ if analysis['novel_properties']:
543
+ sns.histplot(df_novel['NumHDonors'], bins=range(0, max(df_valid['NumHDonors'].max(),
544
+ df_novel['NumHDonors'].max()) + 2),
545
+ ax=axs[1, 0], kde=False, color='orange', label='Novel')
546
+ axs[1, 0].set_title('Number of H Donors')
547
+ axs[1, 0].legend()
548
+
549
+ if analysis['valid_properties']:
550
+ sns.histplot(df_valid['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
551
+ df_novel['NumHAcceptors'].max()) + 2),
552
+ ax=axs[1, 1], kde=False, color='skyblue', label='Valid')
553
+ if analysis['novel_properties']:
554
+ sns.histplot(df_novel['NumHAcceptors'], bins=range(0, max(df_valid['NumHAcceptors'].max(),
555
+ df_novel['NumHAcceptors'].max()) + 2),
556
+ ax=axs[1, 1], kde=False, color='orange', label='Novel')
557
+ axs[1, 1].set_title('Number of H Acceptors')
558
+ axs[1, 1].legend()
559
+
560
+ plt.tight_layout()
561
+ st.pyplot(fig)
562
+
563
+ # Download options
564
+ csv_valid = df_valid.to_csv(index=False).encode('utf-8')
565
+ csv_novel = df_novel.to_csv(index=False).encode('utf-8')
566
+ col1, col2 = st.columns(2)
567
+ with col1:
568
+ st.download_button(
569
+ label="πŸ’Ύ Download Valid Molecules CSV",
570
+ data=csv_valid,
571
+ file_name='valid_molecules.csv',
572
+ mime='text/csv'
573
+ )
574
+ with col2:
575
+ st.download_button(
576
+ label="πŸ’Ύ Download Novel Molecules CSV",
577
+ data=csv_novel,
578
+ file_name='novel_molecules.csv',
579
+ mime='text/csv'
580
+ )
581
+ else:
582
+ st.warning("No valid or novel molecules were generated.")
583
+
584
+ elif model_option == "Predict Property (GNN)":
585
+ # Load GNN models
586
+ with st.spinner("Loading GNN models..."):
587
+ gcn_model, gin_model = load_gnn_models(device)
588
+ # Load normalization parameters
589
+ data_mean, data_std = load_data_norm(device)
590
+
591
+ # GNN Model selection
592
+ gnn_model_option = st.sidebar.selectbox("Choose a GNN model", ("GCN", "GIN"))
593
+
594
+ st.subheader("πŸ” Predict Molecular Property using GNN")
595
+ st.markdown("""
596
+ Input a SMILES string to predict the dipole moment using the selected GNN model.
597
+ """)
598
+
599
+ # User inputs a SMILES string
600
+ user_smiles = st.text_input("Enter a SMILES string for property prediction:", "")
601
+
602
+ if user_smiles:
603
+ data = smiles_to_data(user_smiles)
604
+ if data:
605
+ data = data.to(device)
606
+ if gnn_model_option == "GCN":
607
+ prediction = gcn_model(data)
608
+ prediction = prediction.item()
609
+ elif gnn_model_option == "GIN":
610
+ prediction = gin_model(data)
611
+ prediction = prediction.item()
612
+ # Unnormalize the prediction
613
+ prediction = prediction * data_std.item() + data_mean.item()
614
+ st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
615
+ # Display molecule
616
+ mol = Chem.MolFromSmiles(user_smiles)
617
+ if mol:
618
+ st.subheader("Molecular Structure")
619
+ st.image(Draw.MolToImage(mol, size=(300, 300)))
620
+ else:
621
+ st.error("❌ Invalid SMILES string.")
622
+
623
+ st.markdown("---")
624
+ st.markdown("### Or select a molecule from the generated molecules (if any).")
625
+
626
+ # Check if molecules have been generated
627
+ if st.session_state.vae_generated and st.session_state.analysis is not None:
628
+ # Combine valid and novel properties
629
+ all_properties = st.session_state.analysis['valid_properties'] + st.session_state.analysis['novel_properties']
630
+ if all_properties:
631
+ smiles_options = [prop['smiles'] for prop in all_properties]
632
+ selected_smiles = st.selectbox("Select a molecule", smiles_options)
633
+ if selected_smiles:
634
+ data = smiles_to_data(selected_smiles)
635
+ if data:
636
+ data = data.to(device)
637
+ if gnn_model_option == "GCN":
638
+ prediction = gcn_model(data)
639
+ prediction = prediction.item()
640
+ elif gnn_model_option == "GIN":
641
+ prediction = gin_model(data)
642
+ prediction = prediction.item()
643
+ # Unnormalize the prediction
644
+ prediction = prediction * data_std.item() + data_mean.item()
645
+ st.success(f"**Predicted Dipole Moment ({gnn_model_option}):** {prediction:.4f}")
646
+ # Display molecule
647
+ mol = Chem.MolFromSmiles(selected_smiles)
648
+ if mol:
649
+ st.subheader("Molecular Structure")
650
+ st.image(Draw.MolToImage(mol, size=(300, 300)))
651
+ else:
652
+ st.error("❌ Invalid SMILES string.")
653
+ else:
654
+ st.info("πŸ” No valid or novel molecules available.")
655
+ else:
656
+ st.info("πŸ” No generated molecules available. Generate molecules using the VAE first.")
657
+
658
+ # About section
659
+ st.sidebar.title("ℹ️ About")
660
+ st.sidebar.info("""
661
+ **Molecule Generator and Property Predictor App**
662
+
663
+ This app uses a Variational Autoencoder (VAE) model and Graph Neural Networks (GNNs) to generate novel molecular structures and predict molecular properties.
664
+
665
+ - **Developed by**: Arjun, Kaustubh, and Nachiket
666
+ - **Hugging Face Repository**: https://huggingface.co/spaces/Raykarr/SMILES_Generation_and_Prediction
667
+ """)
668
+
669
+ # Hide Streamlit footer and header
670
+ hide_streamlit_style = """
671
+ <style>
672
+ footer {visibility: hidden;}
673
+ header {visibility: hidden;}
674
+ </style>
675
+ """
676
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
677
+
678
+ # Run the app
679
+ if __name__ == "__main__":
680
+ main()