|
import csv |
|
import torch |
|
import random |
|
import argparse |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
from types import SimpleNamespace |
|
from torch.utils.data import Dataset, DataLoader |
|
from sklearn.metrics import f1_score, accuracy_score |
|
|
|
from bert import BertModel |
|
from optimizer import AdamW |
|
from classifier import seed_everything, tokenizer |
|
from classifier import SentimentDataset, BertSentimentClassifier |
|
|
|
|
|
TQDM_DISABLE = False |
|
|
|
|
|
class TwitterDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, sents): |
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attension_mask |
|
|
|
def collate_fn(self, sents): |
|
token_ids, attention_mask = self.pad_data(sents) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
def load_data(filename, flag='train'): |
|
''' |
|
- for Twitter dataset: list of sentences |
|
- for SST/CFIMDB dataset: list of (sent, [label], sent_id) |
|
''' |
|
num_labels = set() |
|
data = [] |
|
with open(filename, 'r') as fp: |
|
for record in csv.DictReader(fp, delimiter = ',', ): |
|
if flag == 'twitter': |
|
sent = record['clean_text'].lower().strip() |
|
data.append(sent) |
|
elif flag == 'test': |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
data.append((sent,sent_id)) |
|
else: |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
label = int(record['sentiment'].strip()) |
|
num_labels.add(label) |
|
data.append((sent, label, sent_id)) |
|
print(f"load {len(data)} data from {filename}") |
|
|
|
if flag == 'train': |
|
return data, len(num_labels) |
|
else: |
|
return data |
|
|
|
|
|
def save_model(model, optimizer, args, config, filepath): |
|
save_info = { |
|
'model': model.state_dict(), |
|
'optim': optimizer.state_dict(), |
|
'args': args, |
|
'model_config': config, |
|
'system_rng': random.getstate(), |
|
'numpy_rng': np.random.get_state(), |
|
'torch_rng': torch.random.get_rng_state(), |
|
} |
|
|
|
torch.save(save_info, filepath) |
|
print(f"save the model to {filepath}") |
|
|
|
|
|
def train(args): |
|
''' |
|
Training Pipeline |
|
----------------- |
|
1. Load the Twitter Sentiment and SST Dataset. |
|
2. Determine batch_size (64) and number of batches (?). |
|
3. Initialize SentimentClassifier (including bert). |
|
4. Looping through 10 epoches. |
|
5. Finetune minBERT with SimCSE loss function. |
|
6. Finetune Classifier with cross-entropy function. |
|
7. Backpropagation using Adam Optimizer for both. |
|
8. Evaluating the model on dev dataset. |
|
9. If dev_acc > best_dev_acc: save_model(...) |
|
''' |
|
|
|
twitter_data = load_data(args.train_bert, 'twitter') |
|
train_data, num_labels = load_data(args.train, 'train') |
|
dev_data = load_data(args.dev, 'valid') |
|
|
|
twitter_dataset = TwitterDataset(twitter_data, args) |
|
train_dataset = SentimentDataset(train_data, args) |
|
dev_dataset = SentimentDataset(dev_data, args) |
|
|
|
twitter_dataloader = DataLoader(twitter_dataset, shuffle=True, batch_size=args.batch_size_cse, |
|
num_workers=args.num_cpu_cores, collate_fn=twitter_dataset.collate_fn) |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier, |
|
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier, |
|
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn) |
|
|
|
config = SimpleNamespace( |
|
hidden_dropout_prob=args.hidden_dropout_prob, |
|
num_labels=num_labels, |
|
hidden_size=768, |
|
data_dir='.', |
|
fine_tune_mode='full-model' |
|
) |
|
|
|
device = torch.device('cuda') if args.use_gpu else torch.device('cpu') |
|
model = BertSentimentClassifier(config) |
|
model = model.to(device) |
|
|
|
optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse) |
|
optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier) |
|
best_dev_acc = 0 |
|
|
|
for epoch in range(args.epochs): |
|
model.bert.train() |
|
train_loss = num_batches = 0 |
|
for batch in tqdm(twitter_dataloader, f'train-twitter-{epoch}', leave=False, disable=TQDM_DISABLE): |
|
b_ids, b_mask = batch['token_ids'], batch['attention_mask'] |
|
b_ids = b_ids.to(device) |
|
b_mask = b_mask.to(device) |
|
|
|
optimizer_cse.zero_grad() |
|
logits = model.bert.embed(b_ids) |
|
logits = model.bert.encode(logits, b_mask) |
|
|
|
|
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=11711) |
|
parser.add_argument("--num-cpu-cores", type=int, default=4) |
|
parser.add_argument("--epochs", type=int, default=10) |
|
parser.add_argument("--use_gpu", action='store_true') |
|
parser.add_argument("--batch_size_cse", help="'unsup': 64, 'sup': 512", type=int) |
|
parser.add_argument("--batch_size_classifier", help="'sst': 64, 'cfimdb': 8", type=int) |
|
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3) |
|
parser.add_argument("--lr_cse", default=2e-5) |
|
parser.add_argument("--lr_classifier", default=1e-5) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
seed_everything(args.seed) |
|
torch.set_num_threads(args.num_cpu_cores) |
|
|
|
print('Finetuning minBERT with Unsupervised SimCSE...') |
|
config = SimpleNamespace( |
|
filepath='contrastive-nli.pt', |
|
lr=args.lr, |
|
num_cpu_cores=args.num_cpu_cores, |
|
use_gpu=args.use_gpu, |
|
epochs=args.epochs, |
|
batch_size_cse=args.batch_size_cse, |
|
batch_size_classifier=args.batch_size_classifier, |
|
train_bert='data/twitter-unsup.csv', |
|
train='data/ids-sst-train.csv', |
|
dev='data/ids-sst-dev.csv', |
|
test='data/ids-sst-test-student.csv', |
|
dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv', |
|
test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv' |
|
) |
|
|
|
train(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|