|
|
|
|
|
''' |
|
This module contains our Dataset classes and functions that load the three datasets |
|
for training and evaluating multitask BERT. |
|
|
|
Feel free to edit code in this file if you wish to modify the way in which the data |
|
examples are preprocessed. |
|
''' |
|
|
|
import csv |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from tokenizer import BertTokenizer |
|
|
|
|
|
def preprocess_string(s): |
|
return ' '.join(s.lower() |
|
.replace('.', ' .') |
|
.replace('?', ' ?') |
|
.replace(',', ' ,') |
|
.replace('\'', ' \'') |
|
.split()) |
|
|
|
|
|
class SentenceClassificationDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
|
|
sents = [x[0] for x in data] |
|
labels = [x[1] for x in data] |
|
sent_ids = [x[2] for x in data] |
|
|
|
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
labels = torch.LongTensor(labels) |
|
|
|
return token_ids, attention_mask, labels, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'labels': labels, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
|
|
class SentenceClassificationTestDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
sent_ids = [x[1] for x in data] |
|
|
|
encoding = self.tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attention_mask, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class SentencePairDataset(Dataset): |
|
def __init__(self, dataset, args, isRegression=False): |
|
self.dataset = dataset |
|
self.p = args |
|
self.isRegression = isRegression |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sent1 = [x[0] for x in data] |
|
sent2 = [x[1] for x in data] |
|
labels = [x[2] for x in data] |
|
sent_ids = [x[3] for x in data] |
|
|
|
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True) |
|
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True) |
|
|
|
token_ids = torch.LongTensor(encoding1['input_ids']) |
|
attention_mask = torch.LongTensor(encoding1['attention_mask']) |
|
token_type_ids = torch.LongTensor(encoding1['token_type_ids']) |
|
|
|
token_ids2 = torch.LongTensor(encoding2['input_ids']) |
|
attention_mask2 = torch.LongTensor(encoding2['attention_mask']) |
|
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids']) |
|
if self.isRegression: |
|
labels = torch.DoubleTensor(labels) |
|
else: |
|
labels = torch.LongTensor(labels) |
|
|
|
return (token_ids, token_type_ids, attention_mask, |
|
token_ids2, token_type_ids2, attention_mask2, |
|
labels,sent_ids) |
|
|
|
def collate_fn(self, all_data): |
|
(token_ids, token_type_ids, attention_mask, |
|
token_ids2, token_type_ids2, attention_mask2, |
|
labels, sent_ids) = self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids_1': token_ids, |
|
'token_type_ids_1': token_type_ids, |
|
'attention_mask_1': attention_mask, |
|
'token_ids_2': token_ids2, |
|
'token_type_ids_2': token_type_ids2, |
|
'attention_mask_2': attention_mask2, |
|
'labels': labels, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
|
|
class SentencePairTestDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sent1 = [x[0] for x in data] |
|
sent2 = [x[1] for x in data] |
|
sent_ids = [x[2] for x in data] |
|
|
|
encoding1 = self.tokenizer(sent1, return_tensors='pt', padding=True, truncation=True) |
|
encoding2 = self.tokenizer(sent2, return_tensors='pt', padding=True, truncation=True) |
|
|
|
token_ids = torch.LongTensor(encoding1['input_ids']) |
|
attention_mask = torch.LongTensor(encoding1['attention_mask']) |
|
token_type_ids = torch.LongTensor(encoding1['token_type_ids']) |
|
|
|
token_ids2 = torch.LongTensor(encoding2['input_ids']) |
|
attention_mask2 = torch.LongTensor(encoding2['attention_mask']) |
|
token_type_ids2 = torch.LongTensor(encoding2['token_type_ids']) |
|
|
|
|
|
return (token_ids, token_type_ids, attention_mask, |
|
token_ids2, token_type_ids2, attention_mask2, |
|
sent_ids) |
|
|
|
def collate_fn(self, all_data): |
|
(token_ids, token_type_ids, attention_mask, |
|
token_ids2, token_type_ids2, attention_mask2, |
|
sent_ids) = self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids_1': token_ids, |
|
'token_type_ids_1': token_type_ids, |
|
'attention_mask_1': attention_mask, |
|
'token_ids_2': token_ids2, |
|
'token_type_ids_2': token_type_ids2, |
|
'attention_mask_2': attention_mask2, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
def load_multitask_data(sentiment_filename,paraphrase_filename,similarity_filename,split='train'): |
|
sentiment_data = [] |
|
num_labels = {} |
|
if split == 'test': |
|
with open(sentiment_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
sentiment_data.append((sent,sent_id)) |
|
else: |
|
with open(sentiment_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
label = int(record['sentiment'].strip()) |
|
if label not in num_labels: |
|
num_labels[label] = len(num_labels) |
|
sentiment_data.append((sent, label,sent_id)) |
|
|
|
print(f"Loaded {len(sentiment_data)} {split} examples from {sentiment_filename}") |
|
|
|
paraphrase_data = [] |
|
if split == 'test': |
|
with open(paraphrase_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
sent_id = record['id'].lower().strip() |
|
paraphrase_data.append((preprocess_string(record['sentence1']), |
|
preprocess_string(record['sentence2']), |
|
sent_id)) |
|
|
|
else: |
|
with open(paraphrase_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
try: |
|
sent_id = record['id'].lower().strip() |
|
paraphrase_data.append((preprocess_string(record['sentence1']), |
|
preprocess_string(record['sentence2']), |
|
int(float(record['is_duplicate'])),sent_id)) |
|
except: |
|
pass |
|
|
|
print(f"Loaded {len(paraphrase_data)} {split} examples from {paraphrase_filename}") |
|
|
|
similarity_data = [] |
|
if split == 'test': |
|
with open(similarity_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
sent_id = record['id'].lower().strip() |
|
similarity_data.append((preprocess_string(record['sentence1']), |
|
preprocess_string(record['sentence2']) |
|
,sent_id)) |
|
else: |
|
with open(similarity_filename, 'r') as fp: |
|
for record in csv.DictReader(fp,delimiter = '\t'): |
|
sent_id = record['id'].lower().strip() |
|
similarity_data.append((preprocess_string(record['sentence1']), |
|
preprocess_string(record['sentence2']), |
|
float(record['similarity']),sent_id)) |
|
|
|
print(f"Loaded {len(similarity_data)} {split} examples from {similarity_filename}") |
|
|
|
return sentiment_data, num_labels, paraphrase_data, similarity_data |
|
|