Spaces:
Running
Running
import modal | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from .app import app | |
from .image import image | |
from .volume import volume | |
class TaskModelRetrieverModalApp: | |
def setup(self): | |
tasks = ["object-detection", "image-segmentation"] | |
self.vector_stores = {} | |
for task in tasks: | |
self.vector_stores[task] = FAISS.load_local( | |
folder_path=f"/volume/vector_store/{task}", | |
embeddings=HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={"device": "cuda"}, | |
encode_kwargs={"normalize_embeddings": True}, | |
show_progress=True, | |
), | |
index_name="faiss_index", | |
allow_dangerous_deserialization=True, | |
) | |
def forward(self, task: str, query: str) -> str: | |
docs = self.vector_stores[task].similarity_search(query, k=7) | |
model_ids = [doc.metadata["model_id"] for doc in docs] | |
model_labels = [doc.metadata["model_labels"] for doc in docs] | |
models_dict = {model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels)} | |
return models_dict | |
def object_detection_search(self, query: str) -> str: | |
return self.forward("object-detection", query) | |
def image_segmentation_search(self, query: str) -> str: | |
return self.forward("image-segmentation", query) | |