File size: 1,891 Bytes
7e327f2
 
 
518d841
 
 
 
7e327f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
 
 
 
 
 
7e327f2
 
518d841
7e327f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518d841
7e327f2
 
 
 
518d841
 
7e327f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import datasets
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

from rag.settings import get_embeddings_model


def get_vector_store():
    embeddings = get_embeddings_model()
    index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world")))

    vector_store = FAISS(
        embedding_function=embeddings,
        index=index,
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
    )
    return vector_store


def get_docs(dataset):
    page_content = """
    Model card:
    {models}
    List of labels:
    {labels}
    """
    source_docs = [
        Document(
            page_content=page_content,
            metadata={
                "model_id": model["model_id"],
                "model_labels": model["model_labels"],
            },
        )
        for model in dataset
    ]
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,  # Characters per chunk
        chunk_overlap=50,  # Overlap between chunks to maintain context
        add_start_index=True,
        strip_whitespace=True,
        separators=["\n\n", "\n", ".", " ", ""],  # Priority order for splitting
    )
    docs_processed = text_splitter.split_documents(source_docs)
    print(f"Knowledge base prepared with {len(docs_processed)} document chunks")
    return docs_processed


if __name__ == "__main__":
    dataset = datasets.load_dataset("stevenbucaille/image-classification-models-dataset", split="train")
    docs_processed = get_docs(dataset)
    vector_store = get_vector_store()
    vector_store.add_documents(docs_processed)
    vector_store.save_local(
        folder_path="vector_store/image-classification",
        index_name="faiss_index",
    )