File size: 1,783 Bytes
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
518d841
111afa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import modal
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

from .app import app
from .image import image
from .volume import volume


@app.cls(gpu="T4", image=image, volumes={"/volume": volume})
class TaskModelRetrieverModalApp:
    @modal.enter()
    def setup(self):
        tasks = ["object-detection", "image-segmentation", "image-classification"]
        self.vector_stores = {}
        for task in tasks:
            self.vector_stores[task] = FAISS.load_local(
                folder_path=f"/volume/vector_store/{task}",
                embeddings=HuggingFaceEmbeddings(
                    model_name="all-MiniLM-L6-v2",
                    model_kwargs={"device": "cuda"},
                    encode_kwargs={"normalize_embeddings": True},
                    show_progress=True,
                ),
                index_name="faiss_index",
                allow_dangerous_deserialization=True,
            )

    def forward(self, task: str, query: str) -> str:
        docs = self.vector_stores[task].similarity_search(query, k=7)
        model_ids = [doc.metadata["model_id"] for doc in docs]
        model_labels = [doc.metadata["model_labels"] for doc in docs]
        models_dict = {model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels)}
        return models_dict

    @modal.method()
    def object_detection_search(self, query: str) -> str:
        return self.forward("object-detection", query)

    @modal.method()
    def image_segmentation_search(self, query: str) -> str:
        return self.forward("image-segmentation", query)

    @modal.method()
    def image_classification_search(self, query: str) -> str:
        return self.forward("image-classification", query)