|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, List, Optional, Dict, Any |
|
|
|
from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer |
|
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX |
|
from ..extras.misc import get_logits_processor |
|
from ..extras.ploting import plot_loss |
|
from ..model import load_tokenizer |
|
from ..hparams import get_infer_args, get_train_args |
|
from ..model import GraphLLMForCausalMLM |
|
from .dataset import MolQADataset |
|
|
|
import re |
|
import os |
|
import json |
|
import math |
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
if TYPE_CHECKING: |
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
from ..hparams import ( |
|
DataArguments, |
|
FinetuningArguments, |
|
GeneratingArguments, |
|
ModelArguments, |
|
) |
|
|
|
def remove_extra_spaces(text): |
|
|
|
cleaned_text = re.sub(r'\s+', ' ', text) |
|
|
|
return cleaned_text.strip() |
|
|
|
def run_eval(args: Optional[Dict[str, Any]] = None) -> None: |
|
print(args) |
|
raise ValueError('stop') |
|
model_args, data_args, training_args, finetuning_args, generating_args = ( |
|
get_train_args(args) |
|
) |
|
|
|
if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]: |
|
run_molqa( |
|
model_args, data_args, training_args, finetuning_args, generating_args |
|
) |
|
else: |
|
raise ValueError("Unknown dataset: {}.".format(data_args.dataset)) |
|
|
|
|
|
def run_molqa( |
|
model_args: "ModelArguments", |
|
data_args: "DataArguments", |
|
training_args: "Seq2SeqTrainingArguments", |
|
finetuning_args: "FinetuningArguments", |
|
generating_args: "GeneratingArguments", |
|
): |
|
tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"] |
|
|
|
data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json") |
|
with open(data_info_path, "r") as f: |
|
dataset_info = json.load(f) |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
dataset_name = data_args.dataset.strip() |
|
try: |
|
filename = dataset_info[dataset_name]["file_name"] |
|
except KeyError: |
|
raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json") |
|
data_path = os.path.join(data_args.dataset_dir, f"{filename}") |
|
with open(data_path, "r") as f: |
|
original_data = json.load(f) |
|
|
|
|
|
dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len) |
|
dataloader = DataLoader( |
|
dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False |
|
) |
|
|
|
gen_kwargs = generating_args.to_dict() |
|
gen_kwargs["eos_token_id"] = [ |
|
tokenizer.eos_token_id |
|
] + tokenizer.additional_special_tokens_ids |
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id |
|
gen_kwargs["logits_processor"] = get_logits_processor() |
|
|
|
model = GraphLLMForCausalMLM.from_pretrained( |
|
tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True |
|
) |
|
|
|
all_results = [] |
|
property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"] |
|
|
|
|
|
global_idx = 0 |
|
all_smiles = [] |
|
for batch_idx, batch in enumerate(dataloader): |
|
input_ids = batch["input_ids"].to(model.device) |
|
attention_mask = batch["attention_mask"].to(model.device) |
|
property_data = batch["property"].to(model.device) |
|
model.eval() |
|
with torch.no_grad(): |
|
all_info_dict = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
molecule_properties=property_data, |
|
do_molecular_design=True, |
|
do_retrosynthesis=False, |
|
rollback=True, |
|
**gen_kwargs, |
|
) |
|
|
|
batch_results = [] |
|
for i in range(len(all_info_dict["smiles_list"])): |
|
original_data_idx = global_idx + i |
|
original_item = original_data[original_data_idx] |
|
|
|
llm_response = "".join(item for item in all_info_dict["text_lists"][i]) |
|
result = { |
|
"qa_idx": original_data_idx, |
|
"instruction": original_item["instruction"], |
|
"input": original_item["input"], |
|
"llm_response": llm_response, |
|
"response_design": remove_extra_spaces(llm_response), |
|
"llm_smiles": all_info_dict["smiles_list"][i], |
|
"property": {}, |
|
} |
|
|
|
|
|
for j, prop_name in enumerate(property_names): |
|
prop_value = property_data[i][j].item() |
|
if not math.isnan(prop_value): |
|
result["property"][prop_name] = prop_value |
|
|
|
batch_results.append(result) |
|
|
|
all_results.extend(batch_results) |
|
all_smiles.extend([result['llm_smiles'] for result in batch_results]) |
|
global_idx += len(batch_results) |
|
|
|
|
|
retro_batch_start = 0 |
|
for batch_idx, batch in enumerate(dataloader): |
|
|
|
input_ids = batch["input_ids"].to(model.device) |
|
attention_mask = batch["attention_mask"].to(model.device) |
|
batch_size = input_ids.shape[0] |
|
batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size] |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
all_info_dict = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
do_molecular_design=False, |
|
do_retrosynthesis=True, |
|
input_smiles_list=batch_smiles, |
|
expansion_topk=50, |
|
iterations=100, |
|
max_planning_time=30, |
|
**gen_kwargs, |
|
) |
|
|
|
batch_results = [] |
|
for i in range(batch_size): |
|
result = all_results[retro_batch_start + i] |
|
retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]] |
|
result["llm_reactions"] = [] |
|
if retro_plan["success"]: |
|
for reaction, template, cost in zip( |
|
retro_plan["reaction_list"], |
|
retro_plan["templates"], |
|
retro_plan["cost"], |
|
): |
|
result["llm_reactions"].append( |
|
{"reaction": reaction, "template": template, "cost": cost} |
|
) |
|
|
|
|
|
if None in all_info_dict["text_lists"][i]: |
|
print(f"List contains None: {all_info_dict['text_lists'][i]}") |
|
new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None) |
|
else: |
|
new_text = "".join(item for item in all_info_dict["text_lists"][i]) |
|
|
|
result["llm_response"] += new_text |
|
result["llm_response"] = remove_extra_spaces(result["llm_response"]) |
|
result["response_retro"] = remove_extra_spaces(new_text) |
|
batch_results.append(result) |
|
|
|
retro_batch_start += batch_size |
|
|
|
print('all_results', all_results) |
|
print("\nSummary of results:") |
|
print_len = min(5, len(all_results)) |
|
for result in all_results[:print_len]: |
|
print(f"\nData point {result['qa_idx']}:") |
|
print(f" Instruction: {result['instruction']}") |
|
print(f" Input: {result['input']}") |
|
print(f" LLM Response: {result['llm_response']}") |
|
print(f" LLM SMILES: {result['llm_smiles']}") |
|
print(f" Number of reactions: {len(result['llm_reactions'])}") |
|
for prop_name, prop_value in result["property"].items(): |
|
print(f" {prop_name}: {prop_value}") |
|
|
|
print("\nAll data processed successfully.") |