gbyuvd's picture
Update README.md
23ec36a verified
metadata
license: cc-by-nc-sa-4.0
library_name: transformers
tags:
  - chemistry
  - selfies

chemfie-gpt-experiment-1

This model is part of my own hands-on learning and experimentation on molecule generation, to determine which type of model is best suited for SELFIES (GPT2, T5, or by way of fill-mask). It also serves as a baseline for future ablation and customization studies in model architecture, dataset augmentation(s), and training processes.

Model Details

  • Model Type: GPT-2
  • Architecture: L8, A6, H384
  • Task: Generation of SELFIES strings
  • Language: N/A (Chemical representation)

Personal Intended Use

  • Hands-on learning, research and experimentation in molecular generation
  • Baseline for ablation studies and comparisons with more advanced models

Usage

Direct Use

Since this model doesn't use a proper GPT2 format tokenizer, special tokens still need to be set up manually (next experiment will use a proper one ofc):

from transformers import PreTrainedTokenizerFast, AutoModelForCausalLM
import torch

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="gpt2_tokenizer.json",
    model_max_length=512,
    unk_token="<unk>",
    pad_token="<pad>",
    eos_token="</s>",
    bos_token="<s>",
    mask_token="<mask>",
)

model = AutoModelForCausalLM.from_pretrained("gbyuvd/chemfie-gpt-experiment-1")

# Generate some sample outputs
def generate_molecules(model, tokenizer, num_samples=5, max_length=100):
    model.eval()
    generated = []
    for _ in range(num_samples):
        input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(model.device)
        output = model.generate(input_ids, max_length=max_length, num_return_sequences=1, do_sample=True)
        generated.append(tokenizer.decode(output[0], skip_special_tokens=True))
    return generated

sample_molecules = generate_molecules(model, tokenizer)
print("Sample generated molecules:")
for i, mol in enumerate(sample_molecules, 1):
    print(f"{i}. {mol}")

""""
....
2. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [Branch1] [C] [C]
3. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [=C] [Ring1] [N]
4. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N]
5. [C] [Branch1] [C] [Branch1] [C] [C] [=N] [C] [Branch1] [C] [=N] [Branch1] [C] [N] [Branch1] [C]

""""


Tokenized SELFIES to SMILES:

import selfies as sf

test = "[C] [Branch1] [O] [=C] [C] [C] [C] [C] [C] [C] [C] [=Branch1] [=O] [O] [=C] [C] [C] [C] [Ring1]"
test = test.replace(' ', '')
print(sf.decoder(test))

""""
C(CCCCCCCCO)=CCC=C

""""

Generate with Different Temperature(s) and Visualization

import torch
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt


def generate_molecules(temperature, num_molecules=2):
    inputs = torch.tensor([[tokenizer.bos_token_id]])
    gen = model.generate(
        inputs,
        do_sample=True,
        max_length=256,
        temperature=temperature,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        num_beams=5,
        num_return_sequences=num_molecules
    )
    return tokenizer.batch_decode(gen, skip_special_tokens=True)

def selfies_to_smiles(selfies_str):
    selfies_str = selfies_str.replace(' ', '')
    try:
        return sf.decoder(selfies_str)
    except:
        return None

def visualize_molecules(temperatures):
    fig, axs = plt.subplots(len(temperatures), 2, figsize=(20, 4*len(temperatures))) # don't forget to change this args, if you want to generate more than 2 samples each
    fig.suptitle("Generated Molecules at Different Temperatures", fontsize=16)

    for i, temp in enumerate(temperatures):
        molecules = generate_molecules(temp)
        for j, mol in enumerate(molecules):
            smiles = selfies_to_smiles(mol)
            if smiles:
                rdkit_mol = Chem.MolFromSmiles(smiles)
                if rdkit_mol:
                    img = Draw.MolToImage(rdkit_mol)
                    axs[i, j].imshow(img)
                    axs[i, j].axis('off')
                    axs[i, j].set_title(f"Temp: {temp}", fontsize=10)
                else:
                    axs[i, j].text(0.5, 0.5, "Invalid\nMolecule", ha='center', va='center')
                    axs[i, j].axis('off')
            else:
                axs[i, j].text(0.5, 0.5, "Invalid\nSELFIES", ha='center', va='center')
                axs[i, j].axis('off')

    plt.tight_layout()
    plt.show()

# Generate and visualize molecules at different temperatures
temperatures = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5]
visualize_molecules(temperatures)

Output example:

image/png

Generate using Starting Sequence with Different Temperature(s) and Visualization

import torch
import selfies as sf
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt


def generate_molecules(seed, temperature, num_molecules=5):
    # Tokenize the seed
    seed_tokens = tokenizer.encode(seed, add_special_tokens=False, return_tensors="pt")
    
    # Generate from the seed
    gen = model.generate(
        seed_tokens,
        do_sample=True,
        max_length=256,
        temperature=temperature,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        num_beams=5,
        num_return_sequences=num_molecules
    )
    
    # Decode the generated sequences
    generated = tokenizer.batch_decode(gen, skip_special_tokens=True)
    
    # Combine seed with generated sequences
    return [seed + seq[len(seed):] for seq in generated]

def selfies_to_smiles(selfies_str):
    selfies_str = selfies_str.replace(' ', '')
    try:
        return sf.decoder(selfies_str)
    except:
        return None

