|
import logging |
|
import math |
|
import os |
|
import re |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torchvision.ops import roi_align |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
Qwen2Config, |
|
Qwen2ForCausalLM, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
) |
|
from transformers.generation.utils import GenerateOutput |
|
from transformers.utils import logging, strtobool |
|
|
|
from .clip import CLIPVisionTower |
|
from .convnext import ConvNextVisionEncoder |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() |
|
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN_INDEX = 0 |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
|
|
|
DEFAULT_OBJECT_TOKEN = "<obj<i>>" |
|
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" |
|
DEFAULT_OBJECT_INDEX = -300 |
|
|
|
|
|
DEFAULT_GROUNDING_START = "<ground>" |
|
DEFAULT_GROUNDING_END = "</ground>" |
|
DEFAULT_GROUNDING_OBJECTS_START = "<objects>" |
|
DEFAULT_GROUNDING_OBJECTS_END = "</objects>" |
|
|
|
|
|
def is_fsdp_enabled(): |
|
return ( |
|
torch.distributed.is_available() |
|
and torch.distributed.is_initialized() |
|
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 |
|
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 |
|
) |
|
|
|
|
|
class IdentityMap(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_projector_type": "identity"} |
|
|
|
|
|
class SimpleResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.pre_norm = nn.LayerNorm(channels) |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.pre_norm(x) |
|
return x + self.proj(x) |
|
|
|
|
|
def build_vision_projector(config, start_hidden_size, delay_load=False, **kwargs): |
|
projector_type = "mlp2x_gelu" |
|
|
|
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) |
|
if mlp_gelu_match: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
modules = [nn.Linear(start_hidden_size, config.hidden_size)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
|
return nn.Sequential(*modules) |
|
|
|
if projector_type == "identity": |
|
return IdentityMap() |
|
|
|
raise ValueError(f"Unknown projector type: {projector_type}") |
|
|
|
|
|
def get_token_slices(input_ids: torch.Tensor): |
|
""" |
|
Get slices of tokens based on special markers in the input tensor. |
|
|
|
Args: |
|
input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, |
|
DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the |
|
token slice ('text', 'image', 'object') and the span as a list of start and end indices. |
|
""" |
|
|
|
type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} |
|
|
|
|
|
image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] |
|
object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] |
|
if len(object_indices) > 0: |
|
has_object = True |
|
else: |
|
has_object = False |
|
|
|
|
|
special_indices = torch.cat((image_indices, object_indices)) |
|
special_indices, _ = torch.sort(special_indices) |
|
special_tokens = input_ids[special_indices] |
|
|
|
slices = [] |
|
start_idx = 0 |
|
|
|
for i, idx in enumerate(special_indices): |
|
if start_idx < idx: |
|
slices.append({"type": "text", "span": [start_idx, idx.item()]}) |
|
token_type = type_map[special_tokens[i].item()] |
|
slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) |
|
start_idx = idx.item() + 1 |
|
|
|
if start_idx < len(input_ids): |
|
slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) |
|
|
|
return slices, has_object |
|
|
|
|
|
class StopWordStoppingCriteria(StoppingCriteria): |
|
"""StopWord stopping criteria.""" |
|
|
|
def __init__(self, tokenizer, stop_word): |
|
self.tokenizer = tokenizer |
|
self.stop_word = stop_word |
|
self.length = len(self.stop_word) |
|
|
|
def __call__(self, input_ids, *args, **kwargs) -> bool: |
|
cur_text = self.tokenizer.decode(input_ids[0]) |
|
cur_text = cur_text.replace("\r", "").replace("\n", "") |
|
return cur_text[-self.length :] == self.stop_word |
|
|
|
|
|
def get_stop_criteria( |
|
tokenizer, |
|
stop_words=[], |
|
): |
|
stop_criteria = StoppingCriteriaList() |
|
for word in stop_words: |
|
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
|
return stop_criteria |
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): |
|
"""Generate sine position embedding from a position tensor. |
|
|
|
Args: |
|
pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in |
|
normalized coordinates in range [0, 1]. |
|
out_dim (int): the output dimension of the position embedding. |
|
|
|
Returns: |
|
pos (torch.Tensor): shape: [batch_size, N, out_dim]. |
|
""" |
|
scale = 2 * math.pi |
|
dim_t = torch.arange( |
|
dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device |
|
) |
|
dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) |
|
x_embed = pos_tensor[:, :, 0] * scale |
|
y_embed = pos_tensor[:, :, 1] * scale |
|
pos_x = x_embed[:, :, None] / dim_t |
|
pos_y = y_embed[:, :, None] / dim_t |
|
pos_x = torch.stack( |
|
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
pos_y = torch.stack( |
|
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
if pos_tensor.size(-1) == 2: |
|
pos = torch.cat((pos_y, pos_x), dim=2) |
|
elif pos_tensor.size(-1) == 4: |
|
w_embed = pos_tensor[:, :, 2] * scale |
|
pos_w = w_embed[:, :, None] / dim_t |
|
pos_w = torch.stack( |
|
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
|
|
h_embed = pos_tensor[:, :, 3] * scale |
|
pos_h = h_embed[:, :, None] / dim_t |
|
pos_h = torch.stack( |
|
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 |
|
).flatten(2) |
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) |
|
else: |
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) |
|
return pos |
|
|
|
|
|
class MultiLevelROIVisualPrompt(nn.Module): |
|
"""Initialize the MultiLevelROIVisualPrompt. |
|
|
|
Args: |
|
output_size (Optional[int]): The size of the output. Default is None. |
|
channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. |
|
spatial_scale (Optional[float]): The spatial scale factor. Default is None. |
|
with_additional_projection (bool): Whether to use additional projection. Default is False. |
|
visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. |
|
add_pos_embedding (bool): Whether to add position embedding. Default is False. |
|
pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
output_size: int = None, |
|
channel_per_level: List[int] = [192, 384, 768, 1536], |
|
spatail_scale: float = None, |
|
add_pos_embedding: bool = False, |
|
pos_embedding_dim: int = 1024, |
|
): |
|
super(MultiLevelROIVisualPrompt, self).__init__() |
|
self.output_size = output_size |
|
self.channel_per_level = channel_per_level |
|
self.spatail_scale = spatail_scale |
|
self.add_pos_embedding = add_pos_embedding |
|
self.pos_embedding_dim = pos_embedding_dim |
|
|
|
def __call__( |
|
self, |
|
multi_level_features: List[torch.Tensor], |
|
boxes: Union[torch.Tensor, List[torch.Tensor]], |
|
) -> torch.Tensor: |
|
"""Performs Region of Interest (RoI) Align operator on multi-level features. The RoI |
|
feature on each scale will go through a different linear layer for projection. Different |
|
RoI features will be summed up and then average pooled. |
|
|
|
Args: |
|
multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels |
|
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) |
|
format where the regions will be taken from. |
|
Returns: |
|
Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs |
|
""" |
|
boxes[0] = boxes[0].float() |
|
concat_multi_level_feature = [] |
|
max_height = max([feature.shape[2] for feature in multi_level_features]) |
|
max_width = max([feature.shape[3] for feature in multi_level_features]) |
|
|
|
for level, feature in enumerate(multi_level_features): |
|
if level != 0: |
|
concat_multi_level_feature.append( |
|
F.interpolate( |
|
feature.float(), |
|
size=(max_height, max_width), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
else: |
|
concat_multi_level_feature.append(feature.float()) |
|
concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) |
|
|
|
out_box_feat = roi_align( |
|
concat_multi_level_feature, |
|
boxes, |
|
output_size=self.output_size, |
|
spatial_scale=self.spatail_scale, |
|
) |
|
|
|
|
|
out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( |
|
1, out_box_feat.shape[0], out_box_feat.shape[1] |
|
) |
|
if self.add_pos_embedding: |
|
|
|
boxes = boxes[0] |
|
boxes = boxes.to(out_box_feat.dtype) |
|
original_img_width = max_width / self.spatail_scale |
|
original_img_height = max_height / self.spatail_scale |
|
boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width |
|
boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height |
|
|
|
boxes[:, 2] = boxes[:, 2] - boxes[:, 0] |
|
boxes[:, 3] = boxes[:, 3] - boxes[:, 1] |
|
boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 |
|
boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 |
|
pos_embed = gen_sineembed_for_position( |
|
boxes.unsqueeze(0), self.pos_embedding_dim // 4 |
|
) |
|
out_box_feat = out_box_feat + pos_embed |
|
|
|
return out_box_feat |
|
|
|
|
|
class RexSeekQwenConfig(Qwen2Config): |
|
model_type = "rexseek_qwen" |
|
|
|
|
|
class RexSeekQwenForCausalLM(Qwen2ForCausalLM): |
|
|
|
config_class = RexSeekQwenConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
vision_tower = getattr( |
|
config, |
|
"mm_vision_tower", |
|
getattr(config, "vision_tower", None), |
|
) |
|
self.vision_tower = CLIPVisionTower( |
|
vision_tower, |
|
args=config, |
|
) |
|
|
|
self.vision_tower_aux = ConvNextVisionEncoder() |
|
|
|
|
|
self.mm_projector = build_vision_projector( |
|
config, start_hidden_size=2560 |
|
) |
|
|
|
self.mm_object_projector = build_vision_projector( |
|
config, start_hidden_size=2880 |
|
) |
|
|
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.box_encoder = MultiLevelROIVisualPrompt( |
|
output_size=7, |
|
channel_per_level=[192, 384, 768, 1536], |
|
spatail_scale=192 / 768, |
|
add_pos_embedding=True, |
|
pos_embedding_dim=2880, |
|
) |
|
self.post_init() |
|
print("model initialized") |
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, "vision_tower", None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
def get_vision_tower_aux(self): |
|
vision_tower_aux = getattr(self, "vision_tower_aux", None) |
|
if type(vision_tower_aux) is list: |
|
vision_tower_aux = vision_tower_aux[0] |
|
return vision_tower_aux |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
def encode_images(self, images, images_aux): |
|
low_res_feat = self.get_vision_tower()(images) |
|
aux_output = self.get_vision_tower_aux()(images_aux) |
|
visual_outputs_aux = aux_output["image_features"] |
|
high_res_feat = aux_output["last_feat"] |
|
|
|
b, c, h, w = high_res_feat.shape |
|
_, _, d = low_res_feat.shape |
|
high_res_feat = high_res_feat.view(b, c, h * w).transpose(1, 2) |
|
image_features = torch.cat((low_res_feat, high_res_feat), dim=-1) |
|
image_features = self.mm_projector(image_features) |
|
return image_features, visual_outputs_aux |
|
|
|
def encode_objects( |
|
self, bboxes, visual_outputs_aux, dtype, num_gt_boxes_per_image=None |
|
): |
|
"""Encode object features from bounding boxes. |
|
|
|
Args: |
|
bboxes (torch.Tensor): bounding boxes in the shape of (N, 4) |
|
image_features_before_proj (torch.Tensor): image features in the shape of (N, hidden_size) |
|
|
|
Returns: |
|
torch.Tensor: object features in the shape of (N, hidden_size) |
|
""" |
|
bbox_visual_outputs = [] |
|
for batch_idx, boxes in enumerate(bboxes): |
|
num_box = ( |
|
num_gt_boxes_per_image[batch_idx] |
|
if num_gt_boxes_per_image is not None |
|
else len(boxes) |
|
) |
|
boxes = boxes[:num_box] |
|
if len(boxes) == 0: |
|
bbox_visual_outputs.append(None) |
|
continue |
|
multi_level_aux_features = [ |
|
visual_output_aux[batch_idx].unsqueeze(0) |
|
for visual_output_aux in visual_outputs_aux |
|
] |
|
out_vp_feat = self.box_encoder( |
|
multi_level_aux_features, |
|
[boxes], |
|
).squeeze(0) |
|
out_vp_feat = out_vp_feat.to(dtype) |
|
out_vp_feat = self.mm_object_projector(out_vp_feat) |
|
bbox_visual_outputs.append(out_vp_feat) |
|
|
|
return bbox_visual_outputs |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
labels, |
|
pixel_values=None, |
|
pixel_values_aux=None, |
|
gt_boxes=None, |
|
num_gt_boxes_per_image=None, |
|
): |
|
if pixel_values is None: |
|
return ( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
None, |
|
labels, |
|
) |
|
pixel_values, visual_outputs_aux = self.encode_images( |
|
pixel_values, pixel_values_aux |
|
) |
|
if gt_boxes is not None: |
|
bbox_feats = self.encode_objects( |
|
gt_boxes, visual_outputs_aux, pixel_values.dtype, num_gt_boxes_per_image |
|
) |
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
|
) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] |
|
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
labels = [ |
|
cur_labels[cur_attention_mask] |
|
for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
|
] |
|
|
|
new_input_embeds = [] |
|
new_labels = [] |
|
cur_image_idx = 0 |
|
cur_object_idx = 0 |
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_image_features = pixel_values[cur_image_idx] |
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) |
|
cur_input_embeds = torch.cat( |
|
[cur_input_embeds_1, cur_image_features[0:0]], dim=0 |
|
) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
cur_image_idx += 1 |
|
cur_object_idx += 1 |
|
continue |
|
|
|
cur_labels = labels[batch_idx] |
|
token_slices, has_object = get_token_slices(cur_input_ids) |
|
result_input_embeddings = [] |
|
result_output_labels = [] |
|
cur_gt_bnox_indice = 0 |
|
cur_object_features = None |
|
for slice in token_slices: |
|
slice_type = slice["type"] |
|
slice_span = slice["span"] |
|
if slice_type == "text": |
|
cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] |
|
cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] |
|
cur_input_embeds = self.get_model().embed_tokens(cur_input_ids_noim) |
|
result_input_embeddings.append(cur_input_embeds) |
|
result_output_labels.append(cur_labels_noim) |
|
elif slice_type == "image": |
|
cur_input_embeds = pixel_values[cur_image_idx] |
|
result_input_embeddings.append(cur_input_embeds) |
|
result_output_labels.append( |
|
torch.full( |
|
(cur_input_embeds.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
cur_image_idx += 1 |
|
elif slice_type == "object": |
|
try: |
|
result_input_embeddings.append( |
|
bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) |
|
) |
|
except: |
|
raise ValueError( |
|
f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " |
|
) |
|
cur_gt_bnox_indice += 1 |
|
result_output_labels.append( |
|
torch.full( |
|
(1,), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
cur_object_idx += 1 |
|
result_input_embeddings = torch.cat(result_input_embeddings) |
|
result_output_labels = torch.cat(result_output_labels) |
|
assert len(result_output_labels) == len(result_input_embeddings) |
|
new_input_embeds.append(result_input_embeddings) |
|
new_labels.append(result_output_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr( |
|
self.config, "tokenizer_model_max_length", None |
|
) |
|
if tokenizer_model_max_length is not None: |
|
new_input_embeds = [ |
|
x[:tokenizer_model_max_length] for x in new_input_embeds |
|
] |
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
position_ids = torch.zeros( |
|
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate( |
|
zip(new_input_embeds, new_labels) |
|
): |
|
cur_len = cur_new_embed.shape[0] |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
return ( |
|
None, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor], |
|
pixel_values: Optional[torch.Tensor], |
|
pixel_values_aux: Optional[torch.Tensor], |
|
position_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
|
|
if inputs_embeds is None: |
|
position_ids = kwargs.pop("position_ids", None) |
|
attention_mask = kwargs.pop("attention_mask", None) |
|
gt_boxes = kwargs.pop("gt_boxes", None) |
|
num_gt_boxes_per_image = kwargs.pop("num_gt_boxes_per_image", None) |
|
|
|
if pixel_values is not None: |
|
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( |
|
self.prepare_inputs_labels_for_multimodal( |
|
inputs, |
|
position_ids, |
|
attention_mask, |
|
past_key_values=None, |
|
labels=None, |
|
pixel_values=pixel_values, |
|
pixel_values_aux=pixel_values_aux, |
|
gt_boxes=gt_boxes, |
|
num_gt_boxes_per_image=num_gt_boxes_per_image, |
|
) |
|
) |
|
|
|
else: |
|
inputs_embeds = self.get_model().embed_tokens(inputs) |
|
|
|
return super().generate( |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
**kwargs, |
|
) |
|
|
|
|
|
AutoConfig.register("rexseek_qwen", RexSeekQwenConfig) |
|
AutoModelForCausalLM.register(RexSeekQwenConfig, RexSeekQwenForCausalLM) |
|
|