|
from transformers import ( |
|
AdamW, |
|
DataCollatorWithPadding, |
|
HfArgumentParser, |
|
T5Config, |
|
T5ForConditionalGeneration, |
|
T5Tokenizer, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
from torch.utils.data import DataLoader |
|
import os |
|
import tempfile |
|
import re |
|
import pandas as pd |
|
import numpy as np |
|
from typing import Dict, Optional |
|
from dataclasses import dataclass, field |
|
import logging |
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.optim.lr_scheduler import LambdaLR |
|
import os.path as pt |
|
import torch.optim as optim |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from torch.autograd import Variable |
|
from .iupac_dataset_class import IUPACDataset |
|
import os |
|
|
|
|
|
|
|
class T5Collator: |
|
def __init__(self, pad_token_id): |
|
super().__init__() |
|
self.pad_token_id = pad_token_id |
|
def __call__(self, records): |
|
|
|
batch = {} |
|
padvals = {"input_ids": self.pad_token_id,'attention_mask':0} |
|
for k in records[0]: |
|
if k in padvals: |
|
batch[k] = pad_sequence([torch.tensor(r[k]) for r in records], |
|
batch_first=True, |
|
padding_value=padvals[k]) |
|
else: |
|
batch[k] = torch.LongTensor([r[k] for r in records]) |
|
return batch |
|
|
|
class T5IUPACTokenizer(T5Tokenizer): |
|
def prepare_for_tokenization(self, text, is_split_into_words=False, |
|
**kwargs): |
|
return re.sub(" ", "_", text), kwargs |
|
|
|
def _decode(self, *args, **kwargs): |
|
|
|
text = super()._decode(*args, **kwargs) |
|
text = re.sub("extra_id_", "extraAidA", text) |
|
text = re.sub("_", " ", text) |
|
text = re.sub("extraAidA", "extra_id_", text) |
|
return text |
|
|
|
def sentinels(self, sentinel_ids): |
|
return self.vocab_size - sentinel_ids - 1 |
|
|
|
def sentinel_mask(self, ids): |
|
return ((self.vocab_size - self._extra_ids <= ids) & |
|
(ids < self.vocab_size)) |
|
|
|
def _tokenize(self, text, sample=False): |
|
|
|
pieces = super()._tokenize(text) |
|
|
|
return pieces[1:] |
|
|
|
def prepare_input_class(data,device): |
|
from collections.abc import Mapping |
|
if isinstance(data, Mapping): |
|
return type(data)({k: prepare_input_class(v,device) for k, v in data.items()}) |
|
elif isinstance(data, (tuple, list)): |
|
return type(data)(prepare_input_class(v,device) for v in data) |
|
elif isinstance(data, torch.Tensor): |
|
kwargs = dict(device=device) |
|
if data.dtype != torch.int64: |
|
|
|
|
|
|
|
kwargs.update(dict(dtype=torch.int64)) |
|
|
|
return data.to(**kwargs) |
|
return data |
|
|
|
def get_data_loader_class(is_train=1): |
|
|
|
full_path = '/root/autodl-tmp/wjm/iupac-gpt/iupac_gpt/' |
|
|
|
iupac_tokenizer = T5IUPACTokenizer(vocab_file=full_path+'iupac_spm.model') |
|
iupac_vocab_size = iupac_tokenizer.vocab_size |
|
print('iupac_vocab_size:',iupac_vocab_size) |
|
if is_train: |
|
torch.save(iupac_tokenizer, pt.join(full_path,"real_iupac_tokenizer.pt")) |
|
print("training...",len(iupac_tokenizer)) |
|
else: |
|
iupac_tokenizer = torch.load(pt.join(full_path,"real_iupac_tokenizer.pt"), map_location="cpu") |
|
print('fina_tune...',len(iupac_tokenizer)) |
|
|
|
dataset_filename = 'iupacs_logp.csv' |
|
target_col = "LogP" |
|
iupac_name_col = 'iupac' |
|
MAXLEN=1024 |
|
dataset_kwargs = {"dataset_dir":full_path,"dataset_filename": dataset_filename,"tokenizer": iupac_tokenizer,"max_length": MAXLEN,"target_col": target_col,'dataset_size':None,"iupac_name_col":iupac_name_col} |
|
train_dataset = IUPACDataset(**dataset_kwargs) |
|
collator = T5Collator(iupac_tokenizer.pad_token_id) |
|
train_dataloader = DataLoader(train_dataset,batch_size=64,collate_fn=collator,shuffle=True) |
|
|
|
return train_dataloader,iupac_tokenizer |
|
|
|
if __name__ == "__main__": |
|
|
|
train_dataloader,iupac_tokenizer = get_data_loader_class(is_train=1) |
|
pbar = tqdm(train_dataloader) |
|
device = 'cpu' |
|
for inputs in pbar: |
|
|
|
src_label = Variable(inputs["labels"].to(device)) |
|
inputs = prepare_input_class(inputs,device) |
|
src = Variable(inputs["input_ids"].to(device)) |
|
|
|
|
|
print(src[:,:].shape,src_label) |
|
|