|
import csv |
|
import torch |
|
import random |
|
import argparse |
|
import numpy as np |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
|
|
from tqdm import tqdm |
|
from torch import Tensor |
|
from types import SimpleNamespace |
|
from torch.utils.data import Dataset, DataLoader |
|
from sklearn.metrics import f1_score, accuracy_score |
|
|
|
from bert import BertModel |
|
from optimizer import AdamW |
|
from classifier import seed_everything, tokenizer |
|
from classifier import SentimentDataset, BertSentimentClassifier |
|
|
|
|
|
TQDM_DISABLE = False |
|
|
|
|
|
class AmazonDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
sent_ids = [x[1] for x in data] |
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attension_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attension_mask, sent_ids |
|
|
|
def collate_fn(self, data): |
|
token_ids, attention_mask, sent_ids = self.pad_data(data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
def load_data(filename, flag='train'): |
|
''' |
|
- for amazon dataset: list of (sent, sent_id) |
|
- for test dataset: list of (sent, sent_id) |
|
- for train dataset: list of (sent, label, sent_id) |
|
''' |
|
|
|
if flag == 'amazon': |
|
df = pd.read_parquet(filename) |
|
data = list(zip(df['content'], df.index)) |
|
else: |
|
data, num_labels = [], set() |
|
|
|
with open(filename, 'r') as fp: |
|
if flag == 'test': |
|
for record in csv.DictReader(fp, delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
data.append((sent,sent_id)) |
|
else: |
|
for record in csv.DictReader(fp, delimiter = '\t'): |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
label = int(record['sentiment'].strip()) |
|
num_labels.add(label) |
|
data.append((sent, label, sent_id)) |
|
|
|
print(f"load {len(data)} data from {filename}") |
|
if flag in ['test', 'amazon']: |
|
return data |
|
else: |
|
return data, len(num_labels) |
|
|
|
|
|
def save_model(model, optimizer, args, config, filepath): |
|
save_info = { |
|
'model': model.state_dict(), |
|
'optim': optimizer.state_dict(), |
|
'args': args, |
|
'model_config': config, |
|
'system_rng': random.getstate(), |
|
'numpy_rng': np.random.get_state(), |
|
'torch_rng': torch.random.get_rng_state(), |
|
} |
|
|
|
torch.save(save_info, filepath) |
|
print(f"save the model to {filepath}") |
|
|
|
|
|
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05): |
|
''' |
|
embeds_1: [batch_size, hidden_size] |
|
embeds_2: [batch_size, hidden_size] |
|
''' |
|
|
|
|
|
sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp |
|
|
|
|
|
positive_sim = torch.diagonal(sim_matrix) |
|
|
|
|
|
nume = torch.exp(positive_sim) |
|
|
|
|
|
deno = torch.exp(sim_matrix).sum(1) |
|
|
|
|
|
loss_per_batch = -torch.log(nume / deno) |
|
|
|
return loss_per_batch.mean() |
|
|
|
|
|
def train(args): |
|
''' |
|
Training Pipeline |
|
----------------- |
|
1. Load the Amazon Polarity and SST Dataset. |
|
2. Determine batch_size (64) and number of batches (?). |
|
3. Initialize SentimentClassifier (including bert). |
|
4. Looping through 10 epoches. |
|
5. Finetune minBERT with SimCSE loss function. |
|
6. Finetune Classifier with cross-entropy function. |
|
7. Backpropagation using Adam Optimizer for both. |
|
8. Evaluating the model on dev dataset. |
|
9. If dev_acc > best_dev_acc: save_model(...) |
|
''' |
|
|
|
amazon_data = load_data(args.train_bert, 'amazon') |
|
train_data, num_labels = load_data(args.train, 'train') |
|
dev_data = load_data(args.dev, 'valid') |
|
|
|
amazon_dataset = AmazonDataset(amazon_data, args) |
|
train_dataset = SentimentDataset(train_data, args) |
|
dev_dataset = SentimentDataset(dev_data, args) |
|
|
|
amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse, |
|
num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn) |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier, |
|
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier, |
|
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn) |
|
|
|
config = SimpleNamespace( |
|
hidden_dropout_prob=args.hidden_dropout_prob, |
|
num_labels=num_labels, |
|
hidden_size=768, |
|
data_dir='.', |
|
fine_tune_mode='full-model' |
|
) |
|
|
|
device = torch.device('cuda') if args.use_gpu else torch.device('cpu') |
|
model = BertSentimentClassifier(config) |
|
model = model.to(device) |
|
|
|
optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse) |
|
optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier) |
|
best_dev_acc = 0 |
|
|
|
|
|
for epoch in range(args.epochs): |
|
model.bert.train() |
|
train_loss = num_batches = 0 |
|
for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE): |
|
b_ids, b_mask = batch['token_ids'], batch['attention_mask'] |
|
b_ids = b_ids.to(device) |
|
b_mask = b_mask.to(device) |
|
|
|
|
|
logits_1 = model.bert(b_ids, b_mask)['pooler_output'] |
|
logits_2 = model.bert(b_ids, b_mask)['pooler_output'] |
|
|
|
|
|
loss = contrastive_loss(logits_1, logits_2) |
|
|
|
|
|
optimizer_cse.zero_grad() |
|
loss.backward() |
|
optimizer_cse.step() |
|
|
|
train_loss += loss.item() |
|
num_batches += 1 |
|
|
|
train_loss = train_loss / num_batches |
|
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}") |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=11711) |
|
parser.add_argument("--num-cpu-cores", type=int, default=8) |
|
parser.add_argument("--epochs", type=int, default=10) |
|
parser.add_argument("--use_gpu", action='store_true') |
|
parser.add_argument("--batch_size_cse", type=int, default=8) |
|
parser.add_argument("--batch_size_sst", type=int, default=64) |
|
parser.add_argument("--batch_size_cfimdb", type=int, default=8) |
|
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3) |
|
parser.add_argument("--lr_cse", type=float, default=1e-5) |
|
parser.add_argument("--lr_classifier", type=float, default=1e-5) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
seed_everything(args.seed) |
|
torch.set_num_threads(args.num_cpu_cores) |
|
|
|
print('Finetuning minBERT with Unsupervised SimCSE...') |
|
config = SimpleNamespace( |
|
filepath='contrastive-nli.pt', |
|
lr_cse=args.lr_cse, |
|
lr_classifier=args.lr_classifier, |
|
num_cpu_cores=args.num_cpu_cores, |
|
use_gpu=args.use_gpu, |
|
epochs=args.epochs, |
|
batch_size_cse=args.batch_size_cse, |
|
batch_size_classifier=args.batch_size_sst, |
|
hidden_dropout_prob=args.hidden_dropout_prob, |
|
train_bert='data/amazon-polarity.parquet', |
|
train='data/ids-sst-train.csv', |
|
dev='data/ids-sst-dev.csv', |
|
test='data/ids-sst-test-student.csv' |
|
) |
|
|
|
train(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|