Commit
·
897eef4
1
Parent(s):
02b98f5
contrastive commit 3
Browse files
data/{twitter-unsup.csv → amazon-polarity.parquet}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbe4770cfa6be45add6c9a322044bd4c1901520dde5a2707eca402a74fbe854e
|
3 |
+
size 870289
|
unsup_simcse.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import random
|
4 |
import argparse
|
5 |
import numpy as np
|
|
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
from tqdm import tqdm
|
@@ -20,7 +21,7 @@ from classifier import SentimentDataset, BertSentimentClassifier
|
|
20 |
TQDM_DISABLE = False
|
21 |
|
22 |
|
23 |
-
class
|
24 |
def __init__(self, dataset, args):
|
25 |
self.dataset = dataset
|
26 |
self.p = args
|
@@ -31,19 +32,22 @@ class TwitterDataset(Dataset):
|
|
31 |
def __getitem__(self, idx):
|
32 |
return self.dataset[idx]
|
33 |
|
34 |
-
def pad_data(self,
|
|
|
|
|
35 |
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
36 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
37 |
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
38 |
|
39 |
-
return token_ids, attension_mask
|
40 |
|
41 |
-
def collate_fn(self,
|
42 |
-
token_ids, attention_mask = self.pad_data(
|
43 |
|
44 |
batched_data = {
|
45 |
'token_ids': token_ids,
|
46 |
'attention_mask': attention_mask,
|
|
|
47 |
}
|
48 |
|
49 |
return batched_data
|
@@ -51,36 +55,36 @@ class TwitterDataset(Dataset):
|
|
51 |
|
52 |
def load_data(filename, flag='train'):
|
53 |
'''
|
54 |
-
- for
|
55 |
-
- for
|
|
|
56 |
'''
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
if flag == 'twitter':
|
62 |
-
for cnt, record in enumerate(csv.DictReader(fp, delimiter = ',')):
|
63 |
-
sent = record['clean_text'].lower().strip()
|
64 |
-
data.append(sent)
|
65 |
-
if cnt == 10000: break
|
66 |
-
elif flag == 'test':
|
67 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
68 |
-
sent = record['sentence'].lower().strip()
|
69 |
-
sent_id = record['id'].lower().strip()
|
70 |
-
data.append((sent,sent_id))
|
71 |
-
else:
|
72 |
-
for record in csv.DictReader(fp, delimiter = '\t'):
|
73 |
-
sent = record['sentence'].lower().strip()
|
74 |
-
sent_id = record['id'].lower().strip()
|
75 |
-
label = int(record['sentiment'].strip())
|
76 |
-
num_labels.add(label)
|
77 |
-
data.append((sent, label, sent_id))
|
78 |
-
print(f"load {len(data)} data from {filename}")
|
79 |
-
|
80 |
-
if flag == 'train':
|
81 |
-
return data, len(num_labels)
|
82 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return data
|
|
|
|
|
84 |
|
85 |
|
86 |
def save_model(model, optimizer, args, config, filepath):
|
@@ -98,11 +102,6 @@ def save_model(model, optimizer, args, config, filepath):
|
|
98 |
print(f"save the model to {filepath}")
|
99 |
|
100 |
|
101 |
-
# def model_eval(dataloader, model, device):
|
102 |
-
# model.eval()
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
107 |
'''
|
108 |
embeds_1: [batch_size, hidden_size]
|
@@ -131,7 +130,7 @@ def train(args):
|
|
131 |
'''
|
132 |
Training Pipeline
|
133 |
-----------------
|
134 |
-
1. Load the
|
135 |
2. Determine batch_size (64) and number of batches (?).
|
136 |
3. Initialize SentimentClassifier (including bert).
|
137 |
4. Looping through 10 epoches.
|
@@ -142,16 +141,16 @@ def train(args):
|
|
142 |
9. If dev_acc > best_dev_acc: save_model(...)
|
143 |
'''
|
144 |
|
145 |
-
|
146 |
train_data, num_labels = load_data(args.train, 'train')
|
147 |
dev_data = load_data(args.dev, 'valid')
|
148 |
|
149 |
-
|
150 |
train_dataset = SentimentDataset(train_data, args)
|
151 |
dev_dataset = SentimentDataset(dev_data, args)
|
152 |
|
153 |
-
|
154 |
-
num_workers=args.num_cpu_cores, collate_fn=
|
155 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
156 |
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
157 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
@@ -177,7 +176,7 @@ def train(args):
|
|
177 |
for epoch in range(args.epochs):
|
178 |
model.bert.train()
|
179 |
train_loss = num_batches = 0
|
180 |
-
for batch in tqdm(
|
181 |
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
182 |
b_ids = b_ids.to(device)
|
183 |
b_mask = b_mask.to(device)
|
@@ -189,11 +188,13 @@ def train(args):
|
|
189 |
# Calculate mean SimCSE loss function
|
190 |
loss = contrastive_loss(logits_1, logits_2)
|
191 |
|
|
|
|
|
192 |
loss.backward()
|
193 |
optimizer_cse.step()
|
194 |
|
195 |
train_loss += loss.item()
|
196 |
-
num_batches +=
|
197 |
|
198 |
train_loss = train_loss / num_batches
|
199 |
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
|
@@ -205,11 +206,12 @@ def get_args():
|
|
205 |
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
206 |
parser.add_argument("--epochs", type=int, default=10)
|
207 |
parser.add_argument("--use_gpu", action='store_true')
|
208 |
-
parser.add_argument("--batch_size_cse",
|
209 |
-
parser.add_argument("--
|
|
|
210 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
211 |
-
parser.add_argument("--lr_cse", default=2e-5)
|
212 |
-
parser.add_argument("--lr_classifier", default=1e-5)
|
213 |
|
214 |
args = parser.parse_args()
|
215 |
return args
|
@@ -229,9 +231,9 @@ if __name__ == "__main__":
|
|
229 |
use_gpu=args.use_gpu,
|
230 |
epochs=args.epochs,
|
231 |
batch_size_cse=args.batch_size_cse,
|
232 |
-
batch_size_classifier=args.
|
233 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
234 |
-
train_bert='data/
|
235 |
train='data/ids-sst-train.csv',
|
236 |
dev='data/ids-sst-dev.csv',
|
237 |
test='data/ids-sst-test-student.csv'
|
|
|
3 |
import random
|
4 |
import argparse
|
5 |
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
import torch.nn.functional as F
|
8 |
|
9 |
from tqdm import tqdm
|
|
|
21 |
TQDM_DISABLE = False
|
22 |
|
23 |
|
24 |
+
class AmazonDataset(Dataset):
|
25 |
def __init__(self, dataset, args):
|
26 |
self.dataset = dataset
|
27 |
self.p = args
|
|
|
32 |
def __getitem__(self, idx):
|
33 |
return self.dataset[idx]
|
34 |
|
35 |
+
def pad_data(self, data):
|
36 |
+
sents = [x[0] for x in data]
|
37 |
+
sent_ids = [x[1] for x in data]
|
38 |
encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
|
39 |
token_ids = torch.LongTensor(encoding['input_ids'])
|
40 |
attension_mask = torch.LongTensor(encoding['attention_mask'])
|
41 |
|
42 |
+
return token_ids, attension_mask, sent_ids
|
43 |
|
44 |
+
def collate_fn(self, data):
|
45 |
+
token_ids, attention_mask, sent_ids = self.pad_data(data)
|
46 |
|
47 |
batched_data = {
|
48 |
'token_ids': token_ids,
|
49 |
'attention_mask': attention_mask,
|
50 |
+
'sent_ids': sent_ids
|
51 |
}
|
52 |
|
53 |
return batched_data
|
|
|
55 |
|
56 |
def load_data(filename, flag='train'):
|
57 |
'''
|
58 |
+
- for amazon dataset: list of (sent, sent_id)
|
59 |
+
- for test dataset: list of (sent, sent_id)
|
60 |
+
- for train dataset: list of (sent, label, sent_id)
|
61 |
'''
|
62 |
|
63 |
+
if flag == 'amazon':
|
64 |
+
df = pd.read_parquet(filename)
|
65 |
+
data = list(zip(df['content'], df.index))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
else:
|
67 |
+
data, num_labels = [], set()
|
68 |
+
|
69 |
+
with open(filename, 'r') as fp:
|
70 |
+
if flag == 'test':
|
71 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
72 |
+
sent = record['sentence'].lower().strip()
|
73 |
+
sent_id = record['id'].lower().strip()
|
74 |
+
data.append((sent,sent_id))
|
75 |
+
else:
|
76 |
+
for record in csv.DictReader(fp, delimiter = '\t'):
|
77 |
+
sent = record['sentence'].lower().strip()
|
78 |
+
sent_id = record['id'].lower().strip()
|
79 |
+
label = int(record['sentiment'].strip())
|
80 |
+
num_labels.add(label)
|
81 |
+
data.append((sent, label, sent_id))
|
82 |
+
|
83 |
+
print(f"load {len(data)} data from {filename}")
|
84 |
+
if flag in ['test', 'amazon']:
|
85 |
return data
|
86 |
+
else:
|
87 |
+
return data, len(num_labels)
|
88 |
|
89 |
|
90 |
def save_model(model, optimizer, args, config, filepath):
|
|
|
102 |
print(f"save the model to {filepath}")
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
|
106 |
'''
|
107 |
embeds_1: [batch_size, hidden_size]
|
|
|
130 |
'''
|
131 |
Training Pipeline
|
132 |
-----------------
|
133 |
+
1. Load the Amazon Polarity and SST Dataset.
|
134 |
2. Determine batch_size (64) and number of batches (?).
|
135 |
3. Initialize SentimentClassifier (including bert).
|
136 |
4. Looping through 10 epoches.
|
|
|
141 |
9. If dev_acc > best_dev_acc: save_model(...)
|
142 |
'''
|
143 |
|
144 |
+
amazon_data = load_data(args.train_bert, 'amazon')
|
145 |
train_data, num_labels = load_data(args.train, 'train')
|
146 |
dev_data = load_data(args.dev, 'valid')
|
147 |
|
148 |
+
amazon_dataset = AmazonDataset(amazon_data, args)
|
149 |
train_dataset = SentimentDataset(train_data, args)
|
150 |
dev_dataset = SentimentDataset(dev_data, args)
|
151 |
|
152 |
+
amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
|
153 |
+
num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
|
154 |
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
|
155 |
num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
|
156 |
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
|
|
|
176 |
for epoch in range(args.epochs):
|
177 |
model.bert.train()
|
178 |
train_loss = num_batches = 0
|
179 |
+
for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
|
180 |
b_ids, b_mask = batch['token_ids'], batch['attention_mask']
|
181 |
b_ids = b_ids.to(device)
|
182 |
b_mask = b_mask.to(device)
|
|
|
188 |
# Calculate mean SimCSE loss function
|
189 |
loss = contrastive_loss(logits_1, logits_2)
|
190 |
|
191 |
+
# Back propagation
|
192 |
+
optimizer_cse.zero_grad()
|
193 |
loss.backward()
|
194 |
optimizer_cse.step()
|
195 |
|
196 |
train_loss += loss.item()
|
197 |
+
num_batches += 1
|
198 |
|
199 |
train_loss = train_loss / num_batches
|
200 |
print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
|
|
|
206 |
parser.add_argument("--num-cpu-cores", type=int, default=4)
|
207 |
parser.add_argument("--epochs", type=int, default=10)
|
208 |
parser.add_argument("--use_gpu", action='store_true')
|
209 |
+
parser.add_argument("--batch_size_cse", type=int, default=8)
|
210 |
+
parser.add_argument("--batch_size_sst", type=int, default=64)
|
211 |
+
parser.add_argument("--batch_size_cfimdb", type=int, default=8)
|
212 |
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
|
213 |
+
parser.add_argument("--lr_cse", type=float, default=2e-5)
|
214 |
+
parser.add_argument("--lr_classifier", type=float, default=1e-5)
|
215 |
|
216 |
args = parser.parse_args()
|
217 |
return args
|
|
|
231 |
use_gpu=args.use_gpu,
|
232 |
epochs=args.epochs,
|
233 |
batch_size_cse=args.batch_size_cse,
|
234 |
+
batch_size_classifier=args.batch_size_sst,
|
235 |
hidden_dropout_prob=args.hidden_dropout_prob,
|
236 |
+
train_bert='data/amazon-polarity.parquet',
|
237 |
train='data/ids-sst-train.csv',
|
238 |
dev='data/ids-sst-dev.csv',
|
239 |
test='data/ids-sst-test-student.csv'
|