GlowCheese commited on
Commit
897eef4
·
1 Parent(s): 02b98f5

contrastive commit 3

Browse files
data/{twitter-unsup.csv → amazon-polarity.parquet} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a7af1ec5fc749ec8e5ea13c574aeb5c06254aa1c081e3421868079d5356b3f4
3
- size 20895533
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbe4770cfa6be45add6c9a322044bd4c1901520dde5a2707eca402a74fbe854e
3
+ size 870289
unsup_simcse.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import random
4
  import argparse
5
  import numpy as np
 
6
  import torch.nn.functional as F
7
 
8
  from tqdm import tqdm
@@ -20,7 +21,7 @@ from classifier import SentimentDataset, BertSentimentClassifier
20
  TQDM_DISABLE = False
21
 
22
 
23
- class TwitterDataset(Dataset):
24
  def __init__(self, dataset, args):
25
  self.dataset = dataset
26
  self.p = args
@@ -31,19 +32,22 @@ class TwitterDataset(Dataset):
31
  def __getitem__(self, idx):
32
  return self.dataset[idx]
33
 
34
- def pad_data(self, sents):
 
 
35
  encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
36
  token_ids = torch.LongTensor(encoding['input_ids'])
37
  attension_mask = torch.LongTensor(encoding['attention_mask'])
38
 
39
- return token_ids, attension_mask
40
 
41
- def collate_fn(self, sents):
42
- token_ids, attention_mask = self.pad_data(sents)
43
 
44
  batched_data = {
45
  'token_ids': token_ids,
46
  'attention_mask': attention_mask,
 
47
  }
48
 
49
  return batched_data
@@ -51,36 +55,36 @@ class TwitterDataset(Dataset):
51
 
52
  def load_data(filename, flag='train'):
53
  '''
54
- - for Twitter dataset: list of sentences
55
- - for SST/CFIMDB dataset: list of (sent, [label], sent_id)
 
56
  '''
57
 
58
- num_labels = set()
59
- data = []
60
- with open(filename, 'r') as fp:
61
- if flag == 'twitter':
62
- for cnt, record in enumerate(csv.DictReader(fp, delimiter = ',')):
63
- sent = record['clean_text'].lower().strip()
64
- data.append(sent)
65
- if cnt == 10000: break
66
- elif flag == 'test':
67
- for record in csv.DictReader(fp, delimiter = '\t'):
68
- sent = record['sentence'].lower().strip()
69
- sent_id = record['id'].lower().strip()
70
- data.append((sent,sent_id))
71
- else:
72
- for record in csv.DictReader(fp, delimiter = '\t'):
73
- sent = record['sentence'].lower().strip()
74
- sent_id = record['id'].lower().strip()
75
- label = int(record['sentiment'].strip())
76
- num_labels.add(label)
77
- data.append((sent, label, sent_id))
78
- print(f"load {len(data)} data from {filename}")
79
-
80
- if flag == 'train':
81
- return data, len(num_labels)
82
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return data
 
 
84
 
85
 
86
  def save_model(model, optimizer, args, config, filepath):
@@ -98,11 +102,6 @@ def save_model(model, optimizer, args, config, filepath):
98
  print(f"save the model to {filepath}")
99
 
100
 
101
- # def model_eval(dataloader, model, device):
102
- # model.eval()
103
-
104
-
105
-
106
  def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
107
  '''
108
  embeds_1: [batch_size, hidden_size]
@@ -131,7 +130,7 @@ def train(args):
131
  '''
132
  Training Pipeline
133
  -----------------
134
- 1. Load the Twitter Sentiment and SST Dataset.
135
  2. Determine batch_size (64) and number of batches (?).
136
  3. Initialize SentimentClassifier (including bert).
137
  4. Looping through 10 epoches.
@@ -142,16 +141,16 @@ def train(args):
142
  9. If dev_acc > best_dev_acc: save_model(...)
