Upload folder using huggingface_hub
Browse files- configuration_internvl_chat.py +2 -1
- conversation.py +43 -0
- modeling_internvl_chat.py +197 -10
configuration_internvl_chat.py
CHANGED
@@ -12,7 +12,6 @@ from transformers.utils import logging
|
|
12 |
|
13 |
from .configuration_intern_vit import InternVisionConfig
|
14 |
|
15 |
-
|
16 |
logger = logging.get_logger(__name__)
|
17 |
|
18 |
|
@@ -52,6 +51,8 @@ class InternVLChatConfig(PretrainedConfig):
|
|
52 |
self.downsample_ratio = downsample_ratio
|
53 |
self.template = template
|
54 |
|
|
|
|
|
55 |
def to_dict(self):
|
56 |
"""
|
57 |
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
|
|
12 |
|
13 |
from .configuration_intern_vit import InternVisionConfig
|
14 |
|
|
|
15 |
logger = logging.get_logger(__name__)
|
16 |
|
17 |
|
|
|
51 |
self.downsample_ratio = downsample_ratio
|
52 |
self.template = template
|
53 |
|
54 |
+
logger.info(f'vision_select_layer: {self.select_layer}')
|
55 |
+
|
56 |
def to_dict(self):
|
57 |
"""
|
58 |
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
conversation.py
CHANGED
@@ -1211,3 +1211,46 @@ register_conv_template(
|
|
1211 |
sep2='</s>',
|
1212 |
)
|
1213 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1211 |
sep2='</s>',
|
1212 |
)
|
1213 |
)
|
1214 |
+
|
1215 |
+
|
1216 |
+
if __name__ == '__main__':
|
1217 |
+
from fastchat.conversation import get_conv_template
|
1218 |
+
|
1219 |
+
print('-- Vicuna template --')
|
1220 |
+
conv = get_conv_template('vicuna_v1.1')
|
1221 |
+
conv.append_message(conv.roles[0], 'Hello!')
|
1222 |
+
conv.append_message(conv.roles[1], 'Hi!')
|
1223 |
+
conv.append_message(conv.roles[0], 'How are you?')
|
1224 |
+
conv.append_message(conv.roles[1], None)
|
1225 |
+
print(conv.get_prompt())
|
1226 |
+
|
1227 |
+
print('\n')
|
1228 |
+
|
1229 |
+
print('-- Llama-2 template --')
|
1230 |
+
conv = get_conv_template('llama-2')
|
1231 |
+
conv.set_system_message('You are a helpful, respectful and honest assistant.')
|
1232 |
+
conv.append_message(conv.roles[0], 'Hello!')
|
1233 |
+
conv.append_message(conv.roles[1], 'Hi!')
|
1234 |
+
conv.append_message(conv.roles[0], 'How are you?')
|
1235 |
+
conv.append_message(conv.roles[1], None)
|
1236 |
+
print(conv.get_prompt())
|
1237 |
+
|
1238 |
+
print('\n')
|
1239 |
+
|
1240 |
+
print('-- ChatGPT template --')
|
1241 |
+
conv = get_conv_template('chatgpt')
|
1242 |
+
conv.append_message(conv.roles[0], 'Hello!')
|
1243 |
+
conv.append_message(conv.roles[1], 'Hi!')
|
1244 |
+
conv.append_message(conv.roles[0], 'How are you?')
|
1245 |
+
conv.append_message(conv.roles[1], None)
|
1246 |
+
print(conv.to_openai_api_messages())
|
1247 |
+
|
1248 |
+
print('\n')
|
1249 |
+
|
1250 |
+
print('-- Claude template --')
|
1251 |
+
conv = get_conv_template('claude')
|
1252 |
+
conv.append_message(conv.roles[0], 'Hello!')
|
1253 |
+
conv.append_message(conv.roles[1], 'Hi!')
|
1254 |
+
conv.append_message(conv.roles[0], 'How are you?')
|
1255 |
+
conv.append_message(conv.roles[1], None)
|
1256 |
+
print(conv.get_prompt())
|
modeling_internvl_chat.py
CHANGED
@@ -3,16 +3,21 @@
|
|
3 |
# Copyright (c) 2023 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
|
|
6 |
from typing import Any, List, Optional, Tuple, Union
|
7 |
-
|
8 |
import torch.utils.checkpoint
|
9 |
from peft import LoraConfig, get_peft_model
|
10 |
from torch import nn
|
11 |
from torch.nn import CrossEntropyLoss
|
12 |
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
|
|
|
|
|
|
13 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
14 |
from transformers.modeling_utils import PreTrainedModel
|
15 |
from transformers.utils import ModelOutput, logging
|
|
|
16 |
|
17 |
from .configuration_internvl_chat import InternVLChatConfig
|
18 |
from .modeling_intern_vit import InternVisionModel
|
@@ -20,10 +25,183 @@ from .modeling_intern_vit import InternVisionModel
|
|
20 |
logger = logging.get_logger(__name__)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
class InternVLChatModel(PreTrainedModel):
|
24 |
config_class = InternVLChatConfig
|
25 |
main_input_name = 'pixel_values'
|
26 |
-
_no_split_modules = ['
|
27 |
|
28 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
29 |
super().__init__(config)
|
@@ -33,6 +211,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
33 |
self.select_layer = config.select_layer
|
34 |
self.template = config.template
|
35 |
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
|
|
36 |
logger.info(f'num_image_token: {self.num_image_token}')
|
37 |
if vision_model is not None:
|
38 |
self.vision_model = vision_model
|
@@ -41,7 +220,8 @@ class InternVLChatModel(PreTrainedModel):
|
|
41 |
if language_model is not None:
|
42 |
self.language_model = language_model
|
43 |
else:
|
44 |
-
self.language_model = LlamaForCausalLM(config.llm_config)
|
|
|
45 |
vit_hidden_size = config.vision_config.hidden_size
|
46 |
llm_hidden_size = config.llm_config.hidden_size
|
47 |
|
@@ -52,7 +232,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
52 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
53 |
)
|
54 |
|
55 |
-
if config.force_image_size:
|
56 |
self.vision_model.resize_pos_embeddings(
|
57 |
old_size=config.vision_config.image_size,
|
58 |
new_size=config.force_image_size,
|
@@ -173,16 +353,22 @@ class InternVLChatModel(PreTrainedModel):
|
|
173 |
return x
|
174 |
|
175 |
def extract_feature(self, pixel_values):
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
vit_embeds = vit_embeds[:, 1:, :]
|
181 |
# if torch.distributed.get_rank() == 0:
|
182 |
# print("before pixel shuffle:", vit_embeds.shape)
|
183 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
184 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
185 |
-
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=
|
186 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
187 |
# if torch.distributed.get_rank() == 0:
|
188 |
# print("after pixel shuffle:", vit_embeds.shape)
|
@@ -194,6 +380,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
194 |
|
195 |
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
196 |
self.img_context_token_id = img_context_token_id
|
|
|
197 |
from .conversation import get_conv_template
|
198 |
|
199 |
template = get_conv_template(self.template)
|
@@ -243,7 +430,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
243 |
input_ids = input_ids.reshape(B * N)
|
244 |
selected = (input_ids == self.img_context_token_id)
|
245 |
assert selected.sum() != 0
|
246 |
-
input_embeds[selected] = vit_embeds.reshape(-1, C)
|
247 |
|
248 |
input_embeds = input_embeds.reshape(B, N, C)
|
249 |
else:
|
|
|
3 |
# Copyright (c) 2023 OpenGVLab
|
4 |
# Licensed under The MIT License [see LICENSE for details]
|
5 |
# --------------------------------------------------------
|
6 |
+
import warnings
|
7 |
from typing import Any, List, Optional, Tuple, Union
|
8 |
+
import torch.distributed as dist
|
9 |
import torch.utils.checkpoint
|
10 |
from peft import LoraConfig, get_peft_model
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
14 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
15 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
16 |
+
from transformers.generation.streamers import BaseStreamer
|
17 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
18 |
from transformers.modeling_utils import PreTrainedModel
|
19 |
from transformers.utils import ModelOutput, logging
|
20 |
+
from transformers.generation.utils import GreedySearchOutput, validate_stopping_criteria, GreedySearchDecoderOnlyOutput,GreedySearchEncoderDecoderOutput
|
21 |
|
22 |
from .configuration_internvl_chat import InternVLChatConfig
|
23 |
from .modeling_intern_vit import InternVisionModel
|
|
|
25 |
logger = logging.get_logger(__name__)
|
26 |
|
27 |
|
28 |
+
# modified from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py
|
29 |
+
# Fix bug when using device_map='auto' for distributed inference
|
30 |
+
class MLlamaForCausalLM(LlamaForCausalLM):
|
31 |
+
|
32 |
+
def greedy_search(
|
33 |
+
self,
|
34 |
+
input_ids: torch.LongTensor,
|
35 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
36 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
37 |
+
max_length: Optional[int] = None,
|
38 |
+
pad_token_id: Optional[int] = None,
|
39 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
40 |
+
output_attentions: Optional[bool] = None,
|
41 |
+
output_hidden_states: Optional[bool] = None,
|
42 |
+
output_scores: Optional[bool] = None,
|
43 |
+
return_dict_in_generate: Optional[bool] = None,
|
44 |
+
synced_gpus: bool = False,
|
45 |
+
streamer: Optional["BaseStreamer"] = None,
|
46 |
+
**model_kwargs,
|
47 |
+
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
48 |
+
# init values
|
49 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
50 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
51 |
+
if max_length is not None:
|
52 |
+
warnings.warn(
|
53 |
+
"`max_length` is deprecated in this function, use"
|
54 |
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
55 |
+
UserWarning,
|
56 |
+
)
|
57 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
58 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
59 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
60 |
+
if isinstance(eos_token_id, int):
|
61 |
+
eos_token_id = [eos_token_id]
|
62 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
63 |
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
64 |
+
output_attentions = (
|
65 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
66 |
+
)
|
67 |
+
output_hidden_states = (
|
68 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
69 |
+
)
|
70 |
+
return_dict_in_generate = (
|
71 |
+
return_dict_in_generate
|
72 |
+
if return_dict_in_generate is not None
|
73 |
+
else self.generation_config.return_dict_in_generate
|
74 |
+
)
|
75 |
+
|
76 |
+
# init attention / hidden states / scores tuples
|
77 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
78 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
79 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
80 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
81 |
+
|
82 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
83 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
84 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
85 |
+
encoder_hidden_states = (
|
86 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
87 |
+
)
|
88 |
+
|
89 |
+
# keep track of which sequences are already finished
|
90 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
91 |
+
|
92 |
+
this_peer_finished = False # used by synced_gpus only
|
93 |
+
while True:
|
94 |
+
if synced_gpus:
|
95 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
96 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
97 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
98 |
+
# send 0.0 if we finished, 1.0 otherwise
|
99 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
100 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
101 |
+
if this_peer_finished_flag.item() == 0.0:
|
102 |
+
break
|
103 |
+
|
104 |
+
# prepare model inputs
|
105 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
106 |
+
|
107 |
+
# forward pass to get next token
|
108 |
+
outputs = self(
|
109 |
+
**model_inputs,
|
110 |
+
return_dict=True,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
)
|
114 |
+
|
115 |
+
if synced_gpus and this_peer_finished:
|
116 |
+
continue # don't waste resources running the code we don't need
|
117 |
+
|
118 |
+
next_token_logits = outputs.logits[:, -1, :]
|
119 |
+
|
120 |
+
# pre-process distribution
|
121 |
+
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
122 |
+
|
123 |
+
# Store scores, attentions and hidden_states when required
|
124 |
+
if return_dict_in_generate:
|
125 |
+
if output_scores:
|
126 |
+
scores += (next_tokens_scores,)
|
127 |
+
if output_attentions:
|
128 |
+
decoder_attentions += (
|
129 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
130 |
+
)
|
131 |
+
if self.config.is_encoder_decoder:
|
132 |
+
cross_attentions += (outputs.cross_attentions,)
|
133 |
+
|
134 |
+
if output_hidden_states:
|
135 |
+
decoder_hidden_states += (
|
136 |
+
(outputs.decoder_hidden_states,)
|
137 |
+
if self.config.is_encoder_decoder
|
138 |
+
else (outputs.hidden_states,)
|
139 |
+
)
|
140 |
+
|
141 |
+
# argmax
|
142 |
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1).to(device=input_ids.device)
|
143 |
+
# finished sentences should have their next token be a padding token
|
144 |
+
if eos_token_id is not None:
|
145 |
+
if pad_token_id is None:
|
146 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
147 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
148 |
+
|
149 |
+
# update generated ids, model inputs, and length for next step
|
150 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
151 |
+
if streamer is not None:
|
152 |
+
streamer.put(next_tokens.cpu())
|
153 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
154 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
155 |
+
)
|
156 |
+
|
157 |
+
# if eos_token was found in one sentence, set sentence to finished
|
158 |
+
if eos_token_id_tensor is not None:
|
159 |
+
unfinished_sequences = unfinished_sequences.mul(
|
160 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
161 |
+
)
|
162 |
+
|
163 |
+
# stop when each sentence is finished
|
164 |
+
if unfinished_sequences.max() == 0:
|
165 |
+
this_peer_finished = True
|
166 |
+
|
167 |
+
# stop if we exceed the maximum length
|
168 |
+
if stopping_criteria(input_ids, scores):
|
169 |
+
this_peer_finished = True
|
170 |
+
|
171 |
+
if this_peer_finished and not synced_gpus:
|
172 |
+
break
|
173 |
+
|
174 |
+
if streamer is not None:
|
175 |
+
streamer.end()
|
176 |
+
|
177 |
+
if return_dict_in_generate:
|
178 |
+
if self.config.is_encoder_decoder:
|
179 |
+
return GreedySearchEncoderDecoderOutput(
|
180 |
+
sequences=input_ids,
|
181 |
+
scores=scores,
|
182 |
+
encoder_attentions=encoder_attentions,
|
183 |
+
encoder_hidden_states=encoder_hidden_states,
|
184 |
+
decoder_attentions=decoder_attentions,
|
185 |
+
cross_attentions=cross_attentions,
|
186 |
+
decoder_hidden_states=decoder_hidden_states,
|
187 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
188 |
+
)
|
189 |
+
else:
|
190 |
+
return GreedySearchDecoderOnlyOutput(
|
191 |
+
sequences=input_ids,
|
192 |
+
scores=scores,
|
193 |
+
attentions=decoder_attentions,
|
194 |
+
hidden_states=decoder_hidden_states,
|
195 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
return input_ids
|
199 |
+
|
200 |
+
|
201 |
class InternVLChatModel(PreTrainedModel):
|
202 |
config_class = InternVLChatConfig
|
203 |
main_input_name = 'pixel_values'
|
204 |
+
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
|
205 |
|
206 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
|
207 |
super().__init__(config)
|
|
|
211 |
self.select_layer = config.select_layer
|
212 |
self.template = config.template
|
213 |
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
|
214 |
+
self.downsample_ratio = config.downsample_ratio
|
215 |
logger.info(f'num_image_token: {self.num_image_token}')
|
216 |
if vision_model is not None:
|
217 |
self.vision_model = vision_model
|
|
|
220 |
if language_model is not None:
|
221 |
self.language_model = language_model
|
222 |
else:
|
223 |
+
# self.language_model = LlamaForCausalLM(config.llm_config)
|
224 |
+
self.language_model = MLlamaForCausalLM(config.llm_config)
|
225 |
vit_hidden_size = config.vision_config.hidden_size
|
226 |
llm_hidden_size = config.llm_config.hidden_size
|
227 |
|
|
|
232 |
nn.Linear(llm_hidden_size, llm_hidden_size)
|
233 |
)
|
234 |
|
235 |
+
if config.force_image_size != config.vision_config.image_size:
|
236 |
self.vision_model.resize_pos_embeddings(
|
237 |
old_size=config.vision_config.image_size,
|
238 |
new_size=config.force_image_size,
|
|
|
353 |
return x
|
354 |
|
355 |
def extract_feature(self, pixel_values):
|
356 |
+
if self.select_layer == -1:
|
357 |
+
vit_embeds = self.vision_model(
|
358 |
+
pixel_values=pixel_values,
|
359 |
+
output_hidden_states=False,
|
360 |
+
return_dict=True).last_hidden_state
|
361 |
+
else:
|
362 |
+
vit_embeds = self.vision_model(
|
363 |
+
pixel_values=pixel_values,
|
364 |
+
output_hidden_states=True,
|
365 |
+
return_dict=True).hidden_states[self.select_layer]
|
366 |
vit_embeds = vit_embeds[:, 1:, :]
|
367 |
# if torch.distributed.get_rank() == 0:
|
368 |
# print("before pixel shuffle:", vit_embeds.shape)
|
369 |
h = w = int(vit_embeds.shape[1] ** 0.5)
|
370 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
371 |
+
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
372 |
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
373 |
# if torch.distributed.get_rank() == 0:
|
374 |
# print("after pixel shuffle:", vit_embeds.shape)
|
|
|
380 |
|
381 |
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
382 |
self.img_context_token_id = img_context_token_id
|
383 |
+
|
384 |
from .conversation import get_conv_template
|
385 |
|
386 |
template = get_conv_template(self.template)
|
|
|
430 |
input_ids = input_ids.reshape(B * N)
|
431 |
selected = (input_ids == self.img_context_token_id)
|
432 |
assert selected.sum() != 0
|
433 |
+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
434 |
|
435 |
input_embeds = input_embeds.reshape(B, N, C)
|
436 |
else:
|