ScouterAI / tools /task_model_retriever.py
stevenbucaille's picture
Enhance image processing capabilities and update project structure
111afa2
raw
history blame
1.75 kB
import modal
from smolagents import Tool
from modal_apps.app import app
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp
class TaskModelRetrieverTool(Tool):
name = "task_model_retriever"
description = """
For a given task, retrieve the models that can perform that task.
The supported tasks are:
- object-detection
- image-segmentation
The query is a string that describes the task the model needs to perform.
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
"""
inputs = {
"task": {
"type": "string",
"description": "The task the model needs to perform.",
},
"query": {
"type": "string",
"description": "The class of objects the model needs to detect.",
},
}
output_type = "object"
def __init__(self):
super().__init__()
self.tasks = ["object-detection", "image-segmentation"]
self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__)
def setup(self):
self.tool: TaskModelRetrieverModalApp = self.tool_class()
def forward(self, task: str, query: str) -> str:
assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}"
assert isinstance(query, str), "Your search query must be a string"
print(f"Retrieving models for task {task} with query {query}")
if task == "object-detection":
result = self.tool.object_detection_search.remote(query)
elif task == "image-segmentation":
result = self.tool.image_segmentation_search.remote(query)
return result