Spaces:
Running
Running
import datasets | |
from langchain_core.documents import Document | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
import faiss | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from rag.settings import get_embeddings_model | |
def get_vector_store(): | |
embeddings = get_embeddings_model() | |
index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world"))) | |
vector_store = FAISS( | |
embedding_function=embeddings, | |
index=index, | |
docstore=InMemoryDocstore(), | |
index_to_docstore_id={}, | |
) | |
return vector_store | |
def get_docs(dataset): | |
source_docs = [ | |
Document( | |
page_content=model["model_card"], | |
metadata={ | |
"model_id": model["model_id"], | |
"model_labels": model["model_labels"], | |
}, | |
) | |
for model in dataset | |
] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, # Characters per chunk | |
chunk_overlap=50, # Overlap between chunks to maintain context | |
add_start_index=True, | |
strip_whitespace=True, | |
separators=["\n\n", "\n", ".", " ", ""], # Priority order for splitting | |
) | |
docs_processed = text_splitter.split_documents(source_docs) | |
print(f"Knowledge base prepared with {len(docs_processed)} document chunks") | |
return docs_processed | |
if __name__ == "__main__": | |
dataset = datasets.load_dataset( | |
"stevenbucaille/object-detection-models-dataset", split="train" | |
) | |
docs_processed = get_docs(dataset) | |
vector_store = get_vector_store() | |
vector_store.add_documents(docs_processed) | |
vector_store.save_local( | |
folder_path="vector_store", | |
index_name="object_detection_models_faiss_index", | |
) | |