|
""" |
|
Indexing with vector database - updated for Weaviate, FAISS, Qdrant, Pinecone |
|
Compatible with latest LangChain and HuggingFaceEmbeddings |
|
""" |
|
|
|
from pathlib import Path |
|
import re |
|
import os |
|
from unidecode import unidecode |
|
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
def load_doc(list_file_path, chunk_size, chunk_overlap): |
|
loaders = [PyPDFLoader(x) for x in list_file_path] |
|
pages = [] |
|
for loader in loaders: |
|
pages.extend(loader.load()) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
doc_splits = text_splitter.split_documents(pages) |
|
return doc_splits |
|
|
|
|
|
def create_collection_name(filepath): |
|
collection_name = Path(filepath).stem |
|
collection_name = collection_name.replace(" ", "-") |
|
collection_name = unidecode(collection_name) |
|
collection_name = re.sub("[^A-Za-z0-9]+", "-", collection_name) |
|
collection_name = collection_name[:50] |
|
if len(collection_name) < 3: |
|
collection_name += "xyz" |
|
if not collection_name[0].isalnum(): |
|
collection_name = "A" + collection_name[1:] |
|
if not collection_name[-1].isalnum(): |
|
collection_name = collection_name[:-1] + "Z" |
|
print("\n\nFilepath:", filepath) |
|
print("Collection name:", collection_name) |
|
return collection_name |
|
|
|
|
|
def create_db(splits, collection_name, db_type="ChromaDB"): |
|
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") |
|
|
|
if db_type == "ChromaDB": |
|
import chromadb |
|
from langchain_chroma import Chroma |
|
|
|
chromadb.api.client.SharedSystemClient.clear_system_cache() |
|
vectordb = Chroma.from_documents( |
|
documents=splits, |
|
embedding=embedding, |
|
client=chromadb.EphemeralClient(), |
|
collection_name=collection_name, |
|
) |
|
return vectordb |
|
|
|
elif db_type == "Weaviate": |
|
import weaviate |
|
from langchain_weaviate.vectorstores import WeaviateVectorStore |
|
|
|
client = weaviate.connect_to_local("http://localhost:8080", |
|
grpc_port=50051) |
|
vectordb = WeaviateVectorStore.from_documents( |
|
splits, |
|
embedding, |
|
client=client, |
|
index_name=collection_name, |
|
text_key="text" |
|
) |
|
return vectordb |
|
|
|
elif db_type == "FAISS": |
|
from langchain.vectorstores import FAISS |
|
|
|
vectordb = FAISS.from_documents(splits, embedding) |
|
vectordb.save_local(f"{collection_name}_index") |
|
return vectordb |
|
|
|
elif db_type == "Qdrant": |
|
from qdrant_client import QdrantClient |
|
from langchain.vectorstores import Qdrant |
|
|
|
client = QdrantClient("::memory::") |
|
vectordb = Qdrant.from_documents(splits, embedding, client=client, collection_name=collection_name) |
|
return vectordb |
|
|
|
elif db_type == "Pinecone": |
|
import pinecone |
|
from langchain_pinecone import PineconeVectorStore |
|
|
|
pinecone_api_key = os.environ.get("PINECONE_API_KEY") |
|
pc = pinecone.Pinecone(api_key=pinecone_api_key) |
|
|
|
index_name = collection_name |
|
dim = len(embedding.embed_query("test")) |
|
if index_name not in [i.name for i in pc.list_indexes()]: |
|
pc.create_index(name=index_name, dimension=dim, metric="cosine") |
|
|
|
index = pc.Index(index_name) |
|
vectordb = PineconeVectorStore.from_documents(docs=splits, index=index, embedding=embedding) |
|
return vectordb |
|
|
|
else: |
|
raise ValueError(f"Unsupported vector DB type: {db_type}") |