File size: 2,480 Bytes
13362e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright 2024 Llamole Team
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils.data import Dataset
from ..extras.constants import BOND_INDEX
def dict_to_list(data_dict, mol_properties):
return [data_dict.get(prop, float("nan")) for prop in mol_properties]
class MolQADataset(Dataset):
def __init__(self, data, tokenizer, max_len):
self.data = data
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
mol_properties = [
"BBBP",
"HIV",
"BACE",
"CO2",
"N2",
"O2",
"FFV",
"TC",
"SC",
"SA",
]
item = self.data[idx]
instruction = item["instruction"]
input_text = item["input"]
property_data = dict_to_list(item["property"], mol_properties)
property_data = torch.tensor(property_data)
# Combine instruction and input
combined_input = f"{instruction}\n{input_text}"
# Create messages for chat template
messages = [
{"role": "user", "content": combined_input}
]
# Apply chat template
chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize the chat text
encoding = self.tokenizer(
chat_text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_len,
)
return {
"input_ids": encoding.input_ids.squeeze(),
"attention_mask": encoding.attention_mask.squeeze(),
"property": property_data,
} |