import copy from typing import Optional from PIL import Image from .single_image_convsation import SingleImageConvDatasetMixin class SingleImageInteractive(SingleImageConvDatasetMixin): _printed_sample = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.image: Optional[Image.Image] = None self.roles = ('human', 'gpt') self.boxes = [] self.points = [] self.raw_conv = [] self.conversations = [] def set_image(self, image: Image.Image): assert self.image is None, f"{image}" self.image = image def append_message(self, role: str, message: str, *, boxes=None, points=None, boxes_seq=None, points_seq=None): """Append a new message.""" assert role in self.roles def convert_idx(objs_seq, objs_value, get_obj_idx_func): if objs_seq is None: return None ret = [] for objs_idx in objs_seq: new_objs_idx = [] for idx in objs_idx: new_idx = get_obj_idx_func(objs_value[idx]) new_objs_idx.append(new_idx) ret.append(tuple(new_objs_idx)) return tuple(ret) boxes_seq = convert_idx(boxes_seq, boxes, self._get_box_idx) points_seq = convert_idx(points_seq, points, self._get_point_idx) if self.image is not None: previous_message_has_image_placeholder = any( '' in item['value'] for item in self.conversations ) if not previous_message_has_image_placeholder and '' not in message: message = ' ' + message if previous_message_has_image_placeholder and '' in message: message = message.replace('', '') self.conversations.append( { 'from': role, 'value': message, 'boxes_seq': copy.deepcopy(boxes_seq), 'points_seq': copy.deepcopy(points_seq), } ) def get_raw_item(self, index=None): ret = copy.deepcopy({ 'image': self.image, 'target': { 'boxes': self.boxes, 'points': self.points, }, 'conversations': self.conversations, }) assert ret['conversations'][0]['from'] == self.roles[0] if ret['conversations'][-1]['from'] == self.roles[0]: ret['conversations'].append( { 'from': self.roles[1], 'value': '', } ) return ret def to_model_input(self): item = self.__getitem__(0) ret = {'input_ids': item['input_ids'].unsqueeze(0).cuda()} if 'image' in item and item['image'] is not None: ret['images'] = item['image'].unsqueeze(0).cuda() else: ret['images'] = None return ret def to_gradio_chatbot_new_messages(self): conv = self.__getitem__(0, return_conv=True) new_messages = conv.messages[-2:] ret_messages = [] for r, m in new_messages: nm = m.replace('', '').replace('', '').replace('', '') ret_messages.append((r, nm)) return ret_messages def _get_box_idx(self, box): assert isinstance(box, (tuple, list)), f"{type(box)}" assert isinstance(box[0], (int, float)), f"{type(box[0])}" assert len(box) == 4 box = tuple(box) if box not in self.boxes: self.boxes.append(box) return len(self.boxes) - 1 else: return self.boxes.index(box) def _get_point_idx(self, point): assert isinstance(point, (tuple, list)) assert isinstance(point[0], (int, float)) assert len(point) == 2 point = tuple(point) if point not in self.points: self.points.append(tuple(point)) return len(self.points) - 1 else: return self.points.index(point) def __len__(self): return 1