import json from abc import ABC from distutils.util import strtobool from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import yaml from omagent_core.base import BotBase from omagent_core.models.od.schemas import Target from omagent_core.services.handlers.sql_data_handler import SQLDataHandler from omagent_core.utils.error import VQLError from omagent_core.utils.logger import logging from omagent_core.utils.plot import Annotator from PIL import Image from pydantic import BaseModel, model_validator class ArgSchema(BaseModel): """ArgSchema defines the tool input schema. Only support one layer definition. Please prevent using complex structure.""" class Config: """Configuration for this pydantic object.""" extra = "allow" arbitrary_types_allowed = True class ArgInfo(BaseModel): description: Optional[str] type: str = "str" enum: Optional[List] = None required: Optional[bool] = True @model_validator(mode="before") @classmethod def validate_all(cls, values): for key, value in values.items(): if type(value) is str: values[key] = cls.ArgInfo(name=value) elif type(value) is dict: values[key] = cls.ArgInfo(**value) elif type(value) is cls.ArgInfo: pass else: raise ValueError( "The arg type must be one of string, dict or self.ArgInfo." ) return values @classmethod def from_file(cls, schema_file: Union[str, Path]): if type(schema_file) is str: schema_file = Path(schema_file) if schema_file.suffix == ".json": with open(schema_file, "r") as f: schema = json.load(f) elif schema_file.suffix == ".yaml": with open(schema_file, "r") as f: schema = yaml.load(f, Loader=yaml.FullLoader) else: raise ValueError("Only support json and yaml file.") return cls(**schema) def generate_schema(self) -> Union[dict, list]: required_args = [] parameters = {} for key, value in self.model_dump(exclude_none=True).items(): parameters[key] = value if parameters[key].pop("required"): required_args.append(key) return parameters, required_args def validate_args(self, args: dict) -> dict: if type(args) is not dict: raise ValueError( "ArgSchema validate only support dict, not {}".format(type(args)) ) new_args = {} required_fields = set( [k for k, v in self.model_dump().items() if v["required"]] ) name_mapping = { "str": "string", "int": "integer", "float": "number", "bool": "boolean", } for name, value in args.items(): if name not in self.model_dump(): logging.warning( "The input args includes an unnecessary parameter {}. Removed from the args.".format( name ) ) continue if name_mapping[type(value).__name__] == self.model_dump()[name]["type"]: if ( self.model_dump()[name]["enum"] and value not in self.model_dump()[name]["enum"] ): raise ValueError( "The value of {} should be one of {}, but got {}".format( name, str(self.model_dump()[name]["enum"]), value ) ) new_args[name] = value elif self.model_dump()[name]["type"] == "string": try: new_args[name] = str(value) except: raise ValueError( "Parameter {} type expect a str value, but got a {} {}".format( name, type(value), value ) ) elif self.model_dump()[name]["type"] == "integer": try: new_args[name] = int(value) except: raise ValueError( "Parameter {} type expect an int value, but got a {} {}".format( name, type(value), value ) ) elif self.model_dump()[name]["type"] == "number": try: new_args[name] = float(value) except: raise ValueError( "Parameter {} type expect a float value, but got a {} {}".format( name, type(value), value ) ) elif self.model_dump()[name]["type"] == "boolean": if type(value) is bool: new_args[name] = value else: try: new_args[name] = strtobool(str(value)) except: raise ValueError( "Parameter {} type expect a boolean value, but got a {} {}".format( name, type(value), value ) ) else: raise ValueError( "Parameter {} type expect one of string, integer, number and boolean, but got a {} {}".format( name, self.model_dump()[name]["type"], type(value), value ) ) if required_fields - set(new_args.keys()): raise VQLError( "The required fields {} are missing.".format( required_fields - set(new_args.keys()) ) ) return new_args class BaseTool(BotBase, ABC): description: str func: Optional[Callable] = None args_schema: Optional[ArgSchema] special_params: Dict = {} def model_post_init(self, __context: Any) -> None: for _, attr_value in self.__dict__.items(): if isinstance(attr_value, BotBase): attr_value._parent = self @property def workflow_instance_id(self) -> str: if hasattr(self, "_parent"): return self._parent.workflow_instance_id return None @workflow_instance_id.setter def workflow_instance_id(self, value: str): if hasattr(self, "_parent"): self._parent.workflow_instance_id = value def _run(self, **input) -> str: """Implement this function or pass 'func' arg when initializing.""" return self.func(**input) async def _arun(self, **input) -> str: """Implement this function or pass 'func' arg when initializing.""" return await self.func(**input) def run(self, input: Any) -> str: if self.args_schema != None: if type(input) != dict: raise ValueError( "The input type must be dict when args_schema is specified." ) self.args_schema.validate_args(input) return self._run(**input, **self.special_params) async def arun(self, input: Any) -> str: if self.args_schema != None: if type(input) != dict: raise ValueError( "The input type must be dict when args_schema is specified." ) self.args_schema.validate_args(input) return await self._arun(**input, **self.special_params) def generate_schema(self): if not self.args_schema: return { "type": "function", "description": self.description, "function": { "name": self.name, "parameters": { "type": "object", "name": "input", "required": ["input"], }, }, } else: properties, required = self.args_schema.generate_schema() return { "type": "function", "function": { "name": self.name, "description": self.description, "parameters": { "type": "object", "properties": properties, "required": required, }, }, } class BaseModelTool(BaseTool, ABC): # data_handler: Optional[SQLDataHandler] def visual_prompting( self, image: Image.Image, annotation: List[Target], prompting_type: str = "label_on_img", include_labels: Union[List, set, tuple] = None, exclude_labels: Union[List, set, tuple] = None, ) -> List[Image.Image]: annotator = Annotator(image) for obj in annotation: if (exclude_labels is not None and obj.label in exclude_labels) or ( include_labels is not None and obj.label not in include_labels ): continue if obj.bbox: annotator.box_label(obj.bbox, obj.label, color="red") # TODO: Add polygon support return annotator.result() def infer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: """The model inference step. Only support OD type detection. Args: images (List[Image.Image]): The list of input images. Image should be PIL Image object. kwargs (dict): The additional arguments for the model. Returns: List[List[Target]]: The detection results. """ def ainfer(self, images: List[Image.Image], kwargs) -> List[List[Target]]: """The async version of model inference step. Only support OD type detection. Args: images (List[Image.Image]): The list of input images. Image should be PIL Image object. kwargs (dict): The additional arguments for the model. Returns: List[List[Target]]: The detection results. """ class MemoryTool(BaseTool): memory_handler: Optional[SQLDataHandler] def generate_schema(self) -> dict: """Generate the data table schema in dict format. Returns: dict: The data table schema. Including the table name, and the name, data type and additional information of each column. """ table = self.memory_handler.table schema = {"table_name": table.__tablename__, "columns": []} for column in table.__table__.columns: schema["columns"].append( { "name": column.name, "type": column.type.__visit_name__, "info": column.info, } ) return schema def generate_prompt(self): pass def _run(self): self.memory_handler.execute_sql() async def _arun(self): self.memory_handler.execute_sql()