RexSeek-3B / preprocessing_rexseek.py
Mountchicken's picture
Upload 16 files
692ce93 verified
raw
history blame
8.24 kB
from PIL import Image
import re
from typing import List, Union
import numpy as np
import torch
import torchvision.transforms.functional as F
from transformers import AutoTokenizer
from transformers.processing_utils import ProcessorMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
# For Objects
DEFAULT_OBJECT_TOKEN = "<obj<i>>"
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>"
DEFAULT_OBJECT_INDEX = -300
# For Grounding
DEFAULT_GROUNDING_START = "<ground>"
DEFAULT_GROUNDING_END = "</ground>"
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
def xyxy_to_xywh(boxes):
"""
Convert boxes from xywh to xyxy format.
Parameters:
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
Each box is represented as [x, y, x, y].
Returns:
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h].
"""
boxes = np.array(boxes)
x_min, y_min, x_max, y_max = (
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
)
w = x_max - x_min
h = y_max - y_min
return np.stack([x_min, y_min, w, h], axis=1)
def xywh_to_xyxy(boxes):
"""
Convert boxes from xywh to xyxy format.
Parameters:
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
Each box is represented as [x, y, width, height].
Returns:
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max].
"""
boxes = np.array(boxes)
x, y, width, height = (
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
)
x_max = x + width
y_max = y + height
return np.stack([x, y, x_max, y_max], axis=1)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def pad_boxes(gt_boxes, old_size):
old_w, old_h = old_size
gt_boxes = np.array(gt_boxes).astype(np.float32)
# Calculate the padding added
if old_w > old_h:
pad_top = (old_w - old_h) // 2
pad_bottom = old_w - old_h - pad_top
pad_left, pad_right = 0, 0
else:
pad_left = (old_h - old_w) // 2
pad_right = old_h - old_w - pad_left
pad_top, pad_bottom = 0, 0
# Adjust the boxes for padding
gt_boxes[:, 0] += pad_left # x
gt_boxes[:, 1] += pad_top # y
return gt_boxes
def resize_boxes(gt_boxes, old_size, new_size):
old_w, old_h = old_size
new_h, new_w = new_size
gt_boxes = np.array(gt_boxes).astype(np.float32)
# Calculate scale factors
scale_x = new_w / max(old_w, old_h)
scale_y = new_h / max(old_w, old_h)
# Resize the boxes
gt_boxes[:, 0] *= scale_x # x
gt_boxes[:, 1] *= scale_y # y
gt_boxes[:, 2] *= scale_x # w
gt_boxes[:, 3] *= scale_y # h
return gt_boxes
def split_special_strings(input_string: str, special_strings: list[str] = None):
"""Split the input string into a list of strings, keeping the special strings.
Args:
input_string (str): The input string to split.
Example:
input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today."
output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.']
Returns:
list: A list of strings, with the special strings separated from the rest of the input string.
"""
# Create a regex pattern to match the special strings
pattern = "|".join(map(re.escape, special_strings))
# Split the input string using the pattern, keeping the special strings in the result
split_list = re.split(f"({pattern})", input_string)
# Remove empty strings from the list
split_list = [s for s in split_list if s]
return split_list
def tokenizer_image_object_token(prompt, tokenizer):
bos_token_id = tokenizer.bos_token_id
split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN]
chunks = split_special_strings(prompt, split_tokens)
input_encode = [bos_token_id] if bos_token_id else []
for chunk in chunks:
if chunk == DEFAULT_IMAGE_TOKEN:
input_encode.append(IMAGE_TOKEN_INDEX)
elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN:
input_encode.append(DEFAULT_OBJECT_INDEX)
else:
input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False))
return input_encode
class RexSeekProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer: AutoTokenizer = None, **kwargs):
# self.image_processor = image_processor
# self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer)
self._special_tokens = None
self.template = dict(
SYSTEM=("<|im_start|>system\n{system}<|im_end|>\n"),
INSTRUCTION=(
"<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n"
),
SUFFIX="<|im_end|>",
SUFFIX_AS_EOS=True,
SEP="\n",
STOP_WORDS=["<|im_end|>", "<|endoftext|>"],
)
def process(
self,
image: Union[str, Image.Image],
bbox: List[List[int]],
question: str,
):
"""Prepare input data for inference.
Args:
image (Union[str, Image.Image]): The image to process.
bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should
be in order of [x, y, x , y].
question (str): The question to ask about the image.
"""
data_dict = {}
# step1 load image
if type(image) == str:
image = Image.open(image).convert("RGB")
ori_w, ori_h = F.get_image_size(image)
image = expand2square(
image,
tuple(int(x * 255) for x in self.image_processor.image_mean),
)
pad_w, pad_h = F.get_image_size(image)
image_aux = self.image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
resize_h, resize_w = image_aux.shape[-2:]
data_dict["pixel_values_aux"] = image_aux.unsqueeze(0)
image = image_aux.clone()
image = torch.nn.functional.interpolate(
image[None],
size=[336, 336],
mode="bilinear",
align_corners=False,
)[0]
data_dict["pixel_values"] = image.unsqueeze(0)
# step2 load boxes
bbox = xyxy_to_xywh(bbox)
bbox = pad_boxes(bbox, (ori_w, ori_h))
bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w))
data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0)
# step3 prepare question
total_num_boxes = len(bbox)
obj_tokens = [
DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes)
]
obj_tokens = (
DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN
)
question = question.replace(DEFAULT_IMAGE_TOKEN, "")
question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question
inputs = ""
inputs += self.template["INSTRUCTION"].format(input=question, round=1)
# step4 tokenize question
input_ids = tokenizer_image_object_token(inputs, self.tokenizer)
data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0)
return data_dict
RexSeekProcessor.register_for_auto_class()