def visualize_molecules(seed, temperatures):
    fig, axs = plt.subplots(len(temperatures), 5, figsize=(20, 4*len(temperatures)))
    fig.suptitle(f"Generated Molecules at Different Temperatures\nSeed: {seed}", fontsize=16)

    for i, temp in enumerate(temperatures):
        molecules = generate_molecules(seed, temp)
        for j, mol in enumerate(molecules):
            smiles = selfies_to_smiles(mol)
            if smiles:
                rdkit_mol = Chem.MolFromSmiles(smiles)
                if rdkit_mol:
                    img = Draw.MolToImage(rdkit_mol)
                    axs[i, j].imshow(img)
                    axs[i, j].axis('off')
                    axs[i, j].set_title(f"Temp: {temp}", fontsize=10)
                else:
                    axs[i, j].text(0.5, 0.5, "Invalid\nMolecule", ha='center', va='center')
                    axs[i, j].axis('off')
            else:
                axs[i, j].text(0.5, 0.5, "Invalid\nSELFIES", ha='center', va='center')
                axs[i, j].axis('off')

    plt.tight_layout()
    plt.show()

# Set the seed and temperatures
seed = "[C] [C] [=Branch1] [C] [=O] [O] [C] [C] [N+1]"
temperatures = [0.5, 1.0, 1.5, 2.0, 2.5]

# Generate and visualize molecules at different temperatures
visualize_molecules(seed, temperatures)

Example output:

image/png

Training Data

  • Source: Curated and merged from COCONUTDB (Sorokina et al., 2021), ChemBL34 (Zdrazil et al., 2023), and SuperNatural3 (Gallo et al. 2023) database
  • Total: 2,933,355 samples
  • Total Train: 2,346,680 samples
  • Validation: 293,336 samples
  • Per chunk: 586,670 train, 73,334 validation, 73,334 test
  • Random seed for split: 42

Training Procedure

  • Batch Size: 64
  • Num Epoch for Each Chunk: 1
  • Learning Rate: 1.5e-5
  • Optimizer: Ranger21 (MADGRAD-Lookahead-AdaBelief with gradient centralization, linear warm up (22%), gradient clipping, and L2 weight decay)

Training Logs

Chunk Chunk's Training Loss Chunk's Validation Loss Status
I 1.346400 1.065180 Done
II 1.123500 0.993118 Done
III 1.058300 0.948303 Done
IV 1.016600 0.921706 Done

Evaluation Results

[To be filled after model evaluation]

Limitations and Biases

  • May generate unrealistic or synthetically inaccessible molecules
  • Performance on complex, branched, and ringed molecules to be evaluated

Disclaimer & Ethical Considerations

  • This model is in early development stage and may not consistently generate valid outputs.
  • It is intended for personal exploration, academic, and research purposes only.
  • You should be aware of potential ethical concerns:
    • Possible generation of harmful substances if misused
    • Potential biases inherent in the training data
  • The accuracy, completeness, and reliability of the model's outputs are not guaranteed.
  • This model should not be used for any commercial or legal purposes.
  • The information and model provided are for educational and research use only.

Additional Information

  • Part of experimental chemfie-gpt/T5 project
  • Serves as a baseline for future experiments with further curated datasets, training improvements, and architectural modifications

Citation

BibTeX

COCONUTDB

@article{sorokina2021coconut,
  title={COCONUT online: Collection of Open Natural Products database},
  author={Sorokina, Maria and Merseburger, Peter and Rajan, Kohulan and Yirik, Mehmet Aziz and Steinbeck, Christoph},
  journal={Journal of Cheminformatics},
  volume={13},
  number={1},
  pages={2},
  year={2021},
  doi={10.1186/s13321-020-00478-9}
}

ChemBL34

@article{zdrazil2023chembl,
  title={The ChEMBL Database in 2023: a drug discovery platform spanning multiple bioactivity data types and time periods},
  author={Zdrazil, Barbara and Felix, Eloy and Hunter, Fiona and Manners, Emma J and Blackshaw, James and Corbett, Sybilla and de Veij, Marleen and Ioannidis, Harris and Lopez, David Mendez and Mosquera, Juan F and Magarinos, Maria Paula and Bosc, Nicolas and Arcila, Ricardo and Kizil{\"o}ren, Tevfik and Gaulton, Anna and Bento, A Patr{\'i}cia and Adasme, Melissa F and Monecke, Peter and Landrum, Gregory A and Leach, Andrew R},
  journal={Nucleic Acids Research},
  year={2023},
  volume={gkad1004},
  doi={10.1093/nar/gkad1004}
}

@misc{chembl34,
  title={ChemBL34},
  year={2023},
  doi={10.6019/CHEMBL.database.34}
}

SuperNatural3

@article{Gallo2023,
  author = {Gallo, K and Kemmler, E and Goede, A and Becker, F and Dunkel, M and Preissner, R and Banerjee, P},
  title = {{SuperNatural 3.0-a database of natural products and natural product-based derivatives}},
  journal = {Nucleic Acids Research},
  year = {2023},
  month = jan,
  day = {6},
  volume = {51},
  number = {D1},
  pages = {D654-D659},
  doi = {10.1093/nar/gkac1008}
}