File size: 1,448 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch_geometric.data import Dataset
import os
from torch_geometric.data import InMemoryDataset
import random
import json
from .data_utils import reformat_smiles

class ChEBI_dataset(Dataset):
    def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
        super(ChEBI_dataset, self).__init__(root)
        self.root = root
        self.file_path = os.path.join(root, f'{mode}.txt')
        self.smi_max_len = smi_max_len
        self.tokenizer = None
        self.use_graph = use_graph
        self.smiles_type = smiles_type
        if self.use_graph:
            self.idx_graph_map = torch.load(os.path.join(root, 'cid_graph_map.pt'))
        with open(self.file_path) as f:
            lines = f.readlines()
            self.data = [line.split('\t', maxsplit=2) for line in lines[1:]]
        

    def get(self, index):
        return self.__getitem__(index)

    def len(self):
        return len(self)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        cid, smiles, text = self.data[index]
        smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
        smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
        text = text.strip() + '\n'
        if self.use_graph:
            graph_list = [self.idx_graph_map[cid]]

        return index, graph_list, text, smiles_prompt