habdine commited on
Commit
bc973d1
1 Parent(s): 09fb203

Update modeling_prot2text.py

Browse files
Files changed (1) hide show
  1. modeling_prot2text.py +46 -18
modeling_prot2text.py CHANGED
@@ -240,7 +240,15 @@ class Prot2TextModel(PreTrainedModel):
240
  x: Optional[torch.FloatTensor] = None,
241
  edge_type: Optional[torch.LongTensor] = None,
242
  tokenizer=None,
243
- device='cpu'
 
 
 
 
 
 
 
 
244
  ):
245
 
246
  if self.config.esm and not self.config.rgcn and protein_sequence==None:
@@ -326,25 +334,45 @@ class Prot2TextModel(PreTrainedModel):
326
  encoder_state['attentions'] = inputs['attention_mask']
327
  for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
328
  inputs.pop(key)
329
- tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
330
- encoder_outputs=encoder_state,
331
- use_cache=True,
332
- output_attentions=False,
333
- output_scores=False,
334
- return_dict_in_generate=True,
335
- encoder_attention_mask=inputs['attention_mask'],
336
- length_penalty=1.0,
337
- no_repeat_ngram_size=None,
338
- early_stopping=False,
339
- num_beams=1)
 
340
 
341
- generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
342
 
343
- os.remove(structure_filename)
344
- os.remove(graph_filename)
345
- os.remove(process_filename)
346
-
347
- return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  else:
350
  seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
 
240
  x: Optional[torch.FloatTensor] = None,
241
  edge_type: Optional[torch.LongTensor] = None,
242
  tokenizer=None,
243
+ device='cpu',
244
+ streamer=None,
245
+ max_new_tokens=None,
246
+ do_sample=None,
247
+ top_p=None,
248
+ top_k=None,
249
+ temperature=None,
250
+ num_beams=1,
251
+ repetition_penalty=None
252
  ):
253
 
254
  if self.config.esm and not self.config.rgcn and protein_sequence==None:
 
334
  encoder_state['attentions'] = inputs['attention_mask']
335
  for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
336
  inputs.pop(key)
337
+ if streamer is None:
338
+ tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
339
+ encoder_outputs=encoder_state,
340
+ use_cache=True,
341
+ output_attentions=False,
342
+ output_scores=False,
343
+ return_dict_in_generate=True,
344
+ encoder_attention_mask=inputs['attention_mask'],
345
+ length_penalty=1.0,
346
+ no_repeat_ngram_size=None,
347
+ early_stopping=False,
348
+ num_beams=1)
349
 
350
+ generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
351
 
352
+ os.remove(structure_filename)
353
+ os.remove(graph_filename)
354
+ os.remove(process_filename)
355
+
356
+ return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
357
+ else:
358
+ os.remove(structure_filename)
359
+ os.remove(graph_filename)
360
+ os.remove(process_filename)
361
+ return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
362
+ encoder_outputs=encoder_state,
363
+ use_cache=True,
364
+ encoder_attention_mask=inputs['attention_mask'],
365
+ length_penalty=1.0,
366
+ no_repeat_ngram_size=None,
367
+ early_stopping=False,
368
+ streamer=streamer,
369
+ max_new_tokens=max_new_tokens,
370
+ do_sample=do_sample,
371
+ top_p=top_p,
372
+ top_k=top_k,
373
+ temperature=temperature,
374
+ num_beams=1,
375
+ repetition_penalty=repetition_penalty)
376
 
377
  else:
378
  seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")