abrar0503 commited on
Commit
6f7f9b0
·
verified ·
1 Parent(s): e0ffa77

Delete data_config.json

Browse files
Files changed (1) hide show
  1. data_config.json +0 -344
data_config.json DELETED
@@ -1,344 +0,0 @@
1
- """
2
- Train script for a single file
3
-
4
- Need to set the TPU address first:
5
- export XRT_TPU_CONFIG="localservice;0;localhost:51011"
6
- """
7
-
8
- import torch.multiprocessing as mp
9
- import threading
10
- import time
11
- import random
12
- import sys
13
- import argparse
14
- import gzip
15
- import json
16
- import logging
17
- import tqdm
18
- import torch
19
- from torch import nn
20
- from torch.utils.data import DataLoader
21
- import torch
22
- import torch_xla
23
- import torch_xla.core
24
- import torch_xla.core.functions
25
- import torch_xla.core.xla_model as xm
26
- import torch_xla.distributed.xla_multiprocessing as xmp
27
- import torch_xla.distributed.parallel_loader as pl
28
- import os
29
- from shutil import copyfile
30
-
31
-
32
- from transformers import (
33
- AdamW,
34
- AutoModel,
35
- AutoTokenizer,
36
- get_linear_schedule_with_warmup,
37
- set_seed,
38
- )
39
-
40
- class AutoModelForSentenceEmbedding(nn.Module):
41
- def __init__(self, model_name, tokenizer, normalize=True):
42
- super(AutoModelForSentenceEmbedding, self).__init__()
43
-
44
- self.model = AutoModel.from_pretrained(model_name)
45
- self.normalize = normalize
46
- self.tokenizer = tokenizer
47
-
48
- def forward(self, **kwargs):
49
- model_output = self.model(**kwargs)
50
- embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
51
- if self.normalize:
52
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
53
-
54
- return embeddings
55
-
56
- def mean_pooling(self, model_output, attention_mask):
57
- token_embeddings = model_output[0] # First element of model_output contains all token embeddings
58
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
59
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
60
-
61
- def save_pretrained(self, output_path):
62
- if xm.is_master_ordinal():
63
- self.tokenizer.save_pretrained(output_path)
64
- self.model.config.save_pretrained(output_path)
65
-
66
- xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
67
-
68
-
69
-
70
-
71
- def train_function(index, args, queue):
72
- tokenizer = AutoTokenizer.from_pretrained(args.model)
73
- model = AutoModelForSentenceEmbedding(args.model, tokenizer)
74
-
75
-
76
- ### Train Loop
77
- device = xm.xla_device()
78
- model = model.to(device)
79
-
80
- # Instantiate optimizer
81
- optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
82
-
83
- lr_scheduler = get_linear_schedule_with_warmup(
84
- optimizer=optimizer,
85
- num_warmup_steps=500,
86
- num_training_steps=args.steps,
87
- )
88
-
89
- # Now we train the model
90
- cross_entropy_loss = nn.CrossEntropyLoss()
91
- max_grad_norm = 1
92
-
93
- model.train()
94
-
95
- for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
96
- #### Get the batch data
97
- batch = queue.get()
98
- #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
99
-
100
-
101
- if len(batch[0]) == 2: #(anchor, positive)
102
- text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
103
- text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
104
-
105
- ### Compute embeddings
106
- embeddings_a = model(**text1.to(device))
107
- embeddings_b = model(**text2.to(device))
108
-
109
- ### Gather all embedings
110
- embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
111
- embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
112
-
113
- ### Compute similarity scores 512 x 512
114
- scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
115
-
116
- ### Compute cross-entropy loss
117
- labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
118
-
119
- ## Symmetric loss as in CLIP
120
- loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
121
-
122
- else: #(anchor, positive, negative)
123
- text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
124
- text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
125
- text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
126
-
127
- embeddings_a = model(**text1.to(device))
128
- embeddings_b1 = model(**text2.to(device))
129
- embeddings_b2 = model(**text3.to(device))
130
-
131
- embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
132
- embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
133
- embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
134
-
135
- embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
136
-
137
- ### Compute similarity scores 512 x 1024
138
- scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
139
-
140
- ### Compute cross-entropy loss
141
- labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
142
-
143
- ## One-way loss
144
- loss = cross_entropy_loss(scores, labels)
145
-
146
-
147
- # Backward pass
148
- optimizer.zero_grad()
149
- loss.backward()
150
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
151
-
152
- xm.optimizer_step(optimizer, barrier=True)
153
- lr_scheduler.step()
154
-
155
-
156
- #Save model
157
- if (global_step+1) % args.save_steps == 0:
158
- output_path = os.path.join(args.output, str(global_step+1))
159
- xm.master_print("save model: "+output_path)
160
- model.save_pretrained(output_path)
161
-
162
-
163
- output_path = os.path.join(args.output, "final")
164
- xm.master_print("save model final: "+ output_path)
165
- model.save_pretrained(output_path)
166
-
167
-
168
- def produce_data(args, queue, filepaths, dataset_indices):
169
- global_batch_size = args.batch_size*args.nprocs #Global batch size
170
- size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
171
- num_same_dataset = int(size_per_dataset / args.batch_size)
172
- print("producer", "global_batch_size", global_batch_size)
173
- print("producer", "size_per_dataset", size_per_dataset)
174
- print("producer", "num_same_dataset", num_same_dataset)
175
-
176
- datasets = []
177
- for filepath in filepaths:
178
- if "reddit_" in filepath: #Special dataset class for Reddit files
179
- data_obj = RedditDataset(filepath)
180
- else:
181
- data_obj = Dataset(filepath)
182
- datasets.append(iter(data_obj))
183
-
184
- # Store if dataset is in a 2 col or 3 col format
185
- num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
186
-
187
- while True:
188
- texts_in_batch = set()
189
- batch_format = None #2 vs 3 col format for this batch
190
-
191
- #Add data from several sub datasets
192
- for _ in range(args.datasets_per_batch):
193
- valid_dataset = False #Check that datasets have the same 2/3 col format
194
- while not valid_dataset:
195
- data_idx = random.choice(dataset_indices)
196
- if batch_format is None:
197
- batch_format = num_cols[data_idx]
198
- valid_dataset = True
199
- else: #Check that this dataset has the same format
200
- valid_dataset = (batch_format == num_cols[data_idx])
201
-
202
- #Get data from this dataset
203
- dataset = datasets[data_idx]
204
- for _ in range(num_same_dataset):
205
- for _ in range(args.nprocs):
206
- batch_device = [] #A batch for one device
207
- while len(batch_device) < args.batch_size:
208
- sample = next(dataset)
209
- in_batch = False
210
- for text in sample:
211
- if text in texts_in_batch:
212
- in_batch = True
213
- break
214
-
215
- if not in_batch:
216
- for text in sample:
217
- texts_in_batch.add(text)
218
- batch_device.append(sample)
219
-
220
- queue.put(batch_device)
221
-
222
-
223
- class RedditDataset:
224
- """
225
- A class that handles the reddit data files
226
- """
227
- def __init__(self, filepath):
228
- self.filepath = filepath
229
-
230
- def __iter__(self):
231
- while True:
232
- with gzip.open(self.filepath, "rt") as fIn:
233
- for line in fIn:
234
- data = json.loads(line)
235
-
236
- if "response" in data and "context" in data:
237
- yield [data["response"], data["context"]]
238
-
239
- class Dataset:
240
- """
241
- A class that handles one dataset
242
- """
243
- def __init__(self, filepath):
244
- self.filepath = filepath
245
-
246
- def __iter__(self):
247
- max_dataset_size = 10*1000*1000 #Cache small datasets in memory
248
- dataset = []
249
- data_format = None
250
-
251
- while dataset is None or len(dataset) == 0:
252
- with gzip.open(self.filepath, "rt") as fIn:
253
- for line in fIn:
254
- data = json.loads(line)
255
- if isinstance(data, dict):
256
- data = data['texts']
257
-
258
- if data_format is None:
259
- data_format = len(data)
260
-
261
- #Ensure that all entries are of the same 2/3 col format
262
- assert len(data) == data_format
263
-
264
- if dataset is not None:
265
- dataset.append(data)
266
- if len(dataset) >= max_dataset_size:
267
- dataset = None
268
-
269
- yield data
270
-
271
- # Data loaded. Now stream to the queue
272
- # Shuffle for each epoch
273
- while True:
274
- random.shuffle(dataset)
275
- for data in dataset:
276
- yield data
277
-
278
-
279
-
280
- if __name__ == "__main__":
281
- parser = argparse.ArgumentParser()
282
- parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
283
- parser.add_argument('--steps', type=int, default=2000)
284
- parser.add_argument('--save_steps', type=int, default=10000)
285
- parser.add_argument('--batch_size', type=int, default=64)
286
- parser.add_argument('--max_length', type=int, default=128)
287
- parser.add_argument('--nprocs', type=int, default=8)
288
- parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
289
- parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
290
- parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
291
- parser.add_argument('data_config', help="A data_config.json file")
292
- parser.add_argument('output')
293
- args = parser.parse_args()
294
-
295
- # Ensure global batch size is divisble by data_sample_size
296
- assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
297
-
298
- logging.info("Output: "+args.output)
299
- if os.path.exists(args.output):
300
- print("Output folder already exists.")
301
- input("Continue?")
302
-
303
- # Write train script to output path
304
- os.makedirs(args.output, exist_ok=True)
305
-
306
- data_config_path = os.path.join(args.output, 'data_config.json')
307
- copyfile(args.data_config, data_config_path)
308
-
309
- train_script_path = os.path.join(args.output, 'train_script.py')
310
- copyfile(__file__, train_script_path)
311
- with open(train_script_path, 'a') as fOut:
312
- fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
313
-
314
-
315
-
316
- #Load data config
317
- with open(args.data_config) as fIn:
318
- data_config = json.load(fIn)
319
-
320
- queue = mp.Queue(maxsize=100*args.nprocs)
321
-
322
- filepaths = []
323
- dataset_indices = []
324
- for idx, data in enumerate(data_config):
325
- filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
326
- dataset_indices.extend([idx]*data['weight'])
327
-
328
- # Start producer
329
- p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
330
- p.start()
331
-
332
- # Run training
333
- print("Start processes:", args.nprocs)
334
- xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
335
- print("Training done")
336
- print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
337
- print("With 'pkill python' you can kill all remaining python processes")
338
- p.kill()
339
- exit()
340
-
341
-
342
-
343
- # Script was called via:
344
- #python train_many_data_files_v2.py --steps 1000000 --batch_size 128 --model nreimers/MiniLM-L6-H384-uncased train_data_configs/all_datasets_v4.json output/all_datasets_v4_MiniLM-L6-H384-uncased-batch128