|
import torch |
|
from typing import Optional |
|
|
|
class DummyModel: |
|
spatial_merge_size = 2 |
|
vision_start_token_id = 151652 |
|
vision_end_token_id = 151653 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_position_ids( |
|
self, |
|
input_ids: torch.Tensor, |
|
image_grid_thw: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
if image_grid_thw is None: |
|
return ( |
|
torch.arange(input_ids.shape[0], device=input_ids.device) |
|
.unsqueeze(1) |
|
.repeat(1, 3) |
|
) |
|
|
|
spatial_merge_size = self.spatial_merge_size |
|
vision_start_token_id = self.vision_start_token_id |
|
vision_end_token_id = self.vision_end_token_id |
|
device = input_ids.device |
|
dtype = input_ids.dtype |
|
input_ids_len = input_ids.shape[0] |
|
|
|
vision_starts = torch.where(input_ids == vision_start_token_id)[0] |
|
vision_ends = torch.where(input_ids == vision_end_token_id)[0] |
|
vision_segments = torch.stack((vision_starts, vision_ends), dim=1) |
|
prev_vision_end = torch.cat( |
|
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]] |
|
) |
|
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1 |
|
vision_widths_max = torch.cat( |
|
[ |
|
torch.zeros(1, device=image_grid_thw.device, dtype=dtype), |
|
image_grid_thw[:-1, 2] // spatial_merge_size, |
|
] |
|
) |
|
vision_segment_lengths = vision_widths_max + text_lengths_between_vision |
|
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0) |
|
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision |
|
|
|
|
|
llm_pos_ids_list = [] |
|
for i, _ in enumerate(vision_segments): |
|
t, h, w = ( |
|
image_grid_thw[i][0], |
|
image_grid_thw[i][1] // spatial_merge_size, |
|
image_grid_thw[i][2] // spatial_merge_size, |
|
) |
|
t_indices = torch.arange(t, device=device).repeat_interleave(h * w) |
|
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t) |
|
w_indices = torch.arange(w, device=device).repeat(t * h) |
|
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0) |
|
|
|
|
|
im = image_position_ids + vision_segment_lengths[i] |
|
llm_pos_ids_list.append(im) |
|
|
|
|
|
text_ranges = [ |
|
torch.zeros(3, seq_len, device=device) + text_segment_lengths[i] |
|
for i, seq_len in enumerate(text_lengths_between_vision) |
|
] |
|
|
|
full_llm_pos_ids_list = [ |
|
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist |
|
] |
|
max_s = full_llm_pos_ids_list[-1].max() + 1 |
|
final_text_len = input_ids_len - vision_ends[-1] |
|
if final_text_len > 0: |
|
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1) |
|
full_llm_pos_ids_list.append(m + max_s) |
|
|
|
position_ids = ( |
|
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1) |
|
) |
|
return position_ids |