|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.faiss import FAISS |
|
from langchain import OpenAI, Cohere |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from embeddings import OpenAIEmbeddings |
|
from langchain.llms import OpenAI |
|
from langchain.docstore.document import Document |
|
from langchain.vectorstores import FAISS, VectorStore |
|
import docx2txt |
|
from typing import List, Dict, Any |
|
import re |
|
import numpy as np |
|
from io import StringIO |
|
from io import BytesIO |
|
import streamlit as st |
|
from prompts import STUFF_PROMPT |
|
from pypdf import PdfReader |
|
from openai.error import AuthenticationError |
|
|
|
@st.experimental_memo() |
|
def parse_docx(file: BytesIO) -> str: |
|
text = docx2txt.process(file) |
|
|
|
text = re.sub(r"\n\s*\n", "\n\n", text) |
|
return text |
|
|
|
|
|
@st.experimental_memo() |
|
def parse_pdf(file: BytesIO) -> List[str]: |
|
pdf = PdfReader(file) |
|
output = [] |
|
for page in pdf.pages: |
|
text = page.extract_text() |
|
|
|
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) |
|
|
|
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip()) |
|
|
|
text = re.sub(r"\n\s*\n", "\n\n", text) |
|
|
|
output.append(text) |
|
|
|
return output |
|
|
|
|
|
@st.experimental_memo() |
|
def parse_txt(file: BytesIO) -> str: |
|
text = file.read().decode("utf-8") |
|
|
|
text = re.sub(r"\n\s*\n", "\n\n", text) |
|
return text |
|
|
|
@st.experimental_memo() |
|
def parse_csv(uploaded_file): |
|
|
|
|
|
|
|
|
|
|
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
|
|
|
|
|
|
string_data = stringio.read() |
|
|
|
|
|
|
|
|
|
return string_data |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def text_to_docs(text: str | List[str]) -> List[Document]: |
|
"""Converts a string or list of strings to a list of Documents |
|
with metadata.""" |
|
if isinstance(text, str): |
|
|
|
text = [text] |
|
page_docs = [Document(page_content=page) for page in text] |
|
|
|
|
|
for i, doc in enumerate(page_docs): |
|
doc.metadata["page"] = i + 1 |
|
|
|
|
|
doc_chunks = [] |
|
|
|
for doc in page_docs: |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=800, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""], |
|
chunk_overlap=0, |
|
) |
|
chunks = text_splitter.split_text(doc.page_content) |
|
for i, chunk in enumerate(chunks): |
|
doc = Document( |
|
page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i} |
|
) |
|
|
|
doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}" |
|
doc_chunks.append(doc) |
|
return doc_chunks |
|
|
|
|
|
@st.cache(allow_output_mutation=True, show_spinner=False) |
|
def embed_docs(docs: List[Document]) -> VectorStore: |
|
"""Embeds a list of Documents and returns a FAISS index""" |
|
|
|
if not st.session_state.get("OPENAI_API_KEY"): |
|
raise AuthenticationError( |
|
"Enter your OpenAI API key in the sidebar. You can get a key at https://platform.openai.com/account/api-keys." |
|
) |
|
else: |
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.get("OPENAI_API_KEY")) |
|
index = FAISS.from_documents(docs, embeddings) |
|
|
|
return index |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def search_docs(index: VectorStore, query: str) -> List[Document]: |
|
"""Searches a FAISS index for similar chunks to the query |
|
and returns a list of Documents.""" |
|
|
|
|
|
docs = index.similarity_search(query, k=5) |
|
return docs |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_answer(docs: List[Document], query: str) -> Dict[str, Any]: |
|
"""Gets an answer to a question from a list of Documents.""" |
|
|
|
|
|
|
|
chain = load_qa_with_sources_chain(OpenAI(temperature=0, openai_api_key=st.session_state.get("OPENAI_API_KEY")), chain_type="stuff", prompt=STUFF_PROMPT) |
|
|
|
|
|
|
|
answer = chain( |
|
{"input_documents": docs, "question": query}, return_only_outputs=True |
|
) |
|
return answer |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_sources(answer: Dict[str, Any], docs: List[Document]) -> List[Document]: |
|
"""Gets the source documents for an answer.""" |
|
|
|
|
|
source_keys = [s for s in answer["output_text"].split("SOURCES: ")[-1].split(", ")] |
|
|
|
source_docs = [] |
|
for doc in docs: |
|
if doc.metadata["source"] in source_keys: |
|
source_docs.append(doc) |
|
|
|
return source_docs |
|
|
|
|
|
def wrap_text_in_html(text: str | List[str]) -> str: |
|
"""Wraps each text block separated by newlines in <p> tags""" |
|
if isinstance(text, list): |
|
|
|
text = "\n<hr/>\n".join(text) |
|
return "".join([f"<p>{line}</p>" for line in text.split("\n")]) |