|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
from functools import partial |
|
from typing import TYPE_CHECKING, Any, Dict, List, Union, Tuple |
|
|
|
from datasets import Features |
|
|
|
from ..extras.logging import get_logger |
|
from .data_utils import Role |
|
|
|
if TYPE_CHECKING: |
|
from datasets import Dataset, IterableDataset |
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
from ..hparams import DataArguments |
|
from .parser import DatasetAttr |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
def extract_all_smiles(text): |
|
pattern = r'<mol_start>(.*?)<mol_end>' |
|
return re.findall(pattern, text) |
|
|
|
def replace_all_smiles(text): |
|
pattern = r'<mol_start>.*?<mol_end>' |
|
return re.sub(pattern, '<molecule>', text) |
|
|
|
def replace_smiles_with_callback(text): |
|
def replace_mol(match): |
|
design_end = match.group(1) |
|
smiles = match.group(2) |
|
|
|
return f'{design_end}<molecule><rollback_start>{smiles}<rollback_end>' |
|
|
|
pattern = r'(<design_start><design_end>)<mol_start>(.*?)<mol_end>' |
|
text = re.sub(pattern, replace_mol, text) |
|
|
|
|
|
remaining_pattern = r'<mol_start>.*?<mol_end>' |
|
text = re.sub(remaining_pattern, '<molecule>', text) |
|
|
|
return text |
|
|
|
def dict_to_list(data_dict, mol_properties): |
|
return [data_dict.get(prop, None) for prop in mol_properties] |
|
|
|
def insert_bodies(text, num_insertions, retro_labels): |
|
design_pattern = r'<design_start>(.*?)<design_end>' |
|
retro_pattern = r'(This is step \d+ in the retrosynthesis process\..*?<retro_start>.*?<retro_end>)(.*?)(?=This is step \d+|$)' |
|
|
|
def replace_design(match): |
|
return f'<design_start>' + ''.join(['<design_body>'] * num_insertions) + f'<design_end>' |
|
|
|
def replace_retro(match, label): |
|
step_content = match.group(1) |
|
remaining_text = match.group(2) |
|
retro_match = re.search(r'<retro_start>(.*?)<retro_end>', step_content) |
|
if retro_match and label is not None: |
|
modified_content = f'<retro_start>' + ''.join(['<retro_body>'] * num_insertions) + f'<retro_end>' |
|
return re.sub(r'<retro_start>.*?<retro_end>', modified_content, step_content) |
|
return step_content + remaining_text |
|
|
|
text = re.sub(design_pattern, replace_design, text) |
|
|
|
steps = re.finditer(retro_pattern, text) |
|
modified_text = "" |
|
last_end = 0 |
|
|
|
for i, step in enumerate(steps): |
|
label = retro_labels[i] if i < len(retro_labels) else None |
|
modified_text += text[last_end:step.start()] + replace_retro(step, label) |
|
last_end = step.end() |
|
|
|
modified_text += text[last_end:] |
|
return modified_text |
|
|
|
def extract_retro_products(text): |
|
pattern = r'<retro_end>(.*?)>>' |
|
matches = re.findall(pattern, text) |
|
return [match.strip() for match in matches] |
|
|
|
def convert_molqa( |
|
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" |
|
) -> Dict[str, List[Any]]: |
|
r""" |
|
Converts alpaca format dataset to the standard format. |
|
""" |
|
outputs = {"prompt": [], "response": [], "system": [], "molecules": [], "property": [], "retro_labels": [], "retro_products": []} |
|
|
|
mol_properties = ['BBBP', 'HIV', 'BACE', 'CO2', 'N2', 'O2', 'FFV', 'TC', 'SC', 'SA'] |
|
for i in range(len(examples[dataset_attr.prompt])): |
|
prompt = [] |
|
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): |
|
for old_prompt, old_response in examples[dataset_attr.history][i]: |
|
prompt.append({"role": Role.USER.value, "content": old_prompt}) |
|
prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) |
|
|
|
content = [] |
|
if dataset_attr.prompt and examples[dataset_attr.prompt][i]: |
|
content.append(examples[dataset_attr.prompt][i]) |
|
|
|
if dataset_attr.query and examples[dataset_attr.query][i]: |
|
content.append(examples[dataset_attr.query][i]) |
|
|
|
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) |
|
|
|
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): |
|
current_response = examples[dataset_attr.response][i] |
|
smiles_list = extract_all_smiles(current_response) |
|
modified_response = replace_smiles_with_callback(current_response) |
|
retro_labels = examples[dataset_attr.retro][i] if dataset_attr.retro else [] |
|
retro_products = extract_retro_products(current_response) |
|
modified_response = insert_bodies(modified_response, data_args.learned_query_size, retro_labels) |
|
|
|
response = [{"role": Role.ASSISTANT.value, "content": modified_response}] |
|
else: |
|
response = [] |
|
|
|
outputs["prompt"].append(prompt) |
|
outputs["response"].append(response) |
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") |
|
outputs["molecules"].append(smiles_list) |
|
outputs["property"].append(dict_to_list(examples[dataset_attr.property][i], mol_properties)) |
|
outputs["retro_labels"].append(retro_labels) |
|
outputs["retro_products"].append(retro_products) |
|
|
|
return outputs |
|
|
|
def map_smiles_to_id(example, smiles_to_id): |
|
example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] |
|
return example |
|
|
|
def align_dataset( |
|
dataset: Union["Dataset", "IterableDataset"], |
|
dataset_attr: "DatasetAttr", |
|
data_args: "DataArguments", |
|
training_args: "Seq2SeqTrainingArguments", |
|
) -> Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: |
|
r""" |
|
Aligns the dataset and maps unique SMILES strings to molecule IDs. |
|
|
|
This function performs the following operations: |
|
1. Converts the dataset to the required format (molqa). |
|
2. Extracts all unique SMILES strings from the dataset. |
|
3. Maps each unique SMILES string to a unique integer ID (0, 1, 2, ...). |
|
4. Update 'molecules' field to each example, containing the mapped IDs. |
|
|
|
The aligned dataset contains the following fields: |
|
prompt: [{"role": "user", "content": "..."}] * (2T - 1) |
|
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) |
|
system: "..." |
|
molecules: [List of SMILES string] |
|
property: [List of float values] |
|
retro_labels: [List of int values] |
|
retro_products: [List of SMILES string] |
|
|
|
Args: |
|
dataset (Union["Dataset", "IterableDataset"]): The input dataset. |
|
dataset_attr (DatasetAttr): Attributes of the dataset. |
|
data_args (DataArguments): Arguments for data processing. |
|
training_args (Seq2SeqTrainingArguments): Arguments for training. |
|
|
|
Returns: |
|
Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: |
|
- The aligned and converted dataset with molecule IDs. |
|
- A dictionary mapping molecule IDs to their SMILES strings. |
|
""" |
|
assert dataset_attr.formatting == "molqa" |
|
|
|
features = Features.from_dict( |
|
{ |
|
"prompt": [ |
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} |
|
], |
|
"response": [ |
|
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} |
|
], |
|
"system": {"dtype": "string", "_type": "Value"}, |
|
"molecules": [{'dtype': "string", "_type": "Value"}], |
|
"property": [{"dtype": "float", "_type": "Value"}], |
|
"retro_labels": [{"dtype": "int32", "_type": "Value"}], |
|
"retro_products": [{'dtype': "string", "_type": "Value"}], |
|
} |
|
) |
|
|
|
convert_func = partial(convert_molqa, dataset_attr=dataset_attr, data_args=data_args) |
|
aligned = dataset.map( |
|
convert_func, |
|
batched=True, |
|
remove_columns=['instruction', 'input', 'output', 'property', 'retro'], |
|
features=features, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), |
|
desc="Converting molqa format of dataset" |
|
) |
|
|
|
|
|
all_smiles = set() |
|
for item in aligned: |
|
all_smiles.update(item['molecules']) |
|
all_smiles.update(item['retro_products']) |
|
|
|
smiles_to_id = {smiles: idx for idx, smiles in enumerate(sorted(all_smiles))} |
|
id_to_smiles = {idx: smiles for smiles, idx in smiles_to_id.items()} |
|
|
|
def map_smiles_to_id(example, smiles_to_id): |
|
example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] |
|
example['retro_products'] = [smiles_to_id[smiles] for smiles in example['retro_products']] |
|
return example |
|
|
|
smiles_convert_func = partial(map_smiles_to_id, smiles_to_id=smiles_to_id) |
|
|
|
aligned = aligned.map( |
|
smiles_convert_func, |
|
desc="Mapping SMILES to molecule IDs", |
|
) |
|
|
|
return aligned, id_to_smiles |