Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/models
/kosmos2
/processing_kosmos2.py
# coding=utf-8 | |
# Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Processor class for KOSMOS-2.""" | |
import copy | |
import math | |
import re | |
from typing import List, Optional, Tuple, Union | |
from ...image_processing_utils import BatchFeature | |
from ...image_utils import ImageInput, is_batched | |
from ...processing_utils import ProcessorMixin | |
from ...tokenization_utils import AddedToken | |
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy | |
from ...utils import TensorType | |
BboxInput = Union[ | |
List[Tuple[int, int]], | |
List[Tuple[float, float, float, float]], | |
List[List[Tuple[int, int]]], | |
List[List[Tuple[float, float, float]]], | |
] | |
class Kosmos2Processor(ProcessorMixin): | |
r""" | |
Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single | |
processor. | |
[`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and some functionalities of | |
[`XLMRobertaTokenizerFast`]. See the docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`] | |
for more information. | |
Args: | |
image_processor (`CLIPImageProcessor`): | |
An instance of [`CLIPImageProcessor`]. The image processor is a required input. | |
tokenizer (`XLMRobertaTokenizerFast`): | |
An instance of ['XLMRobertaTokenizerFast`]. The tokenizer is a required input. | |
num_patch_index_tokens (`int`, *optional*, defaults to 1024): | |
The number of tokens that represent patch indices. | |
""" | |
attributes = ["image_processor", "tokenizer"] | |
valid_kwargs = ["num_patch_index_tokens"] | |
image_processor_class = "CLIPImageProcessor" | |
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast") | |
def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs): | |
tokenizer.return_token_type_ids = False | |
self.eod_token = "</doc>" | |
self.boi_token = "<image>" | |
self.eoi_token = "</image>" | |
self.eoc_token = "</chunk>" | |
self.eol_token = "</line>" | |
self.bop_token = "<phrase>" | |
self.eop_token = "</phrase>" | |
self.boo_token = "<object>" | |
self.eoo_token = "</object>" | |
self.dom_token = "</delimiter_of_multi_objects/>" | |
self.grd_token = "<grounding>" | |
self.tag_tokens = [ | |
self.eod_token, | |
self.boi_token, | |
self.eoi_token, | |
self.eoc_token, | |
self.eol_token, | |
self.bop_token, | |
self.eop_token, | |
self.boo_token, | |
self.eoo_token, | |
self.dom_token, | |
self.grd_token, | |
] | |
self.num_patch_index_tokens = num_patch_index_tokens | |
patch_index_tokens = [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.num_patch_index_tokens)] | |
tokens_to_add = [] | |
for token in self.tag_tokens + patch_index_tokens: | |
tokens_to_add.append(AddedToken(token, lstrip=True, rstrip=False, normalized=False)) | |
tokenizer.add_tokens(tokens_to_add) | |
super().__init__(image_processor, tokenizer) | |
def __call__( | |
self, | |
images: ImageInput = None, | |
text: Union[TextInput, List[TextInput]] = None, | |
bboxes: BboxInput = None, | |
num_image_tokens: Optional[int] = 64, | |
first_image_token_id: Optional[int] = None, | |
add_special_tokens: bool = True, | |
add_eos_token: bool = False, | |
padding: Union[bool, str, PaddingStrategy] = False, | |
truncation: Union[bool, str, TruncationStrategy] = None, | |
max_length: Optional[int] = None, | |
pad_to_multiple_of: Optional[int] = None, | |
return_attention_mask: Optional[bool] = None, | |
return_length: bool = False, | |
verbose: bool = True, | |
return_tensors: Optional[Union[str, TensorType]] = None, | |
**kwargs, | |
) -> BatchFeature: | |
""" | |
This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and | |
[`XLMRobertaTokenizerFast.__call__`] to prepare text for the model. | |
Please refer to the docstring of the above two methods for more information. | |
The rest of this documentation shows the arguments specific to `Kosmos2Processor`. | |
Args: | |
bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*): | |
The bounding bboxes associated to `texts`. | |
num_image_tokens (`int`, *optional* defaults to 64): | |
The number of (consecutive) places that are used to mark the placeholders to store image information. | |
This should be the same as `latent_query_num` in the instance of `Kosmos2Config` you are using. | |
first_image_token_id (`int`, *optional*): | |
The token id that will be used for the first place of the subsequence that is reserved to store image | |
information. If unset, will default to `self.tokenizer.unk_token_id + 1`. | |
add_eos_token (`bool`, defaults to `False`): | |
Whether or not to include `EOS` token id in the encoding when `add_special_tokens=True`. | |
""" | |
if images is None and text is None: | |
raise ValueError("You have to specify either images or text.") | |
encoding = BatchFeature() | |
if images is not None: | |
image_encoding = self.image_processor(images, return_tensors=return_tensors) | |
encoding.update(image_encoding) | |
if text is not None: | |
text = self.preprocess_examples(text, images, bboxes, num_image_tokens=num_image_tokens) | |
if add_special_tokens and not add_eos_token: | |
if isinstance(text, str): | |
text = f"{self.tokenizer.bos_token}{text}" | |
elif isinstance(text, list): | |
text = [f"{self.tokenizer.bos_token}{s}" for s in text] | |
text_encoding = self.tokenizer( | |
text=text, | |
add_special_tokens=(add_special_tokens and add_eos_token), | |
padding=padding and images is None, | |
truncation=truncation, | |
max_length=max_length, | |
pad_to_multiple_of=pad_to_multiple_of if images is None else pad_to_multiple_of, | |
return_attention_mask=return_attention_mask, | |
verbose=verbose, | |
return_tensors=return_tensors if images is None else None, | |
**kwargs, | |
) | |
encoding.update(text_encoding) | |
if text is not None and images is not None: | |
# Use the id of the first token after <unk> | |
if first_image_token_id is None: | |
first_image_token_id = self.tokenizer.unk_token_id + 1 | |
# To see if we need one more `0` (for `<s>`) at the beginning of `image_embeds_position_mask`. | |
with_bos = add_special_tokens | |
# The first (actual) `<image>` token is always at the 1st or 2nd place (after `<s>` if any). Here we look | |
# for the second `<image>` token (which indicate the first image token). | |
start_index = int(with_bos) + 1 | |
# Add `image_embeds_position_mask`: the leading and trailing `0` are for `boi` and `eoi` tokens. The `1` indicates | |
# the places of image tokens. | |
image_token_ids = list(range(first_image_token_id, first_image_token_id + num_image_tokens)) | |
base_image_embeds_position_mask = [0] + [1] * num_image_tokens + [0] | |
# loop over `encoding["input_ids"]` | |
input_ids = [] | |
image_embeds_position_mask = [] | |
all_input_ids = encoding["input_ids"] | |
# not batched -> (changed to) batch of size 1 | |
if isinstance(text, str): | |
all_input_ids = [all_input_ids] | |
encoding["attention_mask"] = [encoding["attention_mask"]] | |
for text_ids in all_input_ids: | |
# change the ids for the fake `<image>` tokens in `input_ids` | |
text_ids = text_ids[:start_index] + image_token_ids + text_ids[start_index + num_image_tokens :] | |
input_ids.append(text_ids) | |
mask = copy.copy(base_image_embeds_position_mask) | |
if with_bos: | |
# for `<s>` | |
mask = [0] + mask | |
# trailing part (which are not related to the image) | |
mask += [0] * (len(text_ids) - len(mask)) | |
image_embeds_position_mask.append(mask) | |
if isinstance(text, list): | |
sorted_length = sorted( | |
[(idx, len(x)) for idx, x in enumerate(text_encoding.input_ids)], key=lambda x: x[-1] | |
) | |
_, min_len_not_padded = sorted_length[0] | |
idx, _ = sorted_length[-1] | |
text_encoding = self.tokenizer( | |
text=[text[idx]], | |
add_special_tokens=(add_special_tokens and add_eos_token), | |
padding=padding, | |
truncation=truncation, | |
max_length=max_length, | |
pad_to_multiple_of=pad_to_multiple_of, | |
verbose=verbose, | |
return_tensors=None, | |
**kwargs, | |
) | |
max_len_padded = len(text_encoding.input_ids[0]) | |
if min_len_not_padded != max_len_padded: | |
if self.tokenizer.padding_side == "right": | |
input_ids = [x + [self.tokenizer.pad_token_id] * (max_len_padded - len(x)) for x in input_ids] | |
image_embeds_position_mask = [ | |
x + [0] * (max_len_padded - len(x)) for x in image_embeds_position_mask | |
] | |
encoding["attention_mask"] = [ | |
x + [0] * (max_len_padded - len(x)) for x in encoding["attention_mask"] | |
] | |
elif self.tokenizer.padding_side == "left": | |
input_ids = [[self.tokenizer.pad_token_id] * (max_len_padded - len(x)) + x for x in input_ids] | |
image_embeds_position_mask = [ | |
[0] * (max_len_padded - len(x)) + x for x in image_embeds_position_mask | |
] | |
encoding["attention_mask"] = [ | |
[0] * (max_len_padded - len(x)) + x for x in encoding["attention_mask"] | |
] | |
# un-batch if necessary | |
if isinstance(text, str) and return_tensors is None: | |
input_ids = input_ids[0] | |
encoding["attention_mask"] = encoding["attention_mask"][0] | |
image_embeds_position_mask = image_embeds_position_mask[0] | |
# update (with the target tensor type if specified) | |
encoding.update( | |
BatchEncoding( | |
data={ | |
"input_ids": input_ids, | |
"attention_mask": encoding["attention_mask"], | |
"image_embeds_position_mask": image_embeds_position_mask, | |
}, | |
tensor_type=return_tensors, | |
) | |
) | |
return encoding | |
def _check_bboxes_for_single_text(self, bboxes): | |
""" | |
Check `bboxes` for a single text example. It could be | |
- `None`: no bounding box associated to a text. | |
- A list with each element being the bounding boxes associated to one `<phrase> ... </phrase>` pair found | |
in a text. This could be: | |
- `None`: no bounding box associated to a `<phrase> ... </phrase>` pair. | |
- A tuple of 2 integers: A single bounding box specified by patch indices. | |
- A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates. | |
- A list containing the above 2 tuple types: Multiple bounding boxes for a | |
`<phrase> ... </phrase>` pair. | |
""" | |
if bboxes is None: | |
return | |
elif not isinstance(bboxes, list): | |
raise ValueError("`bboxes` (for a single text example) should be `None` or a list.") | |
# `bbox` is the bounding boxes for a single <phrase> </phrase> pair | |
for bbox in bboxes: | |
if bbox is None: | |
continue | |
elif not isinstance(bbox, list): | |
bbox = [bbox] | |
for element in bbox: | |
if not isinstance(element, tuple) or not ( | |
(len(element) == 2 and all(isinstance(x, int) for x in element)) | |
or (len(element) == 4 and all(isinstance(x, float) for x in element)) | |
): | |
raise ValueError( | |
"Each element in `bboxes` (for a single text example) should be either `None`, a tuple containing " | |
"2 integers or 4 float point numbers, or a list containing such tuples. Also " | |
"make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in " | |
"batches or both for a single example." | |
) | |
def _preprocess_single_example(self, text, image, bboxes, img_info_tokens): | |
text = text.strip() | |
if image is not None: | |
# Add `<image> ... (fake) image tokens ... </image>` | |
text = f"{img_info_tokens} {text}" | |
# Add `<object> <patch_idx_xxxx> <patch_idx_yyy> </object>` after `<phrase> phrase text </phrase>` | |
text = self._insert_patch_index_tokens(text, bboxes) | |
return text | |
def preprocess_examples( | |
self, | |
texts: Union[TextInput, List[TextInput]], | |
images: ImageInput = None, | |
bboxes: BboxInput = None, | |
num_image_tokens: Optional[int] = 64, | |
) -> Union[str, List[str]]: | |
"""Add image and bounding box information to `texts` as image and patch index tokens. | |
Args: | |
texts (`Union[TextInput, List[TextInput]]`): The texts to be processed. | |
images (`ImageInput`, *optional*): The images associated to `texts`. | |
bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*): | |
The bounding bboxes associated to `texts`. | |
num_image_tokens (`int`, *optional*, defaults to 64): | |
The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num` | |
attribute in `Kosmos2Config`. | |
Returns: | |
`Union[TextInput, List[TextInput]]`: The processed texts with image and patch index tokens. | |
""" | |
# These are fake `<image>` tokens enclosed between (the actual) `<image>` token and `</image>`. | |
img_tokens = [self.boi_token] * num_image_tokens | |
img_info_tokens = " ".join([self.boi_token] + img_tokens + [self.eoi_token]) | |
# make batch to simplify processing logic | |
batched = True | |
if isinstance(texts, str): | |
batched = False | |
texts = [texts] | |
if images is None: | |
images = [None] * len(texts) | |
elif not is_batched(images): | |
images = [images] | |
if len(texts) != len(images): | |
raise ValueError( | |
f"The number of examples in `texts` and `images` should be the same. Got {len(texts)} v.s. {len(images)} instead." | |
) | |
if not batched: | |
self._check_bboxes_for_single_text(bboxes) | |
bboxes = [bboxes] | |
elif bboxes is not None: | |
if not isinstance(bboxes, list): | |
raise ValueError("`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.") | |
for x in bboxes: | |
self._check_bboxes_for_single_text(x) | |
else: | |
bboxes = [None] * len(texts) | |
if len(bboxes) != len(texts): | |
raise ValueError( | |
f"The number of examples in `texts` and `bboxes` should be the same. Got {len(texts)} v.s. {len(bboxes)} instead." | |
) | |
result = [ | |
self._preprocess_single_example(text, image, bbox, img_info_tokens) | |
for text, image, bbox in zip(texts, images, bboxes) | |
] | |
# un-batch if necessary | |
if not batched: | |
result = result[0] | |
return result | |
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer | |
def batch_decode(self, *args, **kwargs): | |
""" | |
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please | |
refer to the docstring of this method for more information. | |
""" | |
return self.tokenizer.batch_decode(*args, **kwargs) | |
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer | |
def decode(self, *args, **kwargs): | |
""" | |
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to | |
the docstring of this method for more information. | |
""" | |
return self.tokenizer.decode(*args, **kwargs) | |
def post_process_generation(self, text, cleanup_and_extract=True): | |
caption = text.split(self.eoi_token)[-1] | |
if cleanup_and_extract: | |
return clean_text_and_extract_entities_with_bboxes(caption) | |
return caption | |
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names | |
def model_input_names(self): | |
tokenizer_input_names = self.tokenizer.model_input_names | |
image_processor_input_names = self.image_processor.model_input_names | |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |
def _insert_patch_index_tokens(self, text: str, bboxes: Union[List[Tuple[int]], List[Tuple[float]]]) -> str: | |
if bboxes is None or len(bboxes) == 0: | |
return text | |
matched_phrases = list(re.finditer(r"<phrase>.+?</phrase>", string=text)) | |
if len(matched_phrases) != len(bboxes): | |
raise ValueError( | |
f"The number of elements in `bboxes` should be the same as the number of `<phrase> ... </phrase>` pairs in `text`. Got {len(matched_phrases)} v.s. {len(bboxes)} instead." | |
) | |
# insert object's patch index tokens | |
# the found `<phrase> ... </phrase>` pairs. | |
curr_pos = 0 | |
buffer = [] | |
for matched, bbox in zip(matched_phrases, bboxes): | |
_, end = matched.span() | |
buffer.append(text[curr_pos:end]) | |
curr_pos = end | |
# A phrase without bbox | |
if bbox is None: | |
continue | |
# A phrase with a single bbox | |
if isinstance(bbox, tuple): | |
bbox = [bbox] | |
patch_index_strings = [] | |
# A phrase could have multiple bboxes | |
if not all(box is not None for box in bbox): | |
raise ValueError( | |
"The multiple bounding boxes for a single phrase should not contain any `None` value." | |
) | |
for box in bbox: | |
patch_index_1, patch_index_2 = self._convert_bbox_to_patch_index_tokens(box) | |
patch_index_strings.append(f"{patch_index_1} {patch_index_2}") | |
# `bbox` being an empty list | |
if len(patch_index_strings) == 0: | |
continue | |
position_str = " </delimiter_of_multi_objects/> ".join(patch_index_strings) | |
buffer.append(f"<object> {position_str} </object>") | |
# remaining | |
if curr_pos < len(text): | |
buffer.append(text[curr_pos:]) | |
text = "".join(buffer) | |
return text | |
def _convert_bbox_to_patch_index_tokens( | |
self, bbox: Union[Tuple[int, int], Tuple[float, float, float, float]] | |
) -> Tuple[str, str]: | |
# already computed patch indices | |
if len(bbox) == 2: | |
idx_1, idx_2 = bbox | |
# bbox specified with (normalized) coordinates | |
else: | |
# use `self.tokenizer` to get `num_patches_per_side` | |
num_patches_per_side = int(math.sqrt(self.num_patch_index_tokens)) | |
idx_1, idx_2 = coordinate_to_patch_index(bbox, num_patches_per_side) | |
token_1 = f"<patch_index_{str(idx_1).zfill(4)}>" | |
token_2 = f"<patch_index_{str(idx_2).zfill(4)}>" | |
return token_1, token_2 | |
def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patches_per_side: int) -> Tuple[int, int]: | |
"""Convert a bounding box to a pair of patch indices. | |
Args: | |
bbox (`Tuple[float, float, float, float]`): | |
The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left and | |
lower-right corners of the box. It should have x2 > x1 and y2 > y1. | |
num_patches_per_side (`int`): the number of patches along each side. | |
Returns: | |
`Tuple[int, int]`: A pair of patch indices representing the upper-left patch and lower-right patch. | |
""" | |
(x1, y1, x2, y2) = bbox | |
if not (x2 > x1 and y2 > y1): | |
raise ValueError("The coordinates in `bbox` should be `(x1, y1, x2, y2)` with `x2 > x1` and `y2 > y1`.") | |
ul_x = math.floor(x1 * num_patches_per_side) | |
ul_y = math.floor(y1 * num_patches_per_side) | |
lr_x = math.ceil(x2 * num_patches_per_side - 1) | |
lr_y = math.ceil(y2 * num_patches_per_side - 1) | |
ul_idx = ul_y * num_patches_per_side + ul_x | |
lr_idx = lr_y * num_patches_per_side + lr_x | |
return ul_idx, lr_idx | |
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38 | |
# (with format modifications) | |
def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int): | |
""" | |
Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a | |
bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). | |
Args: | |
ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box. | |
lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box. | |
num_patches_per_side (`int`): the number of patches along each side. | |
Returns: | |
`Tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2). | |
""" | |
# Compute the size of each cell in the grid | |
cell_size = 1.0 / num_patches_per_side | |
# Compute the x and y indices of the upper-left and lower-right corners of the bounding box | |
ul_x = ul_idx % num_patches_per_side | |
ul_y = ul_idx // num_patches_per_side | |
lr_x = lr_idx % num_patches_per_side | |
lr_y = lr_idx // num_patches_per_side | |
# Compute the normalized coordinates of the bounding box | |
if ul_idx == lr_idx: | |
x1 = ul_x * cell_size | |
y1 = ul_y * cell_size | |
x2 = lr_x * cell_size + cell_size | |
y2 = lr_y * cell_size + cell_size | |
elif ul_x == lr_x or ul_y == lr_y: | |
x1 = ul_x * cell_size | |
y1 = ul_y * cell_size | |
x2 = lr_x * cell_size + cell_size | |
y2 = lr_y * cell_size + cell_size | |
else: | |
x1 = ul_x * cell_size + cell_size / 2 | |
y1 = ul_y * cell_size + cell_size / 2 | |
x2 = lr_x * cell_size + cell_size / 2 | |
y2 = lr_y * cell_size + cell_size / 2 | |
return x1, y1, x2, y2 | |
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33 | |
# (with format modifications) | |
def extract_entities_with_patch_indices(text): | |
"""Extract entities contained in `text`. The bounding bboxes is given in the form of patch indices. | |
This functioin is only intended to be used within `clean_text_and_extract_entities_with_bboxes` where further | |
processing happens, including converting to normalized coordinates and whitespace character cleaning up. | |
Examples: | |
```python | |
>>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>." | |
>>> entities = extract_entities_with_patch_indices(text) | |
>>> entities | |
[(' a snowman', (31, 41), [(44, 863)]), (' a fire', (130, 137), [(5, 911)])] | |
```""" | |
# The regular expression pattern for matching the required formats | |
pattern = r"(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>" | |
# Find all matches in the given string | |
matches = re.finditer(pattern, text) | |
# Initialize an empty list to store the valid patch_index combinations | |
entities_with_patch_indices = [] | |
for match in matches: | |
# span of a `phrase` that is between <phrase> and </phrase> | |
span = match.span(2) | |
phrase_tag, phrase, match_content = match.groups() | |
if not phrase_tag: | |
phrase = None | |
# We take the starting position of `<object>` | |
span = (match.span(0)[0], match.span(0)[0]) | |
# Split the match_content by the delimiter to get individual patch_index pairs | |
patch_index_pairs = match_content.split("</delimiter_of_multi_objects/>") | |
entity_bboxes = [] | |
for pair in patch_index_pairs: | |
# Extract the xxxx and yyyy values from the patch_index pair | |
x = re.search(r"<patch_index_(\d+)>", pair) | |
y = re.search(r"<patch_index_(\d+)>", pair[1:]) | |
if x and y: | |
if phrase: | |
entity_bboxes.append((int(x.group(1)), int(y.group(1)))) | |
else: | |
entity_bboxes.append((int(x.group(1)), int(y.group(1)))) | |
if phrase: | |
entities_with_patch_indices.append((phrase, span, entity_bboxes)) | |
else: | |
for bbox in entity_bboxes: | |
# fake entity name | |
entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>" | |
entities_with_patch_indices.append((entity, span, [bbox])) | |
return entities_with_patch_indices | |
def adjust_entity_positions(entity, text): | |
"""Adjust the positions of the entities in `text` to be relative to the text with special fields removed.""" | |
entity_name, (start, end) = entity | |
# computed the length of strings with special fields (tag tokens, patch index tokens, etc.) removed | |
adjusted_start = len(re.sub("<.*?>", "", text[:start])) | |
adjusted_end = len(re.sub("<.*?>", "", text[:end])) | |
adjusted_entity = (entity_name, (adjusted_start, adjusted_end)) | |
return adjusted_entity | |
def _cleanup_spaces(text, entities): | |
"""Remove the spaces around the text and the entities in it.""" | |
new_text = text.strip() | |
leading_spaces = len(text) - len(text.lstrip()) | |
new_entities = [] | |
for entity_name, (start, end), bboxes in entities: | |
entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip()) | |
entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip()) | |
start = start - leading_spaces + entity_name_leading_spaces | |
end = end - leading_spaces - entity_name_trailing_spaces | |
entity_name = entity_name.strip() | |
new_entities.append((entity_name, (start, end), bboxes)) | |
return new_text, new_entities | |
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87 | |
# (with format modifications) | |
def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32): | |
"""Remove the tag tokens from `text`, extract entities in it with some cleaning up of white characters. | |
Examples: | |
```python | |
>>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>." | |
>>> clean_text, entities = clean_text_and_extract_entities_with_bboxes(text) | |
>>> clean_text | |
'An image of a snowman warming himself by a fire.' | |
>>> entities | |
[('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])] | |
```""" | |
# remove special fields (tag tokens, patch index tokens, etc.) | |
processed_text = re.sub("<.*?>", "", text) | |
entities_with_patch_indices = extract_entities_with_patch_indices(text) | |
entities = [] | |
for item in entities_with_patch_indices: | |
entity, bboxes = item[0:2], item[2] | |
adjusted_entity = adjust_entity_positions(entity, text) | |
bboxes_in_coords = [patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side) for bbox in bboxes] | |
entities.append(adjusted_entity + (bboxes_in_coords,)) | |
return _cleanup_spaces(processed_text, entities) | |