File size: 1,815 Bytes
b6dd162
 
 
 
 
 
5e3dd4a
b6dd162
8fd3ffb
 
5e3dd4a
b6dd162
f67297b
b6dd162
5e3dd4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6dd162
5e3dd4a
 
 
 
 
 
 
 
d2e572f
5e3dd4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain import vectorstores
from langchain import chains
from langchain import llms 
from langchain.embeddings import HuggingFaceEmbeddings
import gradio as gr

load_dotenv()

llm = llms.AI21(ai21_api_key=os.getenv('AI21_API_KEY'))

def process_pdf(pdf_file):
    pdf_reader = PdfReader(pdf_file)
    texts = ""
    for page in pdf_reader.pages:
        texts += page.extract_text()
    
    text_splitter = CharacterTextSplitter(
        separator="\n",
        chunk_size=1000,
        chunk_overlap=0
    )
    chunks = text_splitter.split_text(texts)
    embeddings = HuggingFaceEmbeddings()
    db = vectorstores.Chroma.from_texts(chunks, embeddings)
    retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":10})
    qa = chains.ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever)
    return qa

def answer_question(pdf_file, question, chat_history):
    if not pdf_file:
        return "Please upload a PDF file first."
    
    qa = process_pdf(pdf_file)
    result = qa({"question": question, "chat_history": chat_history})
    chat_history.append((question, result["answer"]))
    return result["answer"]

def main():
    with gr.Blocks() as demo:
        gr.Markdown("# PDF QA")
        with gr.Row():
            pdf_file = gr.File(label="Upload your PDF", file_types=[".pdf"])
            question = gr.Textbox(label="Ask a question about the PDF")
        output = gr.Textbox(label="Answer")
        chat_history = gr.State([])
        submit_btn = gr.Button("Submit")
        submit_btn.click(answer_question, inputs=[pdf_file, question, chat_history], outputs=output)
    
    demo.launch()

if __name__ == "__main__":
    main()