Spaces:
Running
Running
File size: 1,783 Bytes
111afa2 518d841 111afa2 518d841 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import modal
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from .app import app
from .image import image
from .volume import volume
@app.cls(gpu="T4", image=image, volumes={"/volume": volume})
class TaskModelRetrieverModalApp:
@modal.enter()
def setup(self):
tasks = ["object-detection", "image-segmentation", "image-classification"]
self.vector_stores = {}
for task in tasks:
self.vector_stores[task] = FAISS.load_local(
folder_path=f"/volume/vector_store/{task}",
embeddings=HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
show_progress=True,
),
index_name="faiss_index",
allow_dangerous_deserialization=True,
)
def forward(self, task: str, query: str) -> str:
docs = self.vector_stores[task].similarity_search(query, k=7)
model_ids = [doc.metadata["model_id"] for doc in docs]
model_labels = [doc.metadata["model_labels"] for doc in docs]
models_dict = {model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels)}
return models_dict
@modal.method()
def object_detection_search(self, query: str) -> str:
return self.forward("object-detection", query)
@modal.method()
def image_segmentation_search(self, query: str) -> str:
return self.forward("image-segmentation", query)
@modal.method()
def image_classification_search(self, query: str) -> str:
return self.forward("image-classification", query)
|