import modal from smolagents import Tool from modal_apps.app import app from modal_apps.task_model_retriever import TaskModelRetrieverModalApp class TaskModelRetrieverTool(Tool): name = "task_model_retriever" description = """ For a given task, retrieve the models that can perform that task. The supported tasks are: - object-detection - image-segmentation - image-classification The query is a string that describes the task the model needs to perform. The output is a dictionary with the model id as the key and the labels that the model can detect as the value. """ inputs = { "task": { "type": "string", "description": "The task the model needs to perform.", }, "query": { "type": "string", "description": "The class of objects the model needs to detect.", }, } output_type = "object" def __init__(self): super().__init__() self.tasks = ["object-detection", "image-segmentation", "image-classification"] self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__) def setup(self): self.tool: TaskModelRetrieverModalApp = self.tool_class() def forward(self, task: str, query: str) -> str: assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}" assert isinstance(query, str), "Your search query must be a string" print(f"Retrieving models for task {task} with query {query}") if task == "object-detection": result = self.tool.object_detection_search.remote(query) elif task == "image-segmentation": result = self.tool.image_segmentation_search.remote(query) elif task == "image-classification": result = self.tool.image_classification_search.remote(query) return result