File size: 5,298 Bytes
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import datetime
import re
import time
from enum import Enum
from itertools import groupby
from typing import ClassVar, Dict, List, Optional

from PIL import Image
from pydantic import BaseModel, field_validator, model_validator

from ...utils.general import encode_image
from ..od.schemas import Target


class Role(str, Enum):
    USER = "user"
    ASSISTANT = "assistant"
    SYSTEM = "system"


class MessageType(str, Enum):
    IMAGE = "image"
    TEXT = "text"
    File = "file"
    MIXED = "mixed"


class ImageUrl(BaseModel):
    url: str
    detail: str = "auto"

    @field_validator("detail")
    @classmethod
    def validate_detail(cls, detail: str) -> str:
        if detail not in ["high", "low", "auto"]:
            raise ValueError(
                'The detail can only be one of "high", "low" and "auto". {} is not valid.'.format(
                    detail
                )
            )
        return detail


class Content(BaseModel):
    type: str = "text"
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

    @model_validator(mode="after")
    def validate(self):
        if self.type == "text":
            if self.text is None:
                raise ValueError(
                    "The text value must be valid when Content type is 'text'."
                )
        elif self.type == "image_url":
            if self.image_url is None:
                raise ValueError(
                    "The image_url value must be valid when Content type is 'image_url'."
                )
        else:
            raise ValueError(
                "Invalid Conyent type {}. Must be one of text and image_url".format(
                    self.type
                )
            )
        return self


class Message(BaseModel):
    """
    - role (str): The role of this message. Choose from 'user', 'assistant', 'system'.
    - message_type (str): Type of the message. Choose from 'text', 'image' and 'mixed'.
    - src_type (str): Type of the message content. If message is text, src_type='text'. If message is image, Choose from 'url', 'base64', 'local' and 'redis'.
    - content (str): Message content.
    - objects (List[schemas.Target]): The detected objects.
    """

    role: Role = Role.USER
    message_type: MessageType = MessageType.TEXT
    content: List[Content | Dict] | Content | str
    objects: List[Target] = []
    kwargs: dict = {}
    basic_data_types: ClassVar[List[type]] = [
        str,
        list,
        tuple,
        int,
        float,
        bool,
        datetime.datetime,
        datetime.time,
    ]

    @classmethod
    def merge_consecutive_text(cls, content) -> List:
        result = []
        current_str = ""

        for part in content:
            if isinstance(part, str):
                current_str += part
            else:
                if current_str:
                    result.append(current_str)
                    current_str = ""
                result.append(part)

        if current_str:  # 处理最后的字符串
            result.append(current_str)

        return result

    @field_validator("content", mode="before")
    @classmethod
    def content_validator(
        cls, content: List[Content | Dict] | Content | str
    ) -> List[Content] | Content:
        if isinstance(content, str):
            return Content(type="text", text=content)
        elif isinstance(content, list):
            # combine str elements in list
            content = cls.merge_consecutive_text(content)
            formatted = []
            for c in content:
                if not c:
                    continue
                if isinstance(c, Content):
                    formatted.append(c)
                elif isinstance(c, dict):
                    try:
                        formatted.append(Content(**c))
                    except Exception as e:
                        formatted.append(Content(type="text", text=str(c)))
                elif isinstance(c, Image.Image):
                    formatted.append(
                        Content(
                            type="image_url",
                            image_url={
                                "url": f"data:image/jpeg;base64,{encode_image(c)}"
                            },
                        )
                    )
                elif isinstance(c, tuple(cls.basic_data_types)):
                    formatted.append(Content(type="text", text=str(c)))
                else:
                    raise ValueError(
                        f"Content list must contain [Content, str, list, dict, PIL.Image], got {type(c)}"
                    )
        else:
            raise ValueError(
                "Content must be a string, a list of Content objects or list of dicts."
            )
        return formatted[0] if len(formatted) == 1 else formatted

    @classmethod
    def system(cls, content: str | List[str | Dict | Content]) -> "Message":
        return cls(role=Role.SYSTEM, content=content)

    @classmethod
    def user(cls, content: str | List[str | Dict | Content]) -> "Message":
        return cls(role=Role.USER, content=content)

    @classmethod
    def assistant(cls, content: str | List[str | Dict | Content]) -> "Message":
        return cls(role=Role.ASSISTANT, content=content)