|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
from dataclasses import dataclass |
|
import io |
|
from enum import Enum |
|
from PIL import Image |
|
from typing import List, Tuple |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from .constants import EVA_IMAGE_SIZE, GRD_SYMBOL, BOP_SYMBOL, EOP_SYMBOL, BOO_SYMBOL, EOO_SYMBOL |
|
from .constants import DEFAULT_VIDEO_TOKEN, DEFAULT_EOS_TOKEN, USER_TOKEN, ASSISTANT_TOKEN, FAKE_VIDEO_END_TOKEN |
|
|
|
from .utils import gen_id, frontend_logger as logging |
|
|
|
|
|
class Role(Enum): |
|
UNKNOWN = 0, |
|
USER = 1, |
|
ASSISTANT = 2, |
|
|
|
|
|
class DataType(Enum): |
|
UNKNOWN = 0, |
|
TEXT = 1, |
|
IMAGE = 2, |
|
GROUNDING = 3, |
|
VIDEO = 4, |
|
ERROR = 5, |
|
|
|
|
|
@dataclass |
|
class DataMeta: |
|
datatype: DataType = DataType.UNKNOWN |
|
text: str = None |
|
image: Image.Image = None |
|
mask: Image.Image = None |
|
coordinate: List[int] = None |
|
frames: List[Image.Image] = None |
|
stack_frame: Image.Image = None |
|
|
|
@property |
|
def grounding(self): |
|
return self.coordinate is not None |
|
|
|
@property |
|
def text_str(self): |
|
return self.text |
|
|
|
@property |
|
def image_str(self): |
|
return self.image2str(self.image) |
|
|
|
@property |
|
def video_str(self): |
|
ret = f'<div style="overflow:scroll"><b>[VIDEO]</b></div>{self.image2str(self.stack_frame)}' |
|
return ret |
|
|
|
@property |
|
def grounding_str(self): |
|
ret = "" |
|
if self.text is not None: |
|
ret += f'<div style="overflow:scroll"><b>[PHRASE]</b>{self.text}</div>' |
|
|
|
ret += self.image2str(self.mask) |
|
|
|
if self.image is not None: |
|
ret += self.image2str(self.image) |
|
return ret |
|
|
|
def image2str(self, image): |
|
buf = io.BytesIO() |
|
image.save(buf, format="WEBP") |
|
i_str = base64.b64encode(buf.getvalue()).decode() |
|
return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>' |
|
|
|
def format_chatbot(self): |
|
match self.datatype: |
|
case DataType.TEXT | DataType.ERROR: |
|
return self.text_str |
|
case DataType.IMAGE: |
|
return self.image_str |
|
case DataType.VIDEO: |
|
return self.video_str |
|
case DataType.GROUNDING: |
|
return self.grounding_str |
|
case _: |
|
return "" |
|
|
|
def format_prompt(self) -> List[str | Image.Image]: |
|
match self.datatype: |
|
case DataType.TEXT: |
|
return [self.text] |
|
case DataType.IMAGE: |
|
return [self.image] |
|
case DataType.VIDEO: |
|
return [DEFAULT_VIDEO_TOKEN] + self.frames + [FAKE_VIDEO_END_TOKEN] |
|
case DataType.GROUNDING: |
|
ret = [] |
|
if self.text is not None: |
|
ret.append(f"{BOP_SYMBOL}{self.text}{EOP_SYMBOL}") |
|
ret += [BOO_SYMBOL, self.mask, EOO_SYMBOL] |
|
if self.image is not None: |
|
ret.append(self.image) |
|
return ret |
|
case _: |
|
return [] |
|
|
|
def __str__(self): |
|
s = "" |
|
if self.text is not None: |
|
s += f"T:{self.text}" |
|
|
|
if self.image is not None: |
|
w, h = self.image.size |
|
s += f"[I:{h}x{w}]" |
|
|
|
if self.coordinate is not None: |
|
l, t, r, b = self.coordinate |
|
s += f"[C:({l:03d},{t:03d}),({r:03d},{b:03d})]" |
|
|
|
if self.frames is not None: |
|
w, h = self.frames[0].size |
|
s += f"[V:{len(self.frames)}x{h}x{w}]" |
|
|
|
return s |
|
|
|
@classmethod |
|
def build(cls, text=None, image=None, coordinate=None, frames=None, is_error=False, *, resize: bool = True): |
|
ins = cls() |
|
ins.text = text if text != "" else None |
|
ins.image = cls.resize(image, force=resize) |
|
|
|
ins.coordinate = cls.fix(coordinate) |
|
ins.frames = cls.resize(frames, force=resize) |
|
|
|
|
|
if is_error: |
|
ins.datatype = DataType.ERROR |
|
elif coordinate is not None: |
|
ins.datatype = DataType.GROUNDING |
|
ins.draw_box() |
|
elif image is not None: |
|
ins.datatype = DataType.IMAGE |
|
elif text is not None: |
|
ins.datatype = DataType.TEXT |
|
else: |
|
ins.datatype = DataType.VIDEO |
|
ins.stack() |
|
|
|
return ins |
|
|
|
@classmethod |
|
def fix(cls, coordinate): |
|
if coordinate is None: |
|
return None |
|
|
|
l, t, r, b = coordinate |
|
l = min(EVA_IMAGE_SIZE, max(0, l)) |
|
t = min(EVA_IMAGE_SIZE, max(0, t)) |
|
r = min(EVA_IMAGE_SIZE, max(0, r)) |
|
b = min(EVA_IMAGE_SIZE, max(0, b)) |
|
return min(l, r), min(t, b), max(l, r), max(t, b) |
|
|
|
@classmethod |
|
def resize(cls, image: Image.Image | List[Image.Image] | None, *, force: bool = True): |
|
if image is None: |
|
return None |
|
|
|
if not force: |
|
return image |
|
|
|
if isinstance(image, Image.Image): |
|
image = [image] |
|
|
|
for idx, im in enumerate(image): |
|
w, h = im.size |
|
if w < EVA_IMAGE_SIZE or h < EVA_IMAGE_SIZE: |
|
continue |
|
|
|
if w < h: |
|
h = int(EVA_IMAGE_SIZE / w * h) |
|
w = EVA_IMAGE_SIZE |
|
else: |
|
w = int(EVA_IMAGE_SIZE / h * w) |
|
h = EVA_IMAGE_SIZE |
|
|
|
image[idx] = im.resize((w, h)) |
|
|
|
return image if len(image) > 1 else image[0] |
|
|
|
def draw_box(self): |
|
left, top, right, bottom = self.coordinate |
|
mask = np.zeros((EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, 3), dtype=np.uint8) |
|
mask = cv2.rectangle(mask, (left, top), (right, bottom), (255, 255, 255), 3) |
|
self.mask = Image.fromarray(mask) |
|
|
|
def stack(self): |
|
w, h = self.frames[0].size |
|
n = len(self.frames) |
|
stack_frame = Image.new(mode="RGB", size=(w*n, h)) |
|
for idx, f in enumerate(self.frames): |
|
stack_frame.paste(f, (idx*w, 0)) |
|
self.stack_frame = stack_frame |
|
|
|
|
|
class ConvMeta: |
|
|
|
def __init__(self): |
|
self.system: str = "You are a helpful assistant, dedicated to delivering comprehensive and meticulous responses." |
|
self.message: List[Tuple[Role, DataMeta]] = [] |
|
self.log_id: str = gen_id() |
|
|
|
logging.info(f"{self.log_id}: create new round of chat") |
|
|
|
def append(self, r: Role, p: DataMeta): |
|
logging.info(f"{self.log_id}: APPEND [{r.name}] prompt element, type: {p.datatype.name}, message: {p}") |
|
self.message.append((r, p)) |
|
|
|
def format_chatbot(self): |
|
ret = [] |
|
for r, p in self.message: |
|
cur_p = p.format_chatbot() |
|
if r == Role.USER: |
|
ret.append((cur_p, None)) |
|
else: |
|
ret.append((None, cur_p)) |
|
return ret |
|
|
|
def format_prompt(self): |
|
ret = [] |
|
has_coor = False |
|
for _, p in self.message: |
|
has_coor |= (p.datatype == DataType.GROUNDING) |
|
ret += p.format_prompt() |
|
|
|
if has_coor: |
|
ret.insert(0, GRD_SYMBOL) |
|
|
|
logging.info(f"{self.log_id}: format generation prompt: {ret}") |
|
return ret |
|
|
|
def format_chat(self): |
|
ret = [self.system] |
|
|
|
prev_r = None |
|
for r, p in self.message: |
|
if prev_r != r: |
|
if prev_r == Role.ASSISTANT: |
|
ret.append(f"{DEFAULT_EOS_TOKEN}{USER_TOKEN}: ") |
|
elif prev_r is None: |
|
ret.append(f" {USER_TOKEN}: ") |
|
else: |
|
ret.append(f" {ASSISTANT_TOKEN}: ") |
|
ret += p.format_prompt() |
|
prev_r = r |
|
else: |
|
ret += p.format_prompt() |
|
|
|
ret.append(f" {ASSISTANT_TOKEN}:") |
|
|
|
logging.info(f"{self.log_id}: format chat prompt: {ret}") |
|
return ret |
|
|
|
def clear(self): |
|
logging.info(f"{self.log_id}: clear chat history, end current chat round.") |
|
del self.message |
|
self.message = [] |
|
self.log_id = gen_id() |
|
|
|
def pop(self): |
|
if self.has_gen: |
|
logging.info(f"{self.log_id}: pop out previous generation / chat result") |
|
self.message.pop() |
|
|
|
def pop_error(self): |
|
new_message = [] |
|
for r, p in self.message: |
|
if p.datatype == DataType.ERROR: |
|
logging.info(f"{self.log_id}: pop error message: {p.text_str}") |
|
else: |
|
new_message.append((r, p)) |
|
del self.message |
|
self.message = new_message |
|
|
|
@property |
|
def has_gen(self): |
|
if len(self.message) == 0: |
|
return False |
|
if self.message[-1][0] == Role.USER: |
|
return False |
|
return True |
|
|
|
def __len__(self): |
|
return len(self.message) |
|
|