visheratin commited on
Commit
304a0a4
·
1 Parent(s): cde656c

Update model files

Browse files
Files changed (1) hide show
  1. processing_llava.py +7 -2
processing_llava.py CHANGED
@@ -30,6 +30,7 @@ from transformers.tokenization_utils_base import (
30
  from transformers.utils import TensorType
31
  import torch
32
  from open_clip.transform import PreprocessCfg, image_transform_v2
 
33
 
34
 
35
  class OpenCLIPImageProcessor:
@@ -67,6 +68,7 @@ class LlavaProcessor:
67
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
68
  ] = None,
69
  images: ImageInput = None,
 
70
  padding: Union[bool, str, PaddingStrategy] = False,
71
  truncation: Union[bool, str, TruncationStrategy] = None,
72
  max_length=None,
@@ -76,8 +78,11 @@ class LlavaProcessor:
76
  pixel_values = self.image_processor(images, return_tensors=return_tensors)[
77
  "pixel_values"
78
  ]
 
 
 
79
  else:
80
- pixel_values = None
81
  text_inputs = self.tokenizer(
82
  text,
83
  return_tensors=return_tensors,
@@ -86,7 +91,7 @@ class LlavaProcessor:
86
  max_length=max_length,
87
  )
88
 
89
- return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
90
 
91
  def batch_decode(self, *args, **kwargs):
92
  return self.tokenizer.batch_decode(*args, **kwargs)
 
30
  from transformers.utils import TensorType
31
  import torch
32
  from open_clip.transform import PreprocessCfg, image_transform_v2
33
+ from modeling_llava import LlavaForConditionalGeneration
34
 
35
 
36
  class OpenCLIPImageProcessor:
 
68
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
69
  ] = None,
70
  images: ImageInput = None,
71
+ model: LlavaForConditionalGeneration = None,
72
  padding: Union[bool, str, PaddingStrategy] = False,
73
  truncation: Union[bool, str, TruncationStrategy] = None,
74
  max_length=None,
 
78
  pixel_values = self.image_processor(images, return_tensors=return_tensors)[
79
  "pixel_values"
80
  ]
81
+ pixel_values = pixel_values.to(model.device)
82
+ image_outputs = model.vision_model(pixel_values)
83
+ image_features = model.multi_modal_projector(image_outputs)
84
  else:
85
+ image_features = None
86
  text_inputs = self.tokenizer(
87
  text,
88
  return_tensors=return_tensors,
 
91
  max_length=max_length,
92
  )
93
 
94
+ return BatchFeature(data={**text_inputs, "image_features": image_features})
95
 
96
  def batch_decode(self, *args, **kwargs):
97
  return self.tokenizer.batch_decode(*args, **kwargs)