minBERT / unsup_simcse.py
GlowCheese's picture
retrain on last-linear-layer
dfcb0a5
raw
history blame
8.36 kB
import csv
import torch
import random
import argparse
import numpy as np
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
from torch import Tensor
from types import SimpleNamespace
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, accuracy_score
from bert import BertModel
from optimizer import AdamW
from classifier import seed_everything, tokenizer
from classifier import SentimentDataset, BertSentimentClassifier
TQDM_DISABLE = False
class AmazonDataset(Dataset):
def __init__(self, dataset, args):
self.dataset = dataset
self.p = args
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
def pad_data(self, data):
sents = [x[0] for x in data]
sent_ids = [x[1] for x in data]
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
token_ids = torch.LongTensor(encoding['input_ids'])
attension_mask = torch.LongTensor(encoding['attention_mask'])
return token_ids, attension_mask, sent_ids
def collate_fn(self, data):
token_ids, attention_mask, sent_ids = self.pad_data(data)
batched_data = {
'token_ids': token_ids,
'attention_mask': attention_mask,
'sent_ids': sent_ids
}
return batched_data
def load_data(filename, flag='train'):
'''
- for amazon dataset: list of (sent, sent_id)
- for test dataset: list of (sent, sent_id)
- for train dataset: list of (sent, label, sent_id)
'''
if flag == 'amazon':
df = pd.read_parquet(filename)
data = list(zip(df['content'], df.index))
else:
data, num_labels = [], set()
with open(filename, 'r') as fp:
if flag == 'test':
for record in csv.DictReader(fp, delimiter = '\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
data.append((sent,sent_id))
else:
for record in csv.DictReader(fp, delimiter = '\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
label = int(record['sentiment'].strip())
num_labels.add(label)
data.append((sent, label, sent_id))
print(f"load {len(data)} data from {filename}")
if flag in ['test', 'amazon']:
return data
else:
return data, len(num_labels)
def save_model(model, optimizer, args, config, filepath):
save_info = {
'model': model.state_dict(),
'optim': optimizer.state_dict(),
'args': args,
'model_config': config,
'system_rng': random.getstate(),
'numpy_rng': np.random.get_state(),
'torch_rng': torch.random.get_rng_state(),
}
torch.save(save_info, filepath)
print(f"save the model to {filepath}")
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
'''
embeds_1: [batch_size, hidden_size]
embeds_2: [batch_size, hidden_size]
'''
# [batch_size, batch_size]
sim_matrix = F.cosine_similarity(embeds_1.unsqueeze(1), embeds_2.unsqueeze(0), dim=-1) / temp
# [batch_size]
positive_sim = torch.diagonal(sim_matrix)
# [batch_size]
nume = torch.exp(positive_sim)
# [batch_size]
deno = torch.exp(sim_matrix).sum(1)
# [batch_size]
loss_per_batch = -torch.log(nume / deno)
return loss_per_batch.mean()
def train(args):
'''
Training Pipeline
-----------------
1. Load the Amazon Polarity and SST Dataset.
2. Determine batch_size (64) and number of batches (?).
3. Initialize SentimentClassifier (including bert).
4. Looping through 10 epoches.
5. Finetune minBERT with SimCSE loss function.
6. Finetune Classifier with cross-entropy function.
7. Backpropagation using Adam Optimizer for both.
8. Evaluating the model on dev dataset.
9. If dev_acc > best_dev_acc: save_model(...)
'''
amazon_data = load_data(args.train_bert, 'amazon')
train_data, num_labels = load_data(args.train, 'train')
dev_data = load_data(args.dev, 'valid')
amazon_dataset = AmazonDataset(amazon_data, args)
train_dataset = SentimentDataset(train_data, args)
dev_dataset = SentimentDataset(dev_data, args)
amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn)
config = SimpleNamespace(
hidden_dropout_prob=args.hidden_dropout_prob,
num_labels=num_labels,
hidden_size=768,
data_dir='.',
fine_tune_mode='full-model'
)
device = torch.device('cuda') if args.use_gpu else torch.device('cpu')
model = BertSentimentClassifier(config)
model = model.to(device)
optimizer_cse = AdamW(model.bert.parameters(), lr=args.lr_cse)
optimizer_classifier = AdamW(model.parameters(), lr=args.lr_classifier)
best_dev_acc = 0
# ---- Training minBERT using SimCSE ---- #
for epoch in range(args.epochs):
model.bert.train()
train_loss = num_batches = 0
for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
b_ids = b_ids.to(device)
b_mask = b_mask.to(device)
# Get different embeddings with different dropout masks
logits_1 = model.bert(b_ids, b_mask)['pooler_output']
logits_2 = model.bert(b_ids, b_mask)['pooler_output']
# Calculate mean SimCSE loss function
loss = contrastive_loss(logits_1, logits_2)
# Back propagation
optimizer_cse.zero_grad()
loss.backward()
optimizer_cse.step()
train_loss += loss.item()
num_batches += 1
train_loss = train_loss / num_batches
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=11711)
parser.add_argument("--num-cpu-cores", type=int, default=8)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--use_gpu", action='store_true')
parser.add_argument("--batch_size_cse", type=int, default=8)
parser.add_argument("--batch_size_sst", type=int, default=64)
parser.add_argument("--batch_size_cfimdb", type=int, default=8)
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
parser.add_argument("--lr_cse", type=float, default=1e-5)
parser.add_argument("--lr_classifier", type=float, default=1e-5)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
seed_everything(args.seed)
torch.set_num_threads(args.num_cpu_cores)
print('Finetuning minBERT with Unsupervised SimCSE...')
config = SimpleNamespace(
filepath='contrastive-nli.pt',
lr_cse=args.lr_cse,
lr_classifier=args.lr_classifier,
num_cpu_cores=args.num_cpu_cores,
use_gpu=args.use_gpu,
epochs=args.epochs,
batch_size_cse=args.batch_size_cse,
batch_size_classifier=args.batch_size_sst,
hidden_dropout_prob=args.hidden_dropout_prob,
train_bert='data/amazon-polarity.parquet',
train='data/ids-sst-train.csv',
dev='data/ids-sst-dev.csv',
test='data/ids-sst-test-student.csv'
)
train(config)
# model = BertModel.from_pretrained('bert-base-uncased')
# model.eval()
# s = set()
# for param in model.parameters():
# s.add(param.requires_grad)
# print(s)