Spaces:
Sleeping
Sleeping
import datetime | |
import re | |
import time | |
from enum import Enum | |
from itertools import groupby | |
from typing import ClassVar, Dict, List, Optional | |
from PIL import Image | |
from pydantic import BaseModel, field_validator, model_validator | |
from ...utils.general import encode_image | |
from ..od.schemas import Target | |
class Role(str, Enum): | |
USER = "user" | |
ASSISTANT = "assistant" | |
SYSTEM = "system" | |
class MessageType(str, Enum): | |
IMAGE = "image" | |
TEXT = "text" | |
File = "file" | |
MIXED = "mixed" | |
class ImageUrl(BaseModel): | |
url: str | |
detail: str = "auto" | |
def validate_detail(cls, detail: str) -> str: | |
if detail not in ["high", "low", "auto"]: | |
raise ValueError( | |
'The detail can only be one of "high", "low" and "auto". {} is not valid.'.format( | |
detail | |
) | |
) | |
return detail | |
class Content(BaseModel): | |
type: str = "text" | |
text: Optional[str] = None | |
image_url: Optional[ImageUrl] = None | |
def validate(self): | |
if self.type == "text": | |
if self.text is None: | |
raise ValueError( | |
"The text value must be valid when Content type is 'text'." | |
) | |
elif self.type == "image_url": | |
if self.image_url is None: | |
raise ValueError( | |
"The image_url value must be valid when Content type is 'image_url'." | |
) | |
else: | |
raise ValueError( | |
"Invalid Conyent type {}. Must be one of text and image_url".format( | |
self.type | |
) | |
) | |
return self | |
class Message(BaseModel): | |
""" | |
- role (str): The role of this message. Choose from 'user', 'assistant', 'system'. | |
- message_type (str): Type of the message. Choose from 'text', 'image' and 'mixed'. | |
- src_type (str): Type of the message content. If message is text, src_type='text'. If message is image, Choose from 'url', 'base64', 'local' and 'redis'. | |
- content (str): Message content. | |
- objects (List[schemas.Target]): The detected objects. | |
""" | |
role: Role = Role.USER | |
message_type: MessageType = MessageType.TEXT | |
content: List[Content | Dict] | Content | str | |
objects: List[Target] = [] | |
kwargs: dict = {} | |
basic_data_types: ClassVar[List[type]] = [ | |
str, | |
list, | |
tuple, | |
int, | |
float, | |
bool, | |
datetime.datetime, | |
datetime.time, | |
] | |
def merge_consecutive_text(cls, content) -> List: | |
result = [] | |
current_str = "" | |
for part in content: | |
if isinstance(part, str): | |
current_str += part | |
else: | |
if current_str: | |
result.append(current_str) | |
current_str = "" | |
result.append(part) | |
if current_str: # 处理最后的字符串 | |
result.append(current_str) | |
return result | |
def content_validator( | |
cls, content: List[Content | Dict] | Content | str | |
) -> List[Content] | Content: | |
if isinstance(content, str): | |
return Content(type="text", text=content) | |
elif isinstance(content, list): | |
# combine str elements in list | |
content = cls.merge_consecutive_text(content) | |
formatted = [] | |
for c in content: | |
if not c: | |
continue | |
if isinstance(c, Content): | |
formatted.append(c) | |
elif isinstance(c, dict): | |
try: | |
formatted.append(Content(**c)) | |
except Exception as e: | |
formatted.append(Content(type="text", text=str(c))) | |
elif isinstance(c, Image.Image): | |
formatted.append( | |
Content( | |
type="image_url", | |
image_url={ | |
"url": f"data:image/jpeg;base64,{encode_image(c)}" | |
}, | |
) | |
) | |
elif isinstance(c, tuple(cls.basic_data_types)): | |
formatted.append(Content(type="text", text=str(c))) | |
else: | |
raise ValueError( | |
f"Content list must contain [Content, str, list, dict, PIL.Image], got {type(c)}" | |
) | |
else: | |
raise ValueError( | |
"Content must be a string, a list of Content objects or list of dicts." | |
) | |
return formatted[0] if len(formatted) == 1 else formatted | |
def system(cls, content: str | List[str | Dict | Content]) -> "Message": | |
return cls(role=Role.SYSTEM, content=content) | |
def user(cls, content: str | List[str | Dict | Content]) -> "Message": | |
return cls(role=Role.USER, content=content) | |
def assistant(cls, content: str | List[str | Dict | Content]) -> "Message": | |
return cls(role=Role.ASSISTANT, content=content) | |