bingwork commited on
Commit
f6d0975
·
verified ·
1 Parent(s): 56243ad

Upload mmalaya_arch.py

Browse files
Files changed (1) hide show
  1. mmalaya_arch.py +5 -5
mmalaya_arch.py CHANGED
@@ -299,15 +299,15 @@ class MMAlayaMetaForCausalLM(ABC):
299
  for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
300
  input_ids.extend(x[offset:])
301
 
302
- if return_tensors is not None:
303
- if return_tensors == 'pt':
304
- return torch.tensor(input_ids, dtype=torch.long)
305
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
306
-
307
  # 加载generate stop条件
308
  stopping_criteria = KeywordsStoppingCriteria([conv.sep2], tokenizer, input_ids)
309
  # 加载图像
310
  image_processor = model.get_vision_tower().image_processor
311
  image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
312
 
 
 
 
 
 
313
  return input_ids, image_tensor, stopping_criteria
 
299
  for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
300
  input_ids.extend(x[offset:])
301
 
 
 
 
 
 
302
  # 加载generate stop条件
303
  stopping_criteria = KeywordsStoppingCriteria([conv.sep2], tokenizer, input_ids)
304
  # 加载图像
305
  image_processor = model.get_vision_tower().image_processor
306
  image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
307
 
308
+ if return_tensors is not None:
309
+ if return_tensors == 'pt':
310
+ return torch.tensor(input_ids, dtype=torch.long), image_tensor, stopping_criteria
311
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
312
+
313
  return input_ids, image_tensor, stopping_criteria