Spaces:
Running
on
Zero
Running
on
Zero
Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +5 -1
modeling_llava_qwen2.py
CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
|
|
12 |
import torch.utils.checkpoint
|
13 |
from torch import nn
|
14 |
import torch
|
|
|
15 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
16 |
from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
|
17 |
from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
|
@@ -534,6 +535,7 @@ class SigLipVisionTower(nn.Module):
|
|
534 |
self.is_loaded = True
|
535 |
|
536 |
@torch.no_grad()
|
|
|
537 |
def forward(self, images):
|
538 |
if type(images) is list:
|
539 |
image_features = []
|
@@ -659,11 +661,13 @@ class LlavaMetaForCausalLM(ABC):
|
|
659 |
def get_vision_tower(self):
|
660 |
return self.get_model().get_vision_tower()
|
661 |
|
|
|
662 |
def encode_images(self, images):
|
663 |
image_features = self.get_model().get_vision_tower()(images)
|
664 |
image_features = self.get_model().mm_projector(image_features)
|
665 |
return image_features
|
666 |
-
|
|
|
667 |
def prepare_inputs_labels_for_multimodal(
|
668 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
669 |
):
|
|
|
12 |
import torch.utils.checkpoint
|
13 |
from torch import nn
|
14 |
import torch
|
15 |
+
import spaces
|
16 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
17 |
from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
|
18 |
from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
|
|
|
535 |
self.is_loaded = True
|
536 |
|
537 |
@torch.no_grad()
|
538 |
+
@spaces.GPU
|
539 |
def forward(self, images):
|
540 |
if type(images) is list:
|
541 |
image_features = []
|
|
|
661 |
def get_vision_tower(self):
|
662 |
return self.get_model().get_vision_tower()
|
663 |
|
664 |
+
@spaces.GPU
|
665 |
def encode_images(self, images):
|
666 |
image_features = self.get_model().get_vision_tower()(images)
|
667 |
image_features = self.get_model().mm_projector(image_features)
|
668 |
return image_features
|
669 |
+
|
670 |
+
@spaces.GPU
|
671 |
def prepare_inputs_labels_for_multimodal(
|
672 |
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
673 |
):
|