Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +3 -0
CXR_LLAVA_HF.py
CHANGED
@@ -615,6 +615,9 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
615 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
616 |
if self.device == 'cuda':
|
617 |
input_ids = input_ids.cuda()
|
|
|
|
|
|
|
618 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
619 |
|
620 |
image_args = {"images": images}
|
|
|
615 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
616 |
if self.device == 'cuda':
|
617 |
input_ids = input_ids.cuda()
|
618 |
+
print('using cuda')
|
619 |
+
else:
|
620 |
+
print('using cpu')
|
621 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
622 |
|
623 |
image_args = {"images": images}
|