jcsagar commited on
Commit
30ac77c
·
verified ·
1 Parent(s): 3da894e

Update CXR_LLAVA_HF.py

Browse files
Files changed (1) hide show
  1. CXR_LLAVA_HF.py +3 -0
CXR_LLAVA_HF.py CHANGED
@@ -615,6 +615,9 @@ class CXRLLAVAModel(PreTrainedModel):
615
  input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
616
  if self.device == 'cuda':
617
  input_ids = input_ids.cuda()
 
 
 
618
  stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
619
 
620
  image_args = {"images": images}
 
615
  input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
616
  if self.device == 'cuda':
617
  input_ids = input_ids.cuda()
618
+ print('using cuda')
619
+ else:
620
+ print('using cpu')
621
  stopping_criteria = KeywordsStoppingCriteria(["</s>"], self.tokenizer, input_ids)
622
 
623
  image_args = {"images": images}