Update vitGPT.py
Browse files
vitGPT.py
CHANGED
@@ -278,7 +278,7 @@ class VisionGPT2Model(nn.Module):
|
|
278 |
lm_logits = self.lm_head(input_ids[:,[-1],:])
|
279 |
return lm_logits
|
280 |
|
281 |
-
def generate(self,image,sequence,tokenizer,max_tokens
|
282 |
for _ in range(max_tokens):
|
283 |
out = self(image,sequence)
|
284 |
out = out[:,-1,:] / temperature
|
@@ -337,8 +337,8 @@ def generate_caption(image,max_tokens,temperature,deterministic=True):
|
|
337 |
image,
|
338 |
sequence,
|
339 |
tokenizer,
|
340 |
-
max_tokens
|
341 |
-
temperature
|
342 |
deterministic=deterministic,
|
343 |
|
344 |
)
|
|
|
278 |
lm_logits = self.lm_head(input_ids[:,[-1],:])
|
279 |
return lm_logits
|
280 |
|
281 |
+
def generate(self,image,sequence,tokenizer,max_tokens,temperature,deterministic=False):
|
282 |
for _ in range(max_tokens):
|
283 |
out = self(image,sequence)
|
284 |
out = out[:,-1,:] / temperature
|
|
|
337 |
image,
|
338 |
sequence,
|
339 |
tokenizer,
|
340 |
+
max_tokens,
|
341 |
+
temperature,
|
342 |
deterministic=deterministic,
|
343 |
|
344 |
)
|