Armen Gabrielyan commited on
Commit
4820fa1
·
1 Parent(s): d80771b

change to beam search strategy

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +6 -1
app.py CHANGED
@@ -9,7 +9,7 @@ from inference import Inference
9
  import utils
10
 
11
  encoder_model_name = 'google/vit-large-patch32-224-in21k'
12
- decoder_model_name = 'gpt2'
13
  frame_step = 300
14
 
15
  inference = Inference(
 
9
  import utils
10
 
11
  encoder_model_name = 'google/vit-large-patch32-224-in21k'
12
+ decoder_model_name = 'gpt2-large'
13
  frame_step = 300
14
 
15
  inference = Inference(
inference.py CHANGED
@@ -23,7 +23,12 @@ class Inference:
23
  self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24
  self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
25
 
26
- generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device), max_length=self.max_length)
 
 
 
 
 
27
  generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
 
29
  return generated_text
 
23
  self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
24
  self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
25
 
26
+ generated_ids = self.encoder_decoder_model.generate(
27
+ pixel_values.unsqueeze(0).to(self.device),
28
+ max_length=self.max_length,
29
+ num_beams=4,
30
+ no_repeat_ngram_size=2,
31
+ )
32
  generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
  return generated_text