Spaces:
Running
Running
Commit
·
7e327f2
1
Parent(s):
4fbddf5
Add initial project structure with core functionality for image processing agents
Browse files- Created .gitignore to exclude Python-generated files and virtual environments.
- Added .python-version to specify Python version 3.10.
- Implemented main LLM functionality in llm.py.
- Defined project metadata and dependencies in pyproject.toml and requirements.txt.
- Developed image processing tools including bounding box drawing, cropping, and upscaling.
- Integrated object detection and model retrieval capabilities in remote tools.
- Established dataset creation and knowledge base preparation scripts.
- Set up initial modal application structure for remote processing.
- .gitignore +12 -0
- .python-version +1 -0
- agents/all_agents.py +37 -0
- llm.py +22 -0
- pyproject.toml +32 -0
- rag/__init__.py +0 -0
- rag/create_dataset.py +81 -0
- rag/prepare_knowledge_base.py +56 -0
- rag/settings.py +22 -0
- remote_tools/app.py +5 -0
- remote_tools/deploy.py +10 -0
- remote_tools/image.py +28 -0
- remote_tools/object_detection_tool.py +86 -0
- remote_tools/rag_tool.py +65 -0
- remote_tools/upscaler.py +66 -0
- remote_tools/volume.py +3 -0
- requirements.txt +165 -0
- tools/bbox_drawing_tool.py +58 -0
- tools/cropping_tool.py +55 -0
- tools/hf_api_tool.py +26 -0
- tools/rag_tool.py +38 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python-generated files
|
2 |
+
__pycache__/
|
3 |
+
*.py[oc]
|
4 |
+
build/
|
5 |
+
dist/
|
6 |
+
wheels/
|
7 |
+
*.egg-info
|
8 |
+
|
9 |
+
# Virtual environments
|
10 |
+
.venv
|
11 |
+
|
12 |
+
.env
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
agents/all_agents.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import CodeAgent, LogLevel
|
2 |
+
from remote_tools.rag_tool import RemoteObjectDetectionModelRetrieverTool
|
3 |
+
from tools.bbox_drawing_tool import BBoxDrawingTool
|
4 |
+
from tools.cropping_tool import CroppingTool
|
5 |
+
from remote_tools.object_detection_tool import RemoteObjectDetectionTool
|
6 |
+
from remote_tools.upscaler import RemoteUpscalerTool
|
7 |
+
|
8 |
+
|
9 |
+
def get_master_agent(llm):
|
10 |
+
description = """
|
11 |
+
You are an agent that can perform tasks on an image.
|
12 |
+
You can use the following tools to perform tasks on an image:
|
13 |
+
- object_detection_tool: to detect objects in an image, you must provide the image to the agents.
|
14 |
+
- object_detection_model_retriever: to retrieve object detection models, you must provide the type of class that a model can detect.
|
15 |
+
|
16 |
+
If you don't know what model to use, you can use the object_detection_model_retriever tool to retrieve the model.
|
17 |
+
Never assume an invented model name, always use the model name provided by the object_detection_model_retriever tool.
|
18 |
+
Use batching to perform tasks on multiple images at once when a tool supports it.
|
19 |
+
You have access to the variable "image" which is the image to perform tasks on, no need to load it, it is already loaded.
|
20 |
+
You can also use opencv to draw the bounding boxes on the image.
|
21 |
+
Always use the variable "image" to draw the bounding boxes on the image.
|
22 |
+
"""
|
23 |
+
master_agent = CodeAgent(
|
24 |
+
name="master_agent",
|
25 |
+
description=description,
|
26 |
+
model=llm,
|
27 |
+
tools=[
|
28 |
+
RemoteObjectDetectionTool(),
|
29 |
+
BBoxDrawingTool(),
|
30 |
+
CroppingTool(),
|
31 |
+
RemoteUpscalerTool(),
|
32 |
+
RemoteObjectDetectionModelRetrieverTool(),
|
33 |
+
],
|
34 |
+
verbosity_level=LogLevel.DEBUG,
|
35 |
+
)
|
36 |
+
print("Loaded master agent")
|
37 |
+
return master_agent
|
llm.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import OpenAIServerModel, LiteLLMModel
|
2 |
+
import os
|
3 |
+
|
4 |
+
LOCAL_LLM_SETTINGS = {
|
5 |
+
"api_base": "http://127.0.0.1:1234/v1",
|
6 |
+
"api_key": "api-key",
|
7 |
+
"model_id": "gemma-3-12b-it-qat",
|
8 |
+
}
|
9 |
+
|
10 |
+
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
11 |
+
|
12 |
+
assert ANTHROPIC_API_KEY is not None, "ANTHROPIC_API_KEY is not set"
|
13 |
+
|
14 |
+
|
15 |
+
def get_default_model():
|
16 |
+
model = LiteLLMModel(
|
17 |
+
model_id="claude-3-7-sonnet-20250219",
|
18 |
+
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
19 |
+
reasoning_effort="low",
|
20 |
+
)
|
21 |
+
print("Loaded LLM model")
|
22 |
+
return model
|
pyproject.toml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "image-processing-agent"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.10"
|
7 |
+
dependencies = [
|
8 |
+
"accelerate>=1.7.0",
|
9 |
+
"datasets>=3.6.0",
|
10 |
+
"diffusers>=0.33.1",
|
11 |
+
"faiss-cpu>=1.11.0",
|
12 |
+
"faiss-gpu>=1.7.2",
|
13 |
+
"gradio>=5.33.0",
|
14 |
+
"hf-transfer>=0.1.9",
|
15 |
+
"huggingface-hub[cli]>=0.32.4",
|
16 |
+
"langchain>=0.3.25",
|
17 |
+
"langchain-community>=0.3.24",
|
18 |
+
"langchain-huggingface>=0.2.0",
|
19 |
+
"langchain-openai>=0.3.19",
|
20 |
+
"matplotlib>=3.10.3",
|
21 |
+
"modal>=1.0.3",
|
22 |
+
"opencv-python>=4.11.0.86",
|
23 |
+
"pandas>=2.3.0",
|
24 |
+
"rank-bm25>=0.2.2",
|
25 |
+
"safetensors>=0.5.3",
|
26 |
+
"scipy>=1.15.3",
|
27 |
+
"sentence-transformers>=4.1.0",
|
28 |
+
"smolagents[litellm,openai]>=1.17.0",
|
29 |
+
"timm>=1.0.15",
|
30 |
+
"torch>=2.7.1",
|
31 |
+
"transformers>=4.52.4",
|
32 |
+
]
|
rag/__init__.py
ADDED
File without changes
|
rag/create_dataset.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from smolagents import Tool
|
3 |
+
from huggingface_hub import HfApi, hf_hub_download, ModelCard
|
4 |
+
from datasets import Dataset, Features, Value
|
5 |
+
|
6 |
+
|
7 |
+
def get_model_ids(pipeline_tag: str) -> list[str]:
|
8 |
+
hf_api = HfApi()
|
9 |
+
models = hf_api.list_models(
|
10 |
+
library=["transformers"],
|
11 |
+
pipeline_tag=pipeline_tag,
|
12 |
+
gated=False,
|
13 |
+
fetch_config=True,
|
14 |
+
)
|
15 |
+
models = list(models)
|
16 |
+
model_ids = [model.id for model in models]
|
17 |
+
return model_ids
|
18 |
+
|
19 |
+
|
20 |
+
def get_model_card(model_id: str) -> str:
|
21 |
+
try:
|
22 |
+
model_card = ModelCard.load(model_id)
|
23 |
+
return model_card.text
|
24 |
+
except Exception as e:
|
25 |
+
return ""
|
26 |
+
|
27 |
+
|
28 |
+
def get_model_labels(model_id: str) -> list[str]:
|
29 |
+
hf_api = HfApi()
|
30 |
+
if hf_api.file_exists(model_id, filename="config.json"):
|
31 |
+
config_path = hf_hub_download(model_id, filename="config.json")
|
32 |
+
with open(config_path, "r") as f:
|
33 |
+
try:
|
34 |
+
model_config = json.load(f)
|
35 |
+
except json.JSONDecodeError:
|
36 |
+
return [""]
|
37 |
+
if "id2label" in model_config:
|
38 |
+
labels = list(model_config["id2label"].values())
|
39 |
+
labels = [str(label).lower() for label in labels]
|
40 |
+
return labels
|
41 |
+
else:
|
42 |
+
return [""]
|
43 |
+
else:
|
44 |
+
return [""]
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(pipeline_tag: str):
|
48 |
+
def dataset_gen(model_ids: list[str]):
|
49 |
+
for model_id in model_ids:
|
50 |
+
model_card = get_model_card(model_id)
|
51 |
+
model_labels = get_model_labels(model_id)
|
52 |
+
if len(model_labels) > 1 and len(model_card) > 0:
|
53 |
+
yield {
|
54 |
+
"model_id": model_id,
|
55 |
+
"model_card": model_card,
|
56 |
+
"model_labels": model_labels,
|
57 |
+
}
|
58 |
+
|
59 |
+
model_ids = get_model_ids(pipeline_tag)
|
60 |
+
|
61 |
+
dataset = Dataset.from_generator(
|
62 |
+
dataset_gen,
|
63 |
+
gen_kwargs={"model_ids": model_ids},
|
64 |
+
features=Features(
|
65 |
+
{
|
66 |
+
"model_id": Value("string"),
|
67 |
+
"model_card": Value("string"),
|
68 |
+
"model_labels": [Value("string")],
|
69 |
+
}
|
70 |
+
),
|
71 |
+
num_proc=12,
|
72 |
+
)
|
73 |
+
|
74 |
+
return dataset
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
dataset = create_dataset("object-detection")
|
79 |
+
print(dataset)
|
80 |
+
dataset.push_to_hub("stevenbucaille/object-detection-models-dataset", )
|
81 |
+
# dataset.push_to_hub("stevenbucaille/object-detection-models-dataset")
|
rag/prepare_knowledge_base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
from langchain_core.documents import Document
|
3 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
+
import faiss
|
6 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
7 |
+
from rag.settings import get_embeddings_model
|
8 |
+
|
9 |
+
|
10 |
+
def get_vector_store():
|
11 |
+
embeddings = get_embeddings_model()
|
12 |
+
index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))
|
13 |
+
|
14 |
+
vector_store = FAISS(
|
15 |
+
embedding_function=embeddings,
|
16 |
+
index=index,
|
17 |
+
docstore=InMemoryDocstore(),
|
18 |
+
index_to_docstore_id={},
|
19 |
+
)
|
20 |
+
return vector_store
|
21 |
+
|
22 |
+
|
23 |
+
def get_docs(dataset):
|
24 |
+
source_docs = [
|
25 |
+
Document(
|
26 |
+
page_content=model["model_card"],
|
27 |
+
metadata={
|
28 |
+
"model_id": model["model_id"],
|
29 |
+
"model_labels": model["model_labels"],
|
30 |
+
},
|
31 |
+
)
|
32 |
+
for model in dataset
|
33 |
+
]
|
34 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
35 |
+
chunk_size=500, # Characters per chunk
|
36 |
+
chunk_overlap=50, # Overlap between chunks to maintain context
|
37 |
+
add_start_index=True,
|
38 |
+
strip_whitespace=True,
|
39 |
+
separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting
|
40 |
+
)
|
41 |
+
docs_processed = text_splitter.split_documents(source_docs)
|
42 |
+
print(f"Knowledge base prepared with {len(docs_processed)} document chunks")
|
43 |
+
return docs_processed
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
dataset = datasets.load_dataset(
|
48 |
+
"stevenbucaille/object-detection-models-dataset", split="train"
|
49 |
+
)
|
50 |
+
docs_processed = get_docs(dataset)
|
51 |
+
vector_store = get_vector_store()
|
52 |
+
vector_store.add_documents(docs_processed)
|
53 |
+
vector_store.save_local(
|
54 |
+
folder_path="vector_store",
|
55 |
+
index_name="object_detection_models_faiss_index",
|
56 |
+
)
|
rag/settings.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
2 |
+
from langchain_community.vectorstores import FAISS
|
3 |
+
|
4 |
+
|
5 |
+
def get_embeddings_model():
|
6 |
+
embeddings = HuggingFaceEmbeddings(
|
7 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
8 |
+
model_kwargs={"device": "cuda"},
|
9 |
+
encode_kwargs={"normalize_embeddings": True},
|
10 |
+
show_progress=True,
|
11 |
+
)
|
12 |
+
print("Loaded embeddings model")
|
13 |
+
return embeddings
|
14 |
+
|
15 |
+
|
16 |
+
def get_vector_store():
|
17 |
+
return FAISS.load_local(
|
18 |
+
folder_path="vector_store",
|
19 |
+
embeddings=get_embeddings_model(),
|
20 |
+
index_name="object_detection_models_faiss_index",
|
21 |
+
allow_dangerous_deserialization=True,
|
22 |
+
)
|
remote_tools/app.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
from .image import image
|
4 |
+
|
5 |
+
app = modal.App("image-agent-tools", image=image)
|
remote_tools/deploy.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
from .app import app
|
4 |
+
from .object_detection_tool import app as object_detection_tool_app
|
5 |
+
from .upscaler import app as upscaler_tool_app
|
6 |
+
from .rag_tool import app as rag_tool_app
|
7 |
+
|
8 |
+
app.include(object_detection_tool_app)
|
9 |
+
app.include(upscaler_tool_app)
|
10 |
+
app.include(rag_tool_app)
|
remote_tools/image.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
cuda_version = "12.4.0" # should be no greater than host CUDA version
|
4 |
+
flavor = "devel" # includes full CUDA toolkit
|
5 |
+
operating_sys = "ubuntu22.04"
|
6 |
+
tag = f"{cuda_version}-{flavor}-{operating_sys}"
|
7 |
+
|
8 |
+
cuda_dev_image = modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
|
9 |
+
|
10 |
+
image = (
|
11 |
+
cuda_dev_image.apt_install(
|
12 |
+
"git",
|
13 |
+
"libglib2.0-0",
|
14 |
+
"libsm6",
|
15 |
+
"libxrender1",
|
16 |
+
"libxext6",
|
17 |
+
"ffmpeg",
|
18 |
+
"libgl1",
|
19 |
+
)
|
20 |
+
.add_local_file("requirements.txt", "/app_requirements.txt", copy=True)
|
21 |
+
.run_commands(
|
22 |
+
[
|
23 |
+
"cat /app_requirements.txt",
|
24 |
+
"uv pip install --system --requirement /app_requirements.txt",
|
25 |
+
]
|
26 |
+
)
|
27 |
+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": "/cache"})
|
28 |
+
)
|
remote_tools/object_detection_tool.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
3 |
+
import torch
|
4 |
+
from smolagents import Tool
|
5 |
+
|
6 |
+
from .app import app
|
7 |
+
from .image import image
|
8 |
+
|
9 |
+
|
10 |
+
@app.cls(gpu="T4", image=image)
|
11 |
+
class RemoteObjectDetectionModalApp:
|
12 |
+
model_name: str = modal.parameter()
|
13 |
+
|
14 |
+
@modal.method()
|
15 |
+
def forward(self, image):
|
16 |
+
self.model = AutoModelForObjectDetection.from_pretrained(self.model_name)
|
17 |
+
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
|
18 |
+
self.model.eval()
|
19 |
+
|
20 |
+
# Preprocess image
|
21 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
22 |
+
with torch.no_grad():
|
23 |
+
outputs = self.model(**inputs)
|
24 |
+
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
|
25 |
+
results = self.processor.post_process_object_detection(
|
26 |
+
outputs, target_sizes=target_sizes, threshold=0.5
|
27 |
+
)[0]
|
28 |
+
|
29 |
+
boxes = []
|
30 |
+
for score, label, box in zip(
|
31 |
+
results["scores"], results["labels"], results["boxes"]
|
32 |
+
):
|
33 |
+
boxes.append(
|
34 |
+
{
|
35 |
+
"box": box.tolist(), # [xmin, ymin, xmax, ymax]
|
36 |
+
"score": score.item(),
|
37 |
+
"label": self.model.config.id2label[label.item()],
|
38 |
+
}
|
39 |
+
)
|
40 |
+
return boxes
|
41 |
+
|
42 |
+
|
43 |
+
class RemoteObjectDetectionTool(Tool):
|
44 |
+
name = "object_detection"
|
45 |
+
description = """
|
46 |
+
Given an image, detect objects and return bounding boxes.
|
47 |
+
The image is a PIL image.
|
48 |
+
The output is a list of dictionaries containing the bounding boxes with the following keys:
|
49 |
+
- box: a list of 4 numbers [xmin, ymin, xmax, ymax]
|
50 |
+
- score: a number between 0 and 1
|
51 |
+
- label: a string
|
52 |
+
The bounding boxes are in the format of [xmin, ymin, xmax, ymax].
|
53 |
+
You need to provide the model name to use for object detection.
|
54 |
+
The tool returns a list of bounding boxes for all the objects in the image.
|
55 |
+
"""
|
56 |
+
|
57 |
+
inputs = {
|
58 |
+
"image": {
|
59 |
+
"type": "image",
|
60 |
+
"description": "The image to detect objects in",
|
61 |
+
},
|
62 |
+
"model_name": {
|
63 |
+
"type": "string",
|
64 |
+
"description": "The name of the model to use for object detection",
|
65 |
+
},
|
66 |
+
}
|
67 |
+
output_type = "object"
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
super().__init__()
|
71 |
+
self.tool_class = modal.Cls.from_name(
|
72 |
+
app.name, RemoteObjectDetectionModalApp.__name__
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(
|
76 |
+
self,
|
77 |
+
image,
|
78 |
+
model_name: str,
|
79 |
+
):
|
80 |
+
self.tool = self.tool_class(model_name=model_name)
|
81 |
+
bboxes = self.tool.forward.remote(image)
|
82 |
+
for bbox in bboxes:
|
83 |
+
print(
|
84 |
+
f"Found {bbox['label']} with score: {bbox['score']} at box: {bbox['box']}"
|
85 |
+
)
|
86 |
+
return bboxes
|
remote_tools/rag_tool.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
3 |
+
from smolagents import Tool
|
4 |
+
import modal
|
5 |
+
|
6 |
+
from .app import app
|
7 |
+
from .image import image
|
8 |
+
from .volume import volume
|
9 |
+
|
10 |
+
|
11 |
+
@app.cls(gpu="T4", image=image, volumes={"/volume": volume})
|
12 |
+
class RemoteObjectDetectionModelRetrieverModalApp:
|
13 |
+
@modal.enter()
|
14 |
+
def setup(self):
|
15 |
+
self.vector_store = FAISS.load_local(
|
16 |
+
folder_path="/volume/vector_store",
|
17 |
+
embeddings=HuggingFaceEmbeddings(
|
18 |
+
model_name="all-MiniLM-L6-v2",
|
19 |
+
model_kwargs={"device": "cuda"},
|
20 |
+
encode_kwargs={"normalize_embeddings": True},
|
21 |
+
show_progress=True,
|
22 |
+
),
|
23 |
+
index_name="object_detection_models_faiss_index",
|
24 |
+
allow_dangerous_deserialization=True,
|
25 |
+
)
|
26 |
+
|
27 |
+
@modal.method()
|
28 |
+
def forward(self, query: str) -> str:
|
29 |
+
docs = self.vector_store.similarity_search(query, k=7)
|
30 |
+
model_ids = [doc.metadata["model_id"] for doc in docs]
|
31 |
+
model_labels = [doc.metadata["model_labels"] for doc in docs]
|
32 |
+
models_dict = {
|
33 |
+
model_id: model_labels
|
34 |
+
for model_id, model_labels in zip(model_ids, model_labels)
|
35 |
+
}
|
36 |
+
return models_dict
|
37 |
+
|
38 |
+
|
39 |
+
class RemoteObjectDetectionModelRetrieverTool(Tool):
|
40 |
+
name = "object_detection_model_retriever"
|
41 |
+
description = """
|
42 |
+
For a given class of objects, retrieve the models that can detect that class.
|
43 |
+
The query is a string that describes the class of objects the model needs to detect.
|
44 |
+
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
|
45 |
+
"""
|
46 |
+
inputs = {
|
47 |
+
"query": {
|
48 |
+
"type": "string",
|
49 |
+
"description": "The class of objects the model needs to detect.",
|
50 |
+
}
|
51 |
+
}
|
52 |
+
output_type = "object"
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
super().__init__()
|
56 |
+
self.tool_class = modal.Cls.from_name(
|
57 |
+
app.name, RemoteObjectDetectionModelRetrieverModalApp.__name__
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(self, query: str) -> str:
|
61 |
+
assert isinstance(query, str), "Your search query must be a string"
|
62 |
+
|
63 |
+
tool = self.tool_class()
|
64 |
+
result = tool.forward.remote(query)
|
65 |
+
return result
|
remote_tools/upscaler.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
import torch
|
3 |
+
from smolagents import AgentImage, Tool
|
4 |
+
from diffusers import StableDiffusionUpscalePipeline
|
5 |
+
|
6 |
+
from .app import app
|
7 |
+
from .image import image
|
8 |
+
|
9 |
+
|
10 |
+
@app.cls(gpu="T4", image=image, scaledown_window=60 * 5)
|
11 |
+
class RemoteUpscalerModalApp:
|
12 |
+
@modal.enter()
|
13 |
+
def setup(self):
|
14 |
+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
15 |
+
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
16 |
+
model_id, torch_dtype=torch.float16
|
17 |
+
)
|
18 |
+
self.pipeline = self.pipeline.to("cuda")
|
19 |
+
|
20 |
+
@modal.batched(max_batch_size=4, wait_ms=1000)
|
21 |
+
def forward(self, low_res_imgs, prompts: list[str]):
|
22 |
+
print(len(low_res_imgs))
|
23 |
+
print(low_res_imgs)
|
24 |
+
print(prompts)
|
25 |
+
low_res_imgs = [
|
26 |
+
img.resize(
|
27 |
+
(min(512, img.width), min(512, img.height))
|
28 |
+
) for img in low_res_imgs
|
29 |
+
]
|
30 |
+
upscaled_images = self.pipeline(prompt=prompts, image=low_res_imgs).images
|
31 |
+
return upscaled_images
|
32 |
+
|
33 |
+
|
34 |
+
class RemoteUpscalerTool(Tool):
|
35 |
+
name = "upscaler"
|
36 |
+
description = """
|
37 |
+
Perform upscaling on images.
|
38 |
+
The "low_res_imgs" are PIL images.
|
39 |
+
The "prompts" are strings.
|
40 |
+
The output is a list of PIL images.
|
41 |
+
You can upscale multiple images at once.
|
42 |
+
"""
|
43 |
+
|
44 |
+
inputs = {
|
45 |
+
"low_res_imgs": {
|
46 |
+
"type": "array",
|
47 |
+
"description": "The low resolution images to upscale",
|
48 |
+
},
|
49 |
+
"prompts": {
|
50 |
+
"type": "array",
|
51 |
+
"description": "The prompts to upscale the images",
|
52 |
+
},
|
53 |
+
}
|
54 |
+
output_type = "object"
|
55 |
+
|
56 |
+
def __init__(self):
|
57 |
+
super().__init__()
|
58 |
+
tool_class = modal.Cls.from_name(app.name, RemoteUpscalerModalApp.__name__)
|
59 |
+
self.tool = tool_class()
|
60 |
+
|
61 |
+
def forward(self, low_res_imgs: list[AgentImage], prompts: list[str]):
|
62 |
+
# Modal's forward.map() handles batching internally
|
63 |
+
# We can use it synchronously since Modal manages the async execution
|
64 |
+
upscaled_images = self.tool.forward.map(low_res_imgs, prompts)
|
65 |
+
# Convert the generator to a list to get all results
|
66 |
+
return list(upscaled_images)
|
remote_tools/volume.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
volume = modal.Volume.from_name("hackathon")
|
requirements.txt
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file was autogenerated by uv via the following command:
|
2 |
+
# uv pip compile pyproject.toml -o requirements.txt --no-annotate
|
3 |
+
accelerate==1.7.0
|
4 |
+
aiofiles==24.1.0
|
5 |
+
aiohappyeyeballs==2.6.1
|
6 |
+
aiohttp==3.12.9
|
7 |
+
aiosignal==1.3.2
|
8 |
+
annotated-types==0.7.0
|
9 |
+
anyio==4.9.0
|
10 |
+
async-timeout==4.0.3
|
11 |
+
attrs==25.3.0
|
12 |
+
certifi==2025.4.26
|
13 |
+
charset-normalizer==3.4.2
|
14 |
+
click==8.1.8
|
15 |
+
contourpy==1.3.2
|
16 |
+
cycler==0.12.1
|
17 |
+
dataclasses-json==0.6.7
|
18 |
+
datasets==3.6.0
|
19 |
+
diffusers==0.33.1
|
20 |
+
dill==0.3.8
|
21 |
+
distro==1.9.0
|
22 |
+
exceptiongroup==1.3.0
|
23 |
+
faiss-cpu==1.11.0
|
24 |
+
faiss-gpu==1.7.2
|
25 |
+
fastapi==0.115.12
|
26 |
+
ffmpy==0.6.0
|
27 |
+
filelock==3.18.0
|
28 |
+
fonttools==4.58.1
|
29 |
+
frozenlist==1.6.2
|
30 |
+
fsspec==2025.3.0
|
31 |
+
gradio==5.33.0
|
32 |
+
gradio-client==1.10.2
|
33 |
+
greenlet==3.2.3
|
34 |
+
groovy==0.1.2
|
35 |
+
grpclib==0.4.7
|
36 |
+
h11==0.16.0
|
37 |
+
h2==4.2.0
|
38 |
+
hf-transfer==0.1.9
|
39 |
+
hf-xet==1.1.3
|
40 |
+
hpack==4.1.0
|
41 |
+
httpcore==1.0.9
|
42 |
+
httpx==0.28.1
|
43 |
+
httpx-sse==0.4.0
|
44 |
+
huggingface-hub==0.32.4
|
45 |
+
hyperframe==6.1.0
|
46 |
+
idna==3.10
|
47 |
+
importlib-metadata==8.7.0
|
48 |
+
inquirerpy==0.3.4
|
49 |
+
jinja2==3.1.6
|
50 |
+
jiter==0.10.0
|
51 |
+
joblib==1.5.1
|
52 |
+
jsonpatch==1.33
|
53 |
+
jsonpointer==3.0.0
|
54 |
+
jsonschema==4.24.0
|
55 |
+
jsonschema-specifications==2025.4.1
|
56 |
+
kiwisolver==1.4.8
|
57 |
+
langchain==0.3.25
|
58 |
+
langchain-community==0.3.24
|
59 |
+
langchain-core==0.3.64
|
60 |
+
langchain-huggingface==0.2.0
|
61 |
+
langchain-openai==0.3.21
|
62 |
+
langchain-text-splitters==0.3.8
|
63 |
+
langsmith==0.3.45
|
64 |
+
litellm==1.72.1
|
65 |
+
markdown-it-py==3.0.0
|
66 |
+
markupsafe==3.0.2
|
67 |
+
marshmallow==3.26.1
|
68 |
+
matplotlib==3.10.3
|
69 |
+
mdurl==0.1.2
|
70 |
+
modal==1.0.3
|
71 |
+
mpmath==1.3.0
|
72 |
+
multidict==6.4.4
|
73 |
+
multiprocess==0.70.16
|
74 |
+
mypy-extensions==1.1.0
|
75 |
+
networkx==3.4.2
|
76 |
+
numpy==2.2.6
|
77 |
+
nvidia-cublas-cu12==12.6.4.1
|
78 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
79 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
80 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
81 |
+
nvidia-cudnn-cu12==9.5.1.17
|
82 |
+
nvidia-cufft-cu12==11.3.0.4
|
83 |
+
nvidia-cufile-cu12==1.11.1.6
|
84 |
+
nvidia-curand-cu12==10.3.7.77
|
85 |
+
nvidia-cusolver-cu12==11.7.1.2
|
86 |
+
nvidia-cusparse-cu12==12.5.4.2
|
87 |
+
nvidia-cusparselt-cu12==0.6.3
|
88 |
+
nvidia-nccl-cu12==2.26.2
|
89 |
+
nvidia-nvjitlink-cu12==12.6.85
|
90 |
+
nvidia-nvtx-cu12==12.6.77
|
91 |
+
openai==1.84.0
|
92 |
+
opencv-python==4.11.0.86
|
93 |
+
orjson==3.10.18
|
94 |
+
packaging==24.2
|
95 |
+
pandas==2.3.0
|
96 |
+
pfzy==0.3.4
|
97 |
+
pillow==11.2.1
|
98 |
+
prompt-toolkit==3.0.51
|
99 |
+
propcache==0.3.1
|
100 |
+
protobuf==6.31.1
|
101 |
+
psutil==7.0.0
|
102 |
+
pyarrow==20.0.0
|
103 |
+
pydantic==2.11.5
|
104 |
+
pydantic-core==2.33.2
|
105 |
+
pydantic-settings==2.9.1
|
106 |
+
pydub==0.25.1
|
107 |
+
pygments==2.19.1
|
108 |
+
pyparsing==3.2.3
|
109 |
+
python-dateutil==2.9.0.post0
|
110 |
+
python-dotenv==1.1.0
|
111 |
+
python-multipart==0.0.20
|
112 |
+
pytz==2025.2
|
113 |
+
pyyaml==6.0.2
|
114 |
+
rank-bm25==0.2.2
|
115 |
+
referencing==0.36.2
|
116 |
+
regex==2024.11.6
|
117 |
+
requests==2.32.3
|
118 |
+
requests-toolbelt==1.0.0
|
119 |
+
rich==14.0.0
|
120 |
+
rpds-py==0.25.1
|
121 |
+
ruff==0.11.12
|
122 |
+
safehttpx==0.1.6
|
123 |
+
safetensors==0.5.3
|
124 |
+
scikit-learn==1.7.0
|
125 |
+
scipy==1.15.3
|
126 |
+
semantic-version==2.10.0
|
127 |
+
sentence-transformers==4.1.0
|
128 |
+
setuptools==80.9.0
|
129 |
+
shellingham==1.5.4
|
130 |
+
sigtools==4.0.1
|
131 |
+
six==1.17.0
|
132 |
+
smolagents==1.17.0
|
133 |
+
sniffio==1.3.1
|
134 |
+
sqlalchemy==2.0.41
|
135 |
+
starlette==0.46.2
|
136 |
+
sympy==1.14.0
|
137 |
+
synchronicity==0.9.13
|
138 |
+
tenacity==9.1.2
|
139 |
+
threadpoolctl==3.6.0
|
140 |
+
tiktoken==0.9.0
|
141 |
+
timm==1.0.15
|
142 |
+
tokenizers==0.21.1
|
143 |
+
toml==0.10.2
|
144 |
+
tomlkit==0.13.3
|
145 |
+
torch==2.7.1
|
146 |
+
torchvision==0.22.1
|
147 |
+
tqdm==4.67.1
|
148 |
+
transformers==4.52.4
|
149 |
+
triton==3.3.1
|
150 |
+
typer==0.16.0
|
151 |
+
types-certifi==2021.10.8.3
|
152 |
+
types-toml==0.10.8.20240310
|
153 |
+
typing-extensions==4.14.0
|
154 |
+
typing-inspect==0.9.0
|
155 |
+
typing-inspection==0.4.1
|
156 |
+
tzdata==2025.2
|
157 |
+
urllib3==2.4.0
|
158 |
+
uvicorn==0.34.3
|
159 |
+
watchfiles==1.0.5
|
160 |
+
wcwidth==0.2.13
|
161 |
+
websockets==15.0.1
|
162 |
+
xxhash==3.5.0
|
163 |
+
yarl==1.20.0
|
164 |
+
zipp==3.22.0
|
165 |
+
zstandard==0.23.0
|
tools/bbox_drawing_tool.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Dict
|
2 |
+
from smolagents import Tool, AgentImage
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class BBoxDrawingTool(Tool):
|
9 |
+
name = "bbox_drawing"
|
10 |
+
description = """
|
11 |
+
Given an image and a list of bounding boxes, draw the bounding boxes on the image.
|
12 |
+
The image is a PIL image.
|
13 |
+
The bounding boxes are a list of dictionaries with the following keys:
|
14 |
+
- box: a list of 4 numbers [xmin, ymin, xmax, ymax]
|
15 |
+
- score: a number between 0 and 1
|
16 |
+
- label: a string.
|
17 |
+
The output is the image with the bounding boxes drawn on it.
|
18 |
+
"""
|
19 |
+
|
20 |
+
inputs = {
|
21 |
+
"image": {
|
22 |
+
"type": "image",
|
23 |
+
"description": "The image to draw the bounding boxes on",
|
24 |
+
},
|
25 |
+
"bboxes": {
|
26 |
+
"type": "array",
|
27 |
+
"description": "The list of bounding boxes to draw on the image",
|
28 |
+
},
|
29 |
+
}
|
30 |
+
output_type = "image"
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
def forward(
|
36 |
+
self,
|
37 |
+
image: AgentImage,
|
38 |
+
bboxes: List[Dict[str, Union[str, float, List]]],
|
39 |
+
):
|
40 |
+
np_image = np.array(image)
|
41 |
+
cv2_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
|
42 |
+
|
43 |
+
for bbox in bboxes:
|
44 |
+
print(bbox)
|
45 |
+
print(bbox["box"])
|
46 |
+
cv2_image = self.draw_bbox(cv2_image, bbox["box"])
|
47 |
+
|
48 |
+
pil_image = Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
|
49 |
+
return pil_image
|
50 |
+
|
51 |
+
def draw_bbox(self, image: AgentImage, bbox: List[int]):
|
52 |
+
x1, y1, x2, y2 = tuple(bbox)
|
53 |
+
x1 = int(x1)
|
54 |
+
y1 = int(y1)
|
55 |
+
x2 = int(x2)
|
56 |
+
y2 = int(y2)
|
57 |
+
image = cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
58 |
+
return image
|
tools/cropping_tool.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import Tool, AgentImage
|
2 |
+
|
3 |
+
|
4 |
+
class CroppingTool(Tool):
|
5 |
+
name = "cropping"
|
6 |
+
description = """
|
7 |
+
Given a list of images and a list of bounding boxes, crop the images to the specified regions.
|
8 |
+
The images are PIL images.
|
9 |
+
The bounding boxes are lists of 4 numbers [xmin, ymin, xmax, ymax] for each image.
|
10 |
+
The output is a list of cropped PIL images.
|
11 |
+
You can crop multiple images at once.
|
12 |
+
You need the same number of images and bounding boxes.
|
13 |
+
"""
|
14 |
+
|
15 |
+
inputs = {
|
16 |
+
"images": {
|
17 |
+
"type": "array",
|
18 |
+
"description": "The images to crop",
|
19 |
+
},
|
20 |
+
"bboxes": {
|
21 |
+
"type": "array",
|
22 |
+
"description": "The bounding box coordinates [xmin, ymin, xmax, ymax] for each image",
|
23 |
+
},
|
24 |
+
}
|
25 |
+
output_type = "array"
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
def setup(self):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def forward(self, images: list[AgentImage], bboxes: list[list]):
|
34 |
+
if len(images) != len(bboxes):
|
35 |
+
raise ValueError(
|
36 |
+
"The number of images and bounding boxes must be the same."
|
37 |
+
)
|
38 |
+
|
39 |
+
cropped_images = []
|
40 |
+
for image, bbox in zip(images, bboxes):
|
41 |
+
# Convert bbox to integers
|
42 |
+
xmin, ymin, xmax, ymax = map(int, bbox)
|
43 |
+
|
44 |
+
# Ensure coordinates are within image bounds
|
45 |
+
width, height = image.size
|
46 |
+
xmin = max(0, min(xmin, width))
|
47 |
+
ymin = max(0, min(ymin, height))
|
48 |
+
xmax = max(0, min(xmax, width))
|
49 |
+
ymax = max(0, min(ymax, height))
|
50 |
+
|
51 |
+
# Crop the image
|
52 |
+
cropped_image = image.crop((xmin, ymin, xmax, ymax))
|
53 |
+
cropped_images.append(cropped_image)
|
54 |
+
|
55 |
+
return cropped_images
|
tools/hf_api_tool.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import Tool
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
|
4 |
+
|
5 |
+
class HFAPITool(Tool):
|
6 |
+
name = "hf_api"
|
7 |
+
description = "Use the HuggingFace API to search for models"
|
8 |
+
inputs = {
|
9 |
+
"prompt": {
|
10 |
+
"type": "string",
|
11 |
+
"description": "The prompt to search for models",
|
12 |
+
},
|
13 |
+
}
|
14 |
+
output_type = "object"
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
self.api = HfApi()
|
19 |
+
|
20 |
+
def forward(self, prompt: str):
|
21 |
+
models = self.api.list_models(
|
22 |
+
library=["transformers"], pipeline_tag="object-detection", fetch_config=True
|
23 |
+
)
|
24 |
+
print(models)
|
25 |
+
|
26 |
+
|
tools/rag_tool.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from smolagents import Tool
|
3 |
+
from rag.settings import get_vector_store
|
4 |
+
|
5 |
+
|
6 |
+
class ObjectDetectionModelRetrieverTool(Tool):
|
7 |
+
name = "object_detection_model_retriever"
|
8 |
+
description = """
|
9 |
+
For a given class of objects, retrieve the models that can detect that class.
|
10 |
+
The query is a string that describes the class of objects the model needs to detect.
|
11 |
+
The output is a dictionary with the model id as the key and the labels that the model can detect as the value.
|
12 |
+
"""
|
13 |
+
inputs = {
|
14 |
+
"query": {
|
15 |
+
"type": "object",
|
16 |
+
"description": "The class of objects the model needs to detect.",
|
17 |
+
}
|
18 |
+
}
|
19 |
+
output_type = "object"
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
def setup(self):
|
25 |
+
self.vector_store = get_vector_store()
|
26 |
+
print("Loaded vector store")
|
27 |
+
|
28 |
+
def forward(self, query: str) -> str:
|
29 |
+
assert isinstance(query, str), "Your search query must be a string"
|
30 |
+
|
31 |
+
docs = self.vector_store.similarity_search(query, k=7)
|
32 |
+
model_ids = [doc.metadata["model_id"] for doc in docs]
|
33 |
+
model_labels = [doc.metadata["model_labels"] for doc in docs]
|
34 |
+
models_dict = {
|
35 |
+
model_id: model_labels
|
36 |
+
for model_id, model_labels in zip(model_ids, model_labels)
|
37 |
+
}
|
38 |
+
return models_dict
|