sooooner commited on
Commit
ce2c619
Β·
1 Parent(s): df36716
Files changed (1) hide show
  1. utils.py +3 -2
utils.py CHANGED
@@ -71,10 +71,11 @@ def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height):
71
 
72
 
73
  class Image2Text:
74
- def __init__(self, model_path, hf_token, device):
75
  self.device = device
76
  self.hf_token = hf_token
77
  self.model_path = model_path
 
78
  self.model, self.processor = self.load_model(self.model_path)
79
  self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
80
 
@@ -97,7 +98,7 @@ class Image2Text:
97
  outputs = self.model.generate(
98
  pixel_values,
99
  decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
100
- max_length=1024,
101
  early_stopping=True,
102
  pad_token_id=self.processor.tokenizer.pad_token_id,
103
  eos_token_id=self.processor.tokenizer.eos_token_id,
 
71
 
72
 
73
  class Image2Text:
74
+ def __init__(self, model_path, hf_token, device, max_length=1024):
75
  self.device = device
76
  self.hf_token = hf_token
77
  self.model_path = model_path
78
+ self.max_length = max_length
79
  self.model, self.processor = self.load_model(self.model_path)
80
  self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
81
 
 
98
  outputs = self.model.generate(
99
  pixel_values,
100
  decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
101
+ max_length=self.max_length,
102
  early_stopping=True,
103
  pad_token_id=self.processor.tokenizer.pad_token_id,
104
  eos_token_id=self.processor.tokenizer.eos_token_id,