Spaces:
Running
Running
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
import modal | |
from smolagents import Tool | |
from modal_apps.app import app | |
from modal_apps.task_model_retriever import TaskModelRetrieverModalApp | |
class TaskModelRetrieverTool(Tool): | |
name = "task_model_retriever" | |
description = """ | |
For a given task, retrieve the models that can perform that task. | |
The supported tasks are: | |
- object-detection | |
- image-segmentation | |
- image-classification | |
The query is a string that describes the task the model needs to perform. | |
The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
""" | |
inputs = { | |
"task": { | |
"type": "string", | |
"description": "The task the model needs to perform.", | |
}, | |
"query": { | |
"type": "string", | |
"description": "The class of objects the model needs to detect.", | |
}, | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
self.tasks = ["object-detection", "image-segmentation", "image-classification"] | |
self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__) | |
def setup(self): | |
self.tool: TaskModelRetrieverModalApp = self.tool_class() | |
def forward(self, task: str, query: str) -> str: | |
assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}" | |
assert isinstance(query, str), "Your search query must be a string" | |
print(f"Retrieving models for task {task} with query {query}") | |
if task == "object-detection": | |
result = self.tool.object_detection_search.remote(query) | |
elif task == "image-segmentation": | |
result = self.tool.image_segmentation_search.remote(query) | |
elif task == "image-classification": | |
result = self.tool.image_classification_search.remote(query) | |
return result | |