|
import random, numpy as np, argparse |
|
from types import SimpleNamespace |
|
import csv |
|
|
|
import torch |
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from sklearn.metrics import f1_score, accuracy_score |
|
|
|
from tokenizer import BertTokenizer |
|
from bert import BertModel |
|
from optimizer import AdamW |
|
|
|
|
|
TQDM_DISABLE=False |
|
|
|
|
|
|
|
def seed_everything(seed=11711): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
class BertSentimentClassifier(torch.nn.Module): |
|
''' |
|
This module performs sentiment classification using BERT embeddings on the SST dataset. |
|
|
|
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive"). |
|
Thus, your forward() should return one logit for each of the 5 classes. |
|
''' |
|
def __init__(self, config, bert_model = None): |
|
super(BertSentimentClassifier, self).__init__() |
|
self.num_labels = config.num_labels |
|
self.bert: BertModel = bert_model or BertModel.from_pretrained('bert-base-uncased') |
|
|
|
|
|
assert config.fine_tune_mode in ["last-linear-layer", "full-model"] |
|
for param in self.bert.parameters(): |
|
if config.fine_tune_mode == 'last-linear-layer': |
|
param.requires_grad = False |
|
elif config.fine_tune_mode == 'full-model': |
|
param.requires_grad = True |
|
|
|
|
|
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
'''Takes a batch of sentences and returns logits for sentiment classes''' |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.bert(input_ids, attention_mask) |
|
pooler_output = outputs['pooler_output'] |
|
|
|
|
|
logits = self.classifier(self.dropout(pooler_output)) |
|
|
|
return logits |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
class SentimentDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
labels = [x[1] for x in data] |
|
sent_ids = [x[2] for x in data] |
|
|
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
labels = torch.LongTensor(labels) |
|
|
|
return token_ids, attention_mask, labels, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'labels': labels, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
class SentimentTestDataset(Dataset): |
|
def __init__(self, dataset, args): |
|
self.dataset = dataset |
|
self.p = args |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx] |
|
|
|
def pad_data(self, data): |
|
sents = [x[0] for x in data] |
|
sent_ids = [x[1] for x in data] |
|
|
|
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True) |
|
token_ids = torch.LongTensor(encoding['input_ids']) |
|
attention_mask = torch.LongTensor(encoding['attention_mask']) |
|
|
|
return token_ids, attention_mask, sents, sent_ids |
|
|
|
def collate_fn(self, all_data): |
|
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data) |
|
|
|
batched_data = { |
|
'token_ids': token_ids, |
|
'attention_mask': attention_mask, |
|
'sents': sents, |
|
'sent_ids': sent_ids |
|
} |
|
|
|
return batched_data |
|
|
|
|
|
|
|
def load_data(filename, flag='train'): |
|
num_labels = set() |
|
data = [] |
|
with open(filename, 'r') as fp: |
|
for record in csv.DictReader(fp, delimiter = '\t'): |
|
if flag == 'test': |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
data.append((sent,sent_id)) |
|
else: |
|
sent = record['sentence'].lower().strip() |
|
sent_id = record['id'].lower().strip() |
|
label = int(record['sentiment'].strip()) |
|
num_labels.add(label) |
|
data.append((sent, label, sent_id)) |
|
print(f"load {len(data)} data from {filename}") |
|
|
|
if flag == 'train': |
|
return data, len(num_labels) |
|
else: |
|
return data |
|
|
|
|
|
|
|
def model_eval(dataloader, model, device): |
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
sents = [] |
|
sent_ids = [] |
|
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
|
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \ |
|
batch['labels'], batch['sents'], batch['sent_ids'] |
|
|
|
b_ids = b_ids.to(device) |
|
b_mask = b_mask.to(device) |
|
|
|
logits = model(b_ids, b_mask) |
|
logits = logits.detach().cpu().numpy() |
|
preds = np.argmax(logits, axis=1).flatten() |
|
|
|
b_labels = b_labels.flatten() |
|
y_true.extend(b_labels) |
|
y_pred.extend(preds) |
|
sents.extend(b_sents) |
|
sent_ids.extend(b_sent_ids) |
|
|
|
f1 = f1_score(y_true, y_pred, average='macro') |
|
acc = accuracy_score(y_true, y_pred) |
|
|
|
return acc, f1, y_pred, y_true, sents, sent_ids |
|
|
|
|
|
|
|
def model_test_eval(dataloader, model, device): |
|
model.eval() |
|
y_pred = [] |
|
sents = [] |
|
sent_ids = [] |
|
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', leave=False, disable=TQDM_DISABLE)): |
|
b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \ |
|
batch['sents'], batch['sent_ids'] |
|
|
|
b_ids = b_ids.to(device) |
|
b_mask = b_mask.to(device) |
|
|
|
logits = model(b_ids, b_mask) |
|
logits = logits.detach().cpu().numpy() |
|
preds = np.argmax(logits, axis=1).flatten() |
|
|
|
y_pred.extend(preds) |
|
sents.extend(b_sents) |
|
sent_ids.extend(b_sent_ids) |
|
|
|
return y_pred, sents, sent_ids |
|
|
|
|
|
def save_model(model, optimizer, args, config, filepath): |
|
save_info = { |
|
'model': model.state_dict(), |
|
'optim': optimizer.state_dict(), |
|
'args': args, |
|
'model_config': config, |
|
'system_rng': random.getstate(), |
|
'numpy_rng': np.random.get_state(), |
|
'torch_rng': torch.random.get_rng_state(), |
|
} |
|
|
|
torch.save(save_info, filepath) |
|
print(f"save the model to {filepath}") |
|
|
|
|
|
def train(args): |
|
device = torch.device('cuda') if args.use_gpu else torch.device('cpu') |
|
|
|
train_data, num_labels = load_data(args.train, 'train') |
|
dev_data = load_data(args.dev, 'valid') |
|
|
|
train_dataset = SentimentDataset(train_data, args) |
|
dev_dataset = SentimentDataset(dev_data, args) |
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, |
|
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
|
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn) |
|
|
|
|
|
config = {'hidden_dropout_prob': args.hidden_dropout_prob, |
|
'num_labels': num_labels, |
|
'hidden_size': 768, |
|
'data_dir': '.', |
|
'fine_tune_mode': args.fine_tune_mode} |
|
|
|
config = SimpleNamespace(**config) |
|
|
|
model = BertSentimentClassifier(config) |
|
model = model.to(device) |
|
|
|
lr = args.lr |
|
optimizer = AdamW(model.parameters(), lr=lr) |
|
best_dev_acc = 0 |
|
|
|
|
|
for epoch in range(args.epochs): |
|
model.train() |
|
train_loss = 0 |
|
num_batches = 0 |
|
for batch in tqdm(train_dataloader, desc=f'train-{epoch}', leave=False, disable=TQDM_DISABLE): |
|
b_ids, b_mask, b_labels = (batch['token_ids'], |
|
batch['attention_mask'], batch['labels']) |
|
|
|
b_ids = b_ids.to(device) |
|
b_mask = b_mask.to(device) |
|
b_labels = b_labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
logits = model(b_ids, b_mask) |
|
loss = F.cross_entropy(logits, b_labels.view(-1), reduction='sum') / args.batch_size |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
train_loss += loss.item() |
|
num_batches += 1 |
|
|
|
train_loss = train_loss / (num_batches) |
|
|
|
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device) |
|
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device) |
|
|
|
if dev_acc > best_dev_acc: |
|
best_dev_acc = dev_acc |
|
save_model(model, optimizer, args, config, args.filepath) |
|
|
|
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}") |
|
|
|
|
|
def test(args): |
|
with torch.no_grad(): |
|
device = torch.device('cuda') if args.use_gpu else torch.device('cpu') |
|
saved = torch.load(args.filepath, weights_only=False) |
|
config = saved['model_config'] |
|
model = BertSentimentClassifier(config) |
|
model.load_state_dict(saved['model']) |
|
model = model.to(device) |
|
print(f"load model from {args.filepath}") |
|
|
|
dev_data = load_data(args.dev, 'valid') |
|
dev_dataset = SentimentDataset(dev_data, args) |
|
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, |
|
num_workers=args.num_cpu_cores, collate_fn=dev_dataset.collate_fn) |
|
|
|
dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device) |
|
print('DONE DEV') |
|
print(f"dev acc :: {dev_acc :.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=11711) |
|
parser.add_argument("--num-cpu-cores", type=int, default=4) |
|
parser.add_argument("--epochs", type=int, default=10) |
|
parser.add_argument("--fine-tune-mode", type=str, |
|
help='last-linear-layer: the BERT parameters are frozen and the task specific head parameters are updated; full-model: BERT parameters are updated as well', |
|
choices=('last-linear-layer', 'full-model'), default="last-linear-layer") |
|
parser.add_argument("--use_gpu", action='store_true') |
|
|
|
parser.add_argument("--batch_size_sst", help='64 can fit a 12GB GPU', type=int, default=8) |
|
parser.add_argument("--batch_size_cfimdb", help='8 can fit a 12GB GPU', type=int, default=8) |
|
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3) |
|
parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5", |
|
default=1e-3) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
seed_everything(args.seed) |
|
torch.set_num_threads(args.num_cpu_cores) |
|
|
|
print(torch.get_num_threads()) |
|
|
|
print('Training Sentiment Classifier on SST...') |
|
config = SimpleNamespace( |
|
filepath='sst-classifier.pt', |
|
lr=args.lr, |
|
num_cpu_cores=args.num_cpu_cores, |
|
use_gpu=args.use_gpu, |
|
epochs=args.epochs, |
|
batch_size=args.batch_size_sst, |
|
hidden_dropout_prob=args.hidden_dropout_prob, |
|
train='data/ids-sst-train.csv', |
|
dev='data/ids-sst-dev.csv', |
|
test='data/ids-sst-test-student.csv', |
|
fine_tune_mode=args.fine_tune_mode, |
|
dev_out = 'predictions/' + args.fine_tune_mode + '-sst-dev-out.csv', |
|
test_out = 'predictions/' + args.fine_tune_mode + '-sst-test-out.csv' |
|
) |
|
|
|
train(config) |
|
|
|
print('Evaluating on SST...') |
|
test(config) |
|
|
|
print('Training Sentiment Classifier on cfimdb...') |
|
config = SimpleNamespace( |
|
filepath='cfimdb-classifier.pt', |
|
lr=args.lr, |
|
num_cpu_cores=args.num_cpu_cores, |
|
use_gpu=args.use_gpu, |
|
epochs=args.epochs, |
|
batch_size=args.batch_size_cfimdb, |
|
hidden_dropout_prob=args.hidden_dropout_prob, |
|
train='data/ids-cfimdb-train.csv', |
|
dev='data/ids-cfimdb-dev.csv', |
|
test='data/ids-cfimdb-test-student.csv', |
|
fine_tune_mode=args.fine_tune_mode, |
|
dev_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-dev-out.csv', |
|
test_out = 'predictions/' + args.fine_tune_mode + '-cfimdb-test-out.csv' |
|
) |
|
|
|
train(config) |
|
|
|
print('Evaluating on cfimdb...') |
|
test(config) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |