File size: 3,695 Bytes
6481b5d
 
 
 
 
 
 
dd884ec
6481b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Indexing with vector database - updated for Weaviate, FAISS, Qdrant, Pinecone
Compatible with latest LangChain and HuggingFaceEmbeddings
"""

from pathlib import Path
import re
import os
from unidecode import unidecode

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings


def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits


def create_collection_name(filepath):
    collection_name = Path(filepath).stem
    collection_name = collection_name.replace(" ", "-")
    collection_name = unidecode(collection_name)
    collection_name = re.sub("[^A-Za-z0-9]+", "-", collection_name)
    collection_name = collection_name[:50]
    if len(collection_name) < 3:
        collection_name += "xyz"
    if not collection_name[0].isalnum():
        collection_name = "A" + collection_name[1:]
    if not collection_name[-1].isalnum():
        collection_name = collection_name[:-1] + "Z"
    print("\n\nFilepath:", filepath)
    print("Collection name:", collection_name)
    return collection_name


def create_db(splits, collection_name, db_type="ChromaDB"):
    embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

    if db_type == "ChromaDB":
        import chromadb
        from langchain_chroma import Chroma

        chromadb.api.client.SharedSystemClient.clear_system_cache()
        vectordb = Chroma.from_documents(
            documents=splits,
            embedding=embedding,
            client=chromadb.EphemeralClient(),
            collection_name=collection_name,
        )
        return vectordb

    elif db_type == "Weaviate":
        import weaviate
        from langchain_weaviate.vectorstores import WeaviateVectorStore

        client = weaviate.connect_to_local("http://localhost:8080",
        grpc_port=50051)
        vectordb = WeaviateVectorStore.from_documents(
            splits,
            embedding,
            client=client,
            index_name=collection_name,
            text_key="text"
        )
        return vectordb

    elif db_type == "FAISS":
        from langchain.vectorstores import FAISS

        vectordb = FAISS.from_documents(splits, embedding)
        vectordb.save_local(f"{collection_name}_index")
        return vectordb

    elif db_type == "Qdrant":
        from qdrant_client import QdrantClient
        from langchain.vectorstores import Qdrant

        client = QdrantClient("::memory::")
        vectordb = Qdrant.from_documents(splits, embedding, client=client, collection_name=collection_name)
        return vectordb

    elif db_type == "Pinecone":
        import pinecone
        from langchain_pinecone import PineconeVectorStore

        pinecone_api_key = os.environ.get("PINECONE_API_KEY")
        pc = pinecone.Pinecone(api_key=pinecone_api_key)

        index_name = collection_name
        dim = len(embedding.embed_query("test"))
        if index_name not in [i.name for i in pc.list_indexes()]:
            pc.create_index(name=index_name, dimension=dim, metric="cosine")

        index = pc.Index(index_name)
        vectordb = PineconeVectorStore.from_documents(docs=splits, index=index, embedding=embedding)
        return vectordb

    else:
        raise ValueError(f"Unsupported vector DB type: {db_type}")