Spaces:
Running
Running
File size: 1,913 Bytes
111afa2 518d841 111afa2 518d841 111afa2 518d841 111afa2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import modal
from smolagents import Tool
from modal_apps.app import app
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp
class TaskModelRetrieverTool(Tool):
name = "task_model_retriever"
description = """
For a given task, retrieve the models that can perform that task.
The supported tasks are:
- object-detection
- image-segmentation
- image-classification
The query is a string that describes the task the model needs to perform.
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
"""
inputs = {
"task": {
"type": "string",
"description": "The task the model needs to perform.",
},
"query": {
"type": "string",
"description": "The class of objects the model needs to detect.",
},
}
output_type = "object"
def __init__(self):
super().__init__()
self.tasks = ["object-detection", "image-segmentation", "image-classification"]
self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__)
def setup(self):
self.tool: TaskModelRetrieverModalApp = self.tool_class()
def forward(self, task: str, query: str) -> str:
assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}"
assert isinstance(query, str), "Your search query must be a string"
print(f"Retrieving models for task {task} with query {query}")
if task == "object-detection":
result = self.tool.object_detection_search.remote(query)
elif task == "image-segmentation":
result = self.tool.image_segmentation_search.remote(query)
elif task == "image-classification":
result = self.tool.image_classification_search.remote(query)
return result
|