Spaces:
Running
Running
import modal | |
from smolagents import Tool | |
from modal_apps.app import app | |
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp | |
class TaskModelRetrieverTool(Tool): | |
name = "task_model_retriever" | |
description = """ | |
For a given task, retrieve the models that can perform that task. | |
The supported tasks are: | |
- object-detection | |
- image-segmentation | |
The query is a string that describes the task the model needs to perform. | |
The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
""" | |
inputs = { | |
"task": { | |
"type": "string", | |
"description": "The task the model needs to perform.", | |
}, | |
"query": { | |
"type": "string", | |
"description": "The class of objects the model needs to detect.", | |
}, | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
self.tasks = ["object-detection", "image-segmentation"] | |
self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__) | |
def setup(self): | |
self.tool: TaskModelRetrieverModalApp = self.tool_class() | |
def forward(self, task: str, query: str) -> str: | |
assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}" | |
assert isinstance(query, str), "Your search query must be a string" | |
print(f"Retrieving models for task {task} with query {query}") | |
if task == "object-detection": | |
result = self.tool.object_detection_search.remote(query) | |
elif task == "image-segmentation": | |
result = self.tool.image_segmentation_search.remote(query) | |
return result | |