Spaces:
Runtime error
Runtime error
utils.py
CHANGED
@@ -71,10 +71,11 @@ def aspect_ratio_preserving_resize_and_crop(image, target_width, target_height):
|
|
71 |
|
72 |
|
73 |
class Image2Text:
|
74 |
-
def __init__(self, model_path, hf_token, device):
|
75 |
self.device = device
|
76 |
self.hf_token = hf_token
|
77 |
self.model_path = model_path
|
|
|
78 |
self.model, self.processor = self.load_model(self.model_path)
|
79 |
self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
|
80 |
|
@@ -97,7 +98,7 @@ class Image2Text:
|
|
97 |
outputs = self.model.generate(
|
98 |
pixel_values,
|
99 |
decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
|
100 |
-
max_length=
|
101 |
early_stopping=True,
|
102 |
pad_token_id=self.processor.tokenizer.pad_token_id,
|
103 |
eos_token_id=self.processor.tokenizer.eos_token_id,
|
|
|
71 |
|
72 |
|
73 |
class Image2Text:
|
74 |
+
def __init__(self, model_path, hf_token, device, max_length=1024):
|
75 |
self.device = device
|
76 |
self.hf_token = hf_token
|
77 |
self.model_path = model_path
|
78 |
+
self.max_length = max_length
|
79 |
self.model, self.processor = self.load_model(self.model_path)
|
80 |
self.decoder_input_ids = torch.tensor([[self.model.config.decoder_start_token_id]]).to(self.device)
|
81 |
|
|
|
98 |
outputs = self.model.generate(
|
99 |
pixel_values,
|
100 |
decoder_input_ids=self.decoder_input_ids.repeat(pixel_values.shape[0], 1),
|
101 |
+
max_length=self.max_length,
|
102 |
early_stopping=True,
|
103 |
pad_token_id=self.processor.tokenizer.pad_token_id,
|
104 |
eos_token_id=self.processor.tokenizer.eos_token_id,
|