Update CXR_LLAVA_HF.py
Browse files- CXR_LLAVA_HF.py +3 -0
CXR_LLAVA_HF.py
CHANGED
|
@@ -615,6 +615,9 @@ class CXRLLAVAModel(PreTrainedModel):
|
|
| 615 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
| 616 |
if self.device == 'cuda':
|
| 617 |
input_ids = input_ids.cuda()
|
|
|
|
|
|
|
|
|
|
| 618 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
| 619 |
|
| 620 |
image_args = {"images": images}
|
|
|
|
| 615 |
input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
|
| 616 |
if self.device == 'cuda':
|
| 617 |
input_ids = input_ids.cuda()
|
| 618 |
+
print('using cuda')
|
| 619 |
+
else:
|
| 620 |
+
print('using cpu')
|
| 621 |
stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
|
| 622 |
|
| 623 |
image_args = {"images": images}
|