File size: 5,393 Bytes
20048e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import shutil
from typing import Optional

from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from loguru import logger
from tqdm import tqdm

from .parser import parse_pdf

PROMPT_TEMPLATE = """已知信息:
{context} 

根据上述已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 
问题是:{question}"""


def _get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
    )
    file_type = os.path.splitext(filepath)[1]

    logger.info(f"Loading file: {filepath}")
    texts = Document(page_content="", metadata={"source": filepath})
    try:
        if file_type == ".pdf":
            logger.debug("Loading PDF...")
            try:
                pdftext = parse_pdf(filepath, two_column).text
            except:
                from PyPDF2 import PdfReader

                pdftext = ""
                with open(filepath, "rb") as pdfFileObj:
                    pdfReader = PdfReader(pdfFileObj)
                    for page in tqdm(pdfReader.pages):
                        pdftext += page.extract_text()

            texts = Document(page_content=pdftext, metadata={"source": filepath})

        elif file_type == ".docx":
            from langchain.document_loaders import UnstructuredWordDocumentLoader

            logger.debug("Loading Word...")
            loader = UnstructuredWordDocumentLoader(filepath)
            texts = loader.load()
        elif file_type == ".pptx":
            from langchain.document_loaders import UnstructuredPowerPointLoader

            logger.debug("Loading PowerPoint...")
            loader = UnstructuredPowerPointLoader(filepath)
            texts = loader.load()
        elif file_type == ".epub":
            from langchain.document_loaders import UnstructuredEPubLoader

            logger.debug("Loading EPUB...")
            loader = UnstructuredEPubLoader(filepath)
            texts = loader.load()
        elif file_type == ".md":
            loader = UnstructuredFileLoader(filepath, mode="elements")
            return loader.load()
        else:
            loader = UnstructuredFileLoader(filepath, mode="elements")
            return loader.load_and_split(text_splitter=text_splitter)
    except Exception as e:
        import traceback
        logger.error(f"Error loading file: {filepath}")
        traceback.print_exc()

    return text_splitter.split_documents([texts])


def get_documents(filepath, chunk_size=500, chunk_overlap=0, two_column=False):
    documents = []
    logger.debug("Loading documents...")
    if os.path.isfile(filepath):
        documents.extend(
            _get_documents(
                filepath,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                two_column=two_column
            )
        )
    else:
        for file in filepath:
            documents.extend(
                _get_documents(
                    file,
                    chunk_size=chunk_size,
                    chunk_overlap=chunk_overlap,
                    two_column=two_column
                )
            )
    logger.debug("Documents loaded.")
    return documents


def generate_prompt(related_docs, query: str, prompt_template=PROMPT_TEMPLATE) -> str:
    context = "\n".join([doc[0].page_content for doc in related_docs])
    return prompt_template.replace("{question}", query).replace("{context}", context)


class DocQAPromptAdapter:
    def __init__(self, chunk_size: Optional[int] = 500, chunk_overlap: Optional[int] = 0, api_key: Optional[str] = "xxx"):
        self.embeddings = OpenAIEmbeddings(openai_api_key=api_key)
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap

        self.vector_store = None

    def create_vector_store(self, file_path, vs_path, embeddings=None):
        documents = get_documents(file_path, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
        self.vector_store = FAISS.from_documents(documents, self.embeddings if not embeddings else embeddings)
        self.vector_store.save_local(vs_path)

    def reset_vector_store(self, vs_path, embeddings=None):
        self.vector_store = FAISS.load_local(vs_path, self.embeddings if not embeddings else embeddings)

    @staticmethod
    def delete_files(files):
        for file in files:
            if os.path.exists(file):
                if os.path.isfile(file):
                    os.remove(file)
                else:
                    shutil.rmtree(file)

    def __call__(self, query, vs_path=None, topk=6):
        if vs_path is not None and os.path.exists(vs_path):
            self.reset_vector_store(vs_path)
        self.vector_store.embedding_function = self.embeddings.embed_query
        related_docs_with_score = self.vector_store.similarity_search_with_score(query, k=topk)
        return generate_prompt(related_docs_with_score, query)