BeveledCube commited on
Commit
5e95677
·
1 Parent(s): 7774227

Attention mask shi

Browse files
Files changed (1) hide show
  1. models/gpt2.py +3 -3
models/gpt2.py CHANGED
@@ -1,6 +1,5 @@
1
  from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
2
-
3
- # https://www.youtube.com/watch?v=irjYqV6EebU
4
 
5
  model_name = "gpt2"
6
 
@@ -13,7 +12,8 @@ def load():
13
 
14
  def generate(input_text):
15
  # Tokenize the input text
16
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
 
17
 
18
  # Generate output using the model
19
  output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
 
1
  from transformers import GPT2Tokenizer, TFGPT2LMHeadModel
2
+ import tensorflow as tf
 
3
 
4
  model_name = "gpt2"
5
 
 
12
 
13
  def generate(input_text):
14
  # Tokenize the input text
15
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, padding=True)
16
+ attention_mask = tf.ones_like(input_ids)
17
 
18
  # Generate output using the model
19
  output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)