andreslu commited on
Commit
e7fe4e6
·
1 Parent(s): ab42fe0

Update inductor.py

Browse files
Files changed (1) hide show
  1. inductor.py +0 -97
inductor.py CHANGED
@@ -312,100 +312,3 @@ class BartInductor(object):
312
  return ret
313
 
314
 
315
- class CometInductor(object):
316
- def __init__(self):
317
- self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval().float() # .half()->float
318
- self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
319
- self.task = "summarization"
320
- self.use_task_specific_params()
321
- self.decoder_start_token_id = None
322
-
323
- def drop_repeat(self, old_list):
324
- new_list = []
325
- for item in old_list:
326
- if item not in new_list:
327
- new_list.append(item)
328
-
329
- return new_list
330
-
331
- def chunks(self, lst, n):
332
- """Yield successive n-sized chunks from lst."""
333
- for i in range(0, len(lst), n):
334
- yield lst[i : i + n]
335
-
336
- def use_task_specific_params(self):
337
- """Update config with summarization specific params."""
338
- task_specific_params = self.model.config.task_specific_params
339
-
340
- if task_specific_params is not None:
341
- pars = task_specific_params.get(self.task, {})
342
- self.model.config.update(pars)
343
-
344
- def trim_batch(
345
- self, input_ids, pad_token_id, attention_mask=None,
346
- ):
347
- """Remove columns that are populated exclusively by pad_token_id"""
348
- keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
349
- if attention_mask is None:
350
- return input_ids[:, keep_column_mask]
351
- else:
352
- return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
353
-
354
- def generate(self, inputs, k, topk):
355
- outputs = []
356
- words = ['PersonX', 'PersonY']
357
- for i, _ in enumerate(re.findall("<mask>", inputs)):
358
- index = inputs.index('<mask>')
359
- inputs = inputs[:index] + words[i] + inputs[index + len('<mask>'):]
360
-
361
- for relation in RELATIONS:
362
- inputs = "{} {} [GEN]".format(inputs[:-1], relation)
363
- gen = self.generate_(inputs, num_generate=10)
364
- switch = 0
365
- for output in gen[0]:
366
- output = output.strip()
367
- if re.search("PersonX|X", output) and re.search("PersonY|Y", output):
368
- temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
369
- if temp.endswith("."):
370
- outputs.append(temp)
371
- else:
372
- outputs.append(temp + ".")
373
- switch = 1
374
- break
375
-
376
- if switch == 0:
377
- output = gen[0][0]
378
- temp = re.sub("PersonX|X|PersonY|Y", "<mask>", output.strip())
379
- if temp.endswith("."):
380
- outputs.append(temp)
381
- else:
382
- outputs.append(temp + ".")
383
-
384
- outputs = [output.replace('PersonX', '<mask>').replace('PersonY', '<mask>') for output in outputs]
385
- return outputs
386
-
387
- def generate_(
388
- self,
389
- queries,
390
- decode_method="beam",
391
- num_generate=5,
392
- ):
393
-
394
- with torch.no_grad():
395
- decs = []
396
- batch = self.tokenizer(queries, return_tensors="pt", padding="longest")
397
- input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
398
-
399
- summaries = self.model.generate(
400
- input_ids=input_ids.to(device),
401
- attention_mask=attention_mask.to(device),
402
- decoder_start_token_id=self.decoder_start_token_id,
403
- num_beams=num_generate,
404
- num_return_sequences=num_generate,
405
- )
406
-
407
- dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
408
- decs.append(dec)
409
-
410
- return decs
411
-
 
312
  return ret
313
 
314