|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
import os |
|
from typing import Optional, Dict, Sequence |
|
import transformers |
|
from peft import PeftModel |
|
import torch |
|
from dataclasses import dataclass, field |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
import pandas as pd |
|
from datasets import Dataset |
|
from tqdm import tqdm |
|
import spaces |
|
|
|
from llama_customized_models import LlamaForCausalLMWithNumericalEmbedding |
|
from torch.nn.utils.rnn import pad_sequence |
|
import numpy as np |
|
from torch.utils.data.dataloader import DataLoader |
|
from torch.nn import functional as F |
|
import importlib |
|
|
|
from rdkit import RDLogger, Chem |
|
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
device_map = "cuda" |
|
|
|
means = {"qed": 0.5559003125710424, "logp": 3.497542110420217, "sas": 2.889429694406497, "tpsa": 80.19717097706841} |
|
stds = {"qed": 0.21339854620824716, "logp": 1.7923582437824368, "sas": 0.8081188219568571, "tpsa": 38.212259443049554} |
|
|
|
def phrase_df(df): |
|
metric_calculator = importlib.import_module("metric_calculator") |
|
|
|
new_df = [] |
|
|
|
for i in range(len(df)): |
|
sub_df = dict() |
|
|
|
|
|
smiles = df.iloc[i]['SMILES'] |
|
|
|
property_names = df.iloc[i]['property_names'] |
|
|
|
non_normalized_properties = df.iloc[i]['non_normalized_properties'] |
|
|
|
sub_df['SMILES'] = smiles |
|
|
|
|
|
|
|
|
|
for j in range(len(property_names)): |
|
|
|
property_name = property_names[j] |
|
|
|
non_normalized_property = non_normalized_properties[j] |
|
|
|
sub_df[f'{property_name}_condition'] = non_normalized_property |
|
|
|
if smiles == "": |
|
sub_df[f'{property_name}_measured'] = np.nan |
|
else: |
|
property_eval_func_name = f"compute_{property_name}" |
|
property_eval_func = getattr(metric_calculator, property_eval_func_name) |
|
sub_df[f'{property_name}_measured'] = property_eval_func(Chem.MolFromSmiles(smiles)) |
|
|
|
new_df.append(sub_df) |
|
|
|
new_df = pd.DataFrame(new_df) |
|
return new_df |
|
|
|
|
|
@dataclass |
|
class DataCollatorForCausalLMEval(object): |
|
tokenizer: transformers.PreTrainedTokenizer |
|
source_max_len: int |
|
target_max_len: int |
|
molecule_target_aug_prob: float |
|
molecule_start_str: str |
|
scaffold_aug_prob: float |
|
scaffold_start_str: str |
|
property_start_str: str |
|
property_inner_sep: str |
|
property_inter_sep: str |
|
end_str: str |
|
ignore_index: int |
|
has_scaffold: bool |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
|
prop_token_map = { |
|
'qed': '<qed>', |
|
'logp': '<logp>', |
|
'sas': '<SAS>', |
|
'tpsa': '<TPSA>' |
|
} |
|
|
|
sources = [] |
|
props_list = [] |
|
non_normalized_props_list = [] |
|
prop_names_list = [] |
|
props_index_list = [] |
|
temperature_list = [] |
|
scaffold_list = [] |
|
for example in instances: |
|
prop_names = example['property_name'] |
|
prop_values = example['property_value'] |
|
non_normalized_prop_values = example['non_normalized_property_value'] |
|
temperature = example['temperature'] |
|
|
|
|
|
|
|
props_str = "" |
|
scaffold_str = "" |
|
props = [] |
|
non_nornalized_props = [] |
|
props_index = [] |
|
|
|
|
|
if self.has_scaffold: |
|
scaffold = example['scaffold_smiles'].strip() |
|
scaffold_str = f"{self.scaffold_start_str}{scaffold}{self.end_str}" |
|
|
|
props_str = f"{self.property_start_str}" |
|
for i, prop in enumerate(prop_names): |
|
prop = prop.lower() |
|
props_str += f"{prop_token_map[prop]}{self.property_inner_sep}{self.molecule_start_str}{self.property_inter_sep}" |
|
props.append(prop_values[i]) |
|
non_nornalized_props.append(non_normalized_prop_values[i]) |
|
props_index.append(3 + 4 * i) |
|
props_str += f"{self.end_str}" |
|
|
|
source = props_str + scaffold_str + "<->>" + self.molecule_start_str |
|
|
|
sources.append(source) |
|
props_list.append(props) |
|
non_normalized_props_list.append(non_nornalized_props) |
|
props_index_list.append(props_index) |
|
prop_names_list.append(prop_names) |
|
temperature_list.append(temperature) |
|
|
|
|
|
tokenized_sources_with_prompt = self.tokenizer( |
|
sources, |
|
max_length=self.source_max_len, |
|
truncation=True, |
|
add_special_tokens=False, |
|
) |
|
|
|
|
|
input_ids = [] |
|
for tokenized_source in tokenized_sources_with_prompt['input_ids']: |
|
input_ids.append(torch.tensor(tokenized_source)) |
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) |
|
|
|
data_dict = { |
|
'input_ids': input_ids, |
|
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), |
|
'properties': props_list, |
|
'non_normalized_properties': non_normalized_props_list, |
|
'property_names': prop_names_list, |
|
'properties_index': props_index_list, |
|
'temperature': temperature_list, |
|
} |
|
|
|
return data_dict |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
non_special_tokens = None, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens) |
|
num_old_tokens = model.get_input_embeddings().weight.shape[0] |
|
num_new_tokens = len(tokenizer) - num_old_tokens |
|
if num_new_tokens == 0: |
|
return |
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings_data = model.get_input_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg |
|
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.") |
|
|
|
class MolecularGenerationModel(): |
|
def __init__(self): |
|
model_id = "ChemFM/molecular_cond_generation_guacamol" |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
padding_side="right", |
|
use_fast=True, |
|
trust_remote_code=True, |
|
token = os.environ.get("TOKEN") |
|
) |
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
model_id, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
token = os.environ.get("TOKEN") |
|
) |
|
|
|
self.model = LlamaForCausalLMWithNumericalEmbedding.from_pretrained( |
|
model_id, |
|
config=config, |
|
device_map=device_map, |
|
trust_remote_code=True, |
|
token = os.environ.get("TOKEN") |
|
) |
|
|
|
|
|
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN) |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=special_tokens_dict, |
|
tokenizer=self.tokenizer, |
|
model=self.model |
|
) |
|
self.model.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
self.model.eval() |
|
|
|
string_template_path = hf_hub_download(model_id, filename="string_template.json", token = os.environ.get("TOKEN")) |
|
string_template = json.load(open(string_template_path, 'r')) |
|
molecule_start_str = string_template['MOLECULE_START_STRING'] |
|
scaffold_start_str = string_template['SCAFFOLD_MOLECULE_START_STRING'] |
|
property_start_str = string_template['PROPERTY_START_STRING'] |
|
property_inner_sep = string_template['PROPERTY_INNER_SEP'] |
|
property_inter_sep = string_template['PROPERTY_INTER_SEP'] |
|
end_str = string_template['END_STRING'] |
|
|
|
self.data_collator = DataCollatorForCausalLMEval( |
|
tokenizer=self.tokenizer, |
|
source_max_len=512, |
|
target_max_len=512, |
|
molecule_target_aug_prob=1.0, |
|
scaffold_aug_prob=0.0, |
|
molecule_start_str=molecule_start_str, |
|
scaffold_start_str=scaffold_start_str, |
|
property_start_str=property_start_str, |
|
property_inner_sep=property_inner_sep, |
|
property_inter_sep=property_inter_sep, |
|
end_str=end_str, |
|
ignore_index=-100, |
|
has_scaffold=False |
|
) |
|
|
|
|
|
def generate(self, loader): |
|
|
|
|
|
df = [] |
|
pbar = tqdm(loader, desc=f"Evaluating...", leave=False) |
|
for it, batch in enumerate(pbar): |
|
sub_df = dict() |
|
|
|
batch_size = batch['input_ids'].shape[0] |
|
assert batch_size == 1, "The batch size should be 1" |
|
|
|
temperature = batch['temperature'][0] |
|
property_names = batch['property_names'][0] |
|
non_normalized_properties = batch['non_normalized_properties'][0] |
|
|
|
num_generations = 1 |
|
del batch['temperature'] |
|
del batch['property_names'] |
|
del batch['non_normalized_properties'] |
|
|
|
batch['input_ids'] = batch['input_ids'].to(self.model.device) |
|
|
|
|
|
input_length = batch['input_ids'].shape[1] |
|
steps = 1024 - input_length |
|
|
|
with torch.set_grad_enabled(False): |
|
early_stop_flags = torch.zeros(num_generations, dtype=torch.bool).to(self.model.device) |
|
for k in range(steps): |
|
logits = self.model(**batch)['logits'] |
|
logits = logits[:, -1, :] / temperature |
|
probs = F.softmax(logits, dim=-1) |
|
ix = torch.multinomial(probs, num_samples=num_generations) |
|
|
|
ix[early_stop_flags] = self.tokenizer.eos_token_id |
|
|
|
batch['input_ids'] = torch.cat([batch['input_ids'], ix], dim=-1) |
|
early_stop_flags |= (ix.squeeze() == self.tokenizer.eos_token_id) |
|
|
|
if torch.all(early_stop_flags): |
|
break |
|
|
|
generations = self.tokenizer.batch_decode(batch['input_ids'][:, input_length:], skip_special_tokens=True) |
|
generations = map(lambda x: x.replace(" ", ""), generations) |
|
|
|
predictions = [] |
|
for generation in generations: |
|
try: |
|
predictions.append(Chem.MolToSmiles(Chem.MolFromSmiles(generation))) |
|
except: |
|
predictions.append("") |
|
|
|
sub_df['SMILES'] = predictions[0] |
|
sub_df['property_names'] = property_names |
|
sub_df['property'] = batch['properties'][0] |
|
sub_df['non_normalized_properties'] = non_normalized_properties |
|
|
|
df.append(sub_df) |
|
|
|
df = pd.DataFrame(df) |
|
return df |
|
|
|
|
|
|
|
|
|
def predict_single_smiles(self, input_dict: Dict): |
|
|
|
input_dict = {key.lower(): value for key, value in input_dict.items()} |
|
|
|
properties = [key.lower() for key in input_dict.keys()] |
|
property_means = [means[prop] for prop in properties] |
|
property_stds = [stds[prop] for prop in properties] |
|
|
|
sample_point = [input_dict[prop] for prop in properties] |
|
non_normalized_sample_point = np.array(sample_point).reshape(-1) |
|
sample_point = (np.array(sample_point) - np.array(property_means)) / np.array(property_stds) |
|
sub_df = { |
|
"property_name": properties, |
|
"property_value": sample_point.tolist(), |
|
"temperature": 1.0, |
|
"non_normalized_property_value": non_normalized_sample_point.tolist() |
|
} |
|
|
|
test_dataset = [sub_df] * 10 |
|
test_dataset = pd.DataFrame(test_dataset) |
|
test_dataset = Dataset.from_pandas(test_dataset) |
|
|
|
|
|
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=self.data_collator) |
|
df = self.generate(test_loader) |
|
new_df = phrase_df(df) |
|
|
|
new_df = new_df.drop(columns=[col for col in new_df.columns if "condition" in col]) |
|
|
|
|
|
df = df[df["SMILES"] != ""] |
|
|
|
|
|
new_df = new_df.round(2) |
|
|
|
|
|
return new_df |
|
|
|
|
|
|
|
|
|
|
|
|