Spaces:
Sleeping
Sleeping
File size: 5,298 Bytes
1b7e88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import datetime
import re
import time
from enum import Enum
from itertools import groupby
from typing import ClassVar, Dict, List, Optional
from PIL import Image
from pydantic import BaseModel, field_validator, model_validator
from ...utils.general import encode_image
from ..od.schemas import Target
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class MessageType(str, Enum):
IMAGE = "image"
TEXT = "text"
File = "file"
MIXED = "mixed"
class ImageUrl(BaseModel):
url: str
detail: str = "auto"
@field_validator("detail")
@classmethod
def validate_detail(cls, detail: str) -> str:
if detail not in ["high", "low", "auto"]:
raise ValueError(
'The detail can only be one of "high", "low" and "auto". {} is not valid.'.format(
detail
)
)
return detail
class Content(BaseModel):
type: str = "text"
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
@model_validator(mode="after")
def validate(self):
if self.type == "text":
if self.text is None:
raise ValueError(
"The text value must be valid when Content type is 'text'."
)
elif self.type == "image_url":
if self.image_url is None:
raise ValueError(
"The image_url value must be valid when Content type is 'image_url'."
)
else:
raise ValueError(
"Invalid Conyent type {}. Must be one of text and image_url".format(
self.type
)
)
return self
class Message(BaseModel):
"""
- role (str): The role of this message. Choose from 'user', 'assistant', 'system'.
- message_type (str): Type of the message. Choose from 'text', 'image' and 'mixed'.
- src_type (str): Type of the message content. If message is text, src_type='text'. If message is image, Choose from 'url', 'base64', 'local' and 'redis'.
- content (str): Message content.
- objects (List[schemas.Target]): The detected objects.
"""
role: Role = Role.USER
message_type: MessageType = MessageType.TEXT
content: List[Content | Dict] | Content | str
objects: List[Target] = []
kwargs: dict = {}
basic_data_types: ClassVar[List[type]] = [
str,
list,
tuple,
int,
float,
bool,
datetime.datetime,
datetime.time,
]
@classmethod
def merge_consecutive_text(cls, content) -> List:
result = []
current_str = ""
for part in content:
if isinstance(part, str):
current_str += part
else:
if current_str:
result.append(current_str)
current_str = ""
result.append(part)
if current_str: # 处理最后的字符串
result.append(current_str)
return result
@field_validator("content", mode="before")
@classmethod
def content_validator(
cls, content: List[Content | Dict] | Content | str
) -> List[Content] | Content:
if isinstance(content, str):
return Content(type="text", text=content)
elif isinstance(content, list):
# combine str elements in list
content = cls.merge_consecutive_text(content)
formatted = []
for c in content:
if not c:
continue
if isinstance(c, Content):
formatted.append(c)
elif isinstance(c, dict):
try:
formatted.append(Content(**c))
except Exception as e:
formatted.append(Content(type="text", text=str(c)))
elif isinstance(c, Image.Image):
formatted.append(
Content(
type="image_url",
image_url={
"url": f"data:image/jpeg;base64,{encode_image(c)}"
},
)
)
elif isinstance(c, tuple(cls.basic_data_types)):
formatted.append(Content(type="text", text=str(c)))
else:
raise ValueError(
f"Content list must contain [Content, str, list, dict, PIL.Image], got {type(c)}"
)
else:
raise ValueError(
"Content must be a string, a list of Content objects or list of dicts."
)
return formatted[0] if len(formatted) == 1 else formatted
@classmethod
def system(cls, content: str | List[str | Dict | Content]) -> "Message":
return cls(role=Role.SYSTEM, content=content)
@classmethod
def user(cls, content: str | List[str | Dict | Content]) -> "Message":
return cls(role=Role.USER, content=content)
@classmethod
def assistant(cls, content: str | List[str | Dict | Content]) -> "Message":
return cls(role=Role.ASSISTANT, content=content)
|