143
  '''
144
 
145
- twitter_data = load_data(args.train_bert, 'twitter')
146
  train_data, num_labels = load_data(args.train, 'train')
147
  dev_data = load_data(args.dev, 'valid')
148
 
149
- twitter_dataset = TwitterDataset(twitter_data, args)
150
  train_dataset = SentimentDataset(train_data, args)
151
  dev_dataset = SentimentDataset(dev_data, args)
152
 
153
- twitter_dataloader = DataLoader(twitter_dataset, shuffle=True, batch_size=args.batch_size_cse,
154
- num_workers=args.num_cpu_cores, collate_fn=twitter_dataset.collate_fn)
155
  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
156
  num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
157
  dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
@@ -177,7 +176,7 @@ def train(args):
177
  for epoch in range(args.epochs):
178
  model.bert.train()
179
  train_loss = num_batches = 0
180
- for batch in tqdm(twitter_dataloader, f'train-twitter-{epoch}', leave=False, disable=TQDM_DISABLE):
181
  b_ids, b_mask = batch['token_ids'], batch['attention_mask']
182
  b_ids = b_ids.to(device)
183
  b_mask = b_mask.to(device)
@@ -189,11 +188,13 @@ def train(args):
189
  # Calculate mean SimCSE loss function
190
  loss = contrastive_loss(logits_1, logits_2)
191
 
 
 
192
  loss.backward()
193
  optimizer_cse.step()
194
 
195
  train_loss += loss.item()
196
- num_batches += 0
197
 
198
  train_loss = train_loss / num_batches
199
  print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
@@ -205,11 +206,12 @@ def get_args():
205
  parser.add_argument("--num-cpu-cores", type=int, default=4)
206
  parser.add_argument("--epochs", type=int, default=10)
207
  parser.add_argument("--use_gpu", action='store_true')
208
- parser.add_argument("--batch_size_cse", help="'unsup': 64, 'sup': 512", type=int)
209
- parser.add_argument("--batch_size_classifier", help="'sst': 64, 'cfimdb': 8", type=int)
 
210
  parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
211
- parser.add_argument("--lr_cse", default=2e-5)
212
- parser.add_argument("--lr_classifier", default=1e-5)
213
 
214
  args = parser.parse_args()
215
  return args
@@ -229,9 +231,9 @@ if __name__ == "__main__":
229
  use_gpu=args.use_gpu,
230
  epochs=args.epochs,
231
  batch_size_cse=args.batch_size_cse,
232
- batch_size_classifier=args.batch_size_classifier,
233
  hidden_dropout_prob=args.hidden_dropout_prob,
234
- train_bert='data/twitter-unsup.csv',
235
  train='data/ids-sst-train.csv',
236
  dev='data/ids-sst-dev.csv',
237
  test='data/ids-sst-test-student.csv'
 
3
  import random
4
  import argparse
5
  import numpy as np
6
+ import pandas as pd
7
  import torch.nn.functional as F
8
 
9
  from tqdm import tqdm
 
21
  TQDM_DISABLE = False
22
 
23
 
24
+ class AmazonDataset(Dataset):
25
  def __init__(self, dataset, args):
26
  self.dataset = dataset
27
  self.p = args
 
32
  def __getitem__(self, idx):
33
  return self.dataset[idx]
34
 
35
+ def pad_data(self, data):
36
+ sents = [x[0] for x in data]
37
+ sent_ids = [x[1] for x in data]
38
  encoding = tokenizer(sents, return_tensors='pt', padding=True, truncation=True)
39
  token_ids = torch.LongTensor(encoding['input_ids'])
40
  attension_mask = torch.LongTensor(encoding['attention_mask'])
41
 
42
+ return token_ids, attension_mask, sent_ids
43
 
44
+ def collate_fn(self, data):
45
+ token_ids, attention_mask, sent_ids = self.pad_data(data)
46
 
47
  batched_data = {
48
  'token_ids': token_ids,
49
  'attention_mask': attention_mask,
50
+ 'sent_ids': sent_ids
51
  }
52
 
53
  return batched_data
 
55
 
56
  def load_data(filename, flag='train'):
57
  '''
