zwgao's picture
add file
3fdcc70
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable, List
import json
from pathlib import Path
from collections import OrderedDict
@dataclass
class Action:
"""The action represent an assignment.
`output = tool_name(**inputs)`
Examples:
>>> mask = segmentation_by_mask(image=image, prompt_mask=prompt_mask)
>>> image = image_inpainting(image=image, mask=mask)
"""
tool_name: str = (None,)
inputs: dict = (None,)
outputs: List[str] = (None,)
def __str__(self) -> str:
args = ", ".join([f"{k}={v}" for k, v in self.inputs.items()])
return "{} = {}(".format(", ".join(self.outputs), self.tool_name) + args + ")"
def dict(self):
args = {str(k): str(v) for k, v in self.inputs.items()}
# args = {str(item["name"]): str(item["value"]) for item in self.inputs}
rets = [o if isinstance(o, str) else str(o) for o in self.outputs]
return {
"tool": self.tool_name,
"inputs": args,
"outputs": rets,
}
class DataType(Enum):
TEXT = "text"
TAGS = "tags"
TITLE = "title"
# HTML = "text.html"
HTML = "html"
LOCATION = "location"
WEATHER = "weather"
TIME = "time"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
ANY = "any"
NONE = "none"
SEGMENTATION = "image.segmentation"
EDGE = "image.edge"
LINE = "image.line"
HED = "image.hed"
CANNY = "image.canny"
SCRIBBLE = "image.scribble"
POSE = "image.pose"
DEPTH = "image.depth"
NORMAL = "image.normal"
MASK = "image.mask" # SAM mask
POINT = "point"
BBOX = "bbox" # {'label': 'dog', 'box': [1,2,3,4], 'score': 0.9}
CATEGORY = "category"
LIST = "list"
def __str__(self):
return self.value
def __eq__(self, other):
if isinstance(other, str):
return self.value == other
elif isinstance(other, self.__class__):
return self.value == other.value
else:
return False
@dataclass
class Resource:
name: str
type: DataType
value: None
# description: str = None
def dict(self):
return {
"name": self.name,
"type": str(self.type),
"value": str(self.value),
# "description": self.description,
}
@dataclass
class Tool:
class Domain(Enum):
IMAGE_PERCEPTION = "image-perception"
IMAGE_GENERATION = "image-generation"
IMAGE_EDITING = "image-editing"
IMAGE_PROCESSING = "image-processing"
AUDIO_PERCEPTION = "audio-perception"
AUDIO_GENERATION = "audio-generation"
VIDEO_PERCEPTION = "video-perception"
VIDEO_GENERATION = "video-generation"
VIDEO_PROCESSING = "video-processing"
VIDEO_EDITING = "video-editing"
VIDEO_CUTTING = "video-cutting"
NATURAL_LANGUAGE_PROCESSING = "natural-language-processing"
CODE_GENERATION = "code-generation"
VISUAL_QUESTION_ANSWERING = "visual-question-answering"
QUESTION_ANSWERING = "question-answering"
GENERAL = "general"
def __str__(self):
return self.value
@dataclass
class Argument:
name: str
type: DataType
description: str
def dict(self):
return {
"name": self.name,
"type": str(self.type),
"description": self.description,
}
name: str
description: str
domain: Domain
model: Callable
usages: List[str] = field(default_factory=lambda: [])
args: List[Argument] = field(default_factory=lambda: [])
returns: List[Argument] = field(default_factory=lambda: [])
def dict(self):
return {
"name": self.name,
"description": self.description,
"domain": str(self.domain),
"args": [a.dict() for a in self.args],
"returns": [r.dict() for r in self.returns],
}
NON_FILE_TYPES = [
DataType.TAGS,
DataType.TEXT,
DataType.TITLE,
DataType.BBOX,
DataType.CATEGORY,
DataType.LIST,
DataType.LOCATION,
DataType.POINT,
DataType.WEATHER,
DataType.TIME,
]
if __name__ == "__main__":
s = [
[Action("a", {"aa": [Path("/a/d/e/t.txt")]}, [Path("/a/aa.txt")])],
Action("b", {"bb": "bbb"}, ["bbb"]),
]
print(json.dumps(s, indent=4, default=lambda o: o.dict()))