Update modeling_prot2text.py
Browse files- modeling_prot2text.py +46 -18
modeling_prot2text.py
CHANGED
@@ -240,7 +240,15 @@ class Prot2TextModel(PreTrainedModel):
|
|
240 |
x: Optional[torch.FloatTensor] = None,
|
241 |
edge_type: Optional[torch.LongTensor] = None,
|
242 |
tokenizer=None,
|
243 |
-
device='cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
):
|
245 |
|
246 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
@@ -326,25 +334,45 @@ class Prot2TextModel(PreTrainedModel):
|
|
326 |
encoder_state['attentions'] = inputs['attention_mask']
|
327 |
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
|
328 |
inputs.pop(key)
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
340 |
|
341 |
-
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
else:
|
350 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|
|
|
240 |
x: Optional[torch.FloatTensor] = None,
|
241 |
edge_type: Optional[torch.LongTensor] = None,
|
242 |
tokenizer=None,
|
243 |
+
device='cpu',
|
244 |
+
streamer=None,
|
245 |
+
max_new_tokens=None,
|
246 |
+
do_sample=None,
|
247 |
+
top_p=None,
|
248 |
+
top_k=None,
|
249 |
+
temperature=None,
|
250 |
+
num_beams=1,
|
251 |
+
repetition_penalty=None
|
252 |
):
|
253 |
|
254 |
if self.config.esm and not self.config.rgcn and protein_sequence==None:
|
|
|
334 |
encoder_state['attentions'] = inputs['attention_mask']
|
335 |
for key in ['edge_index', 'edge_type', 'x', 'encoder_input_ids']:
|
336 |
inputs.pop(key)
|
337 |
+
if streamer is None:
|
338 |
+
tok_ids = self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
339 |
+
encoder_outputs=encoder_state,
|
340 |
+
use_cache=True,
|
341 |
+
output_attentions=False,
|
342 |
+
output_scores=False,
|
343 |
+
return_dict_in_generate=True,
|
344 |
+
encoder_attention_mask=inputs['attention_mask'],
|
345 |
+
length_penalty=1.0,
|
346 |
+
no_repeat_ngram_size=None,
|
347 |
+
early_stopping=False,
|
348 |
+
num_beams=1)
|
349 |
|
350 |
+
generated = tokenizer.batch_decode(tok_ids.get('sequences'), skip_special_tokens=True)
|
351 |
|
352 |
+
os.remove(structure_filename)
|
353 |
+
os.remove(graph_filename)
|
354 |
+
os.remove(process_filename)
|
355 |
+
|
356 |
+
return generated[0].replace('<|stop_token|>', '').replace('<|graph_token|>', '')
|
357 |
+
else:
|
358 |
+
os.remove(structure_filename)
|
359 |
+
os.remove(graph_filename)
|
360 |
+
os.remove(process_filename)
|
361 |
+
return self.decoder.generate(input_ids=inputs['decoder_input_ids'],
|
362 |
+
encoder_outputs=encoder_state,
|
363 |
+
use_cache=True,
|
364 |
+
encoder_attention_mask=inputs['attention_mask'],
|
365 |
+
length_penalty=1.0,
|
366 |
+
no_repeat_ngram_size=None,
|
367 |
+
early_stopping=False,
|
368 |
+
streamer=streamer,
|
369 |
+
max_new_tokens=max_new_tokens,
|
370 |
+
do_sample=do_sample,
|
371 |
+
top_p=top_p,
|
372 |
+
top_k=top_k,
|
373 |
+
temperature=temperature,
|
374 |
+
num_beams=1,
|
375 |
+
repetition_penalty=repetition_penalty)
|
376 |
|
377 |
else:
|
378 |
seq = esmtokenizer([protein_sequence], add_special_tokens=True, truncation=True, max_length=1021, padding='max_length', return_tensors="pt")
|