58
+ - for amazon dataset: list of (sent, sent_id)
59
+ - for test dataset: list of (sent, sent_id)
60
+ - for train dataset: list of (sent, label, sent_id)
61
  '''
62
 
63
+ if flag == 'amazon':
64
+ df = pd.read_parquet(filename)
65
+ data = list(zip(df['content'], df.index))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  else:
67
+ data, num_labels = [], set()
68
+
69
+ with open(filename, 'r') as fp:
70
+ if flag == 'test':
71
+ for record in csv.DictReader(fp, delimiter = '\t'):
72
+ sent = record['sentence'].lower().strip()
73
+ sent_id = record['id'].lower().strip()
74
+ data.append((sent,sent_id))
75
+ else:
76
+ for record in csv.DictReader(fp, delimiter = '\t'):
77
+ sent = record['sentence'].lower().strip()
78
+ sent_id = record['id'].lower().strip()
79
+ label = int(record['sentiment'].strip())
80
+ num_labels.add(label)
81
+ data.append((sent, label, sent_id))
82
+
83
+ print(f"load {len(data)} data from {filename}")
84
+ if flag in ['test', 'amazon']:
85
  return data
86
+ else:
87
+ return data, len(num_labels)
88
 
89
 
90
  def save_model(model, optimizer, args, config, filepath):
 
102
  print(f"save the model to {filepath}")
103
 
104
 
 
 
 
 
 
105
  def contrastive_loss(embeds_1: Tensor, embeds_2: Tensor, temp=0.05):
106
  '''
107
  embeds_1: [batch_size, hidden_size]
 
130
  '''
131
  Training Pipeline
132
  -----------------
133
+ 1. Load the Amazon Polarity and SST Dataset.
134
  2. Determine batch_size (64) and number of batches (?).
135
  3. Initialize SentimentClassifier (including bert).
136
  4. Looping through 10 epoches.
 
141
  9. If dev_acc > best_dev_acc: save_model(...)
142
  '''
143
 
144
+ amazon_data = load_data(args.train_bert, 'amazon')
145
  train_data, num_labels = load_data(args.train, 'train')
146
  dev_data = load_data(args.dev, 'valid')
147
 
148
+ amazon_dataset = AmazonDataset(amazon_data, args)
149
  train_dataset = SentimentDataset(train_data, args)
150
  dev_dataset = SentimentDataset(dev_data, args)
151
 
152
+ amazon_dataloader = DataLoader(amazon_dataset, shuffle=True, batch_size=args.batch_size_cse,
153
+ num_workers=args.num_cpu_cores, collate_fn=amazon_dataset.collate_fn)
154
  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size_classifier,
155
  num_workers=args.num_cpu_cores, collate_fn=train_dataset.collate_fn)
156
  dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size_classifier,
 
176
  for epoch in range(args.epochs):
177
  model.bert.train()
178
  train_loss = num_batches = 0
179
+ for batch in tqdm(amazon_dataloader, f'train-amazon-{epoch}', leave=False, disable=TQDM_DISABLE):
180
  b_ids, b_mask = batch['token_ids'], batch['attention_mask']
181
  b_ids = b_ids.to(device)
182
  b_mask = b_mask.to(device)
 
188
  # Calculate mean SimCSE loss function
189
  loss = contrastive_loss(logits_1, logits_2)
190
 
191
+ # Back propagation
192
+ optimizer_cse.zero_grad()
193
  loss.backward()
194
  optimizer_cse.step()
195
 
196
  train_loss += loss.item()
197
+ num_batches += 1
198
 
199
  train_loss = train_loss / num_batches
200
  print(f"Epoch {epoch}: train loss :: {train_loss :.3f}")
 
206
  parser.add_argument("--num-cpu-cores", type=int, default=4)
207
  parser.add_argument("--epochs", type=int, default=10)
208
  parser.add_argument("--use_gpu", action='store_true')
209
+ parser.add_argument("--batch_size_cse", type=int, default=8)
210
+ parser.add_argument("--batch_size_sst", type=int, default=64)
211
+ parser.add_argument("--batch_size_cfimdb", type=int, default=8)
212
  parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
213
+ parser.add_argument("--lr_cse", type=float, default=2e-5)
214
+ parser.add_argument("--lr_classifier", type=float, default=1e-5)
215
 
216
  args = parser.parse_args()
217
  return args
 
231
  use_gpu=args.use_gpu,
232
  epochs=args.epochs,
233
  batch_size_cse=args.batch_size_cse,
234
+ batch_size_classifier=args.batch_size_sst,
235
  hidden_dropout_prob=args.hidden_dropout_prob,
236
+ train_bert='data/amazon-polarity.parquet',
237
  train='data/ids-sst-train.csv',
238
  dev='data/ids-sst-dev.csv',
239
  test='data/ids-sst-test-student.csv'