File size: 3,883 Bytes
2105ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d78c05e
 
 
2105ace
d78c05e
 
2105ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch

import streamlit as st
from streamlit_chat import message

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFacePipeline
from langchain.document_loaders import PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

from constants import CHROMA_SETTINGS



st.set_page_config(layout="centered")

checkpoint = "meta-llama/Llama-2-7b-chat-hf"
token = os.getenv("HF_TOKEN")

tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint, 
    use_auth_token=token,
    device_map="auto",
    torch_dtype=torch.float32
)

@st.cache_resource
def data_ingestion(filepath):
    loader = PDFMinerLoader(filepath)
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    texts = text_splitter.split_documents(documents)

    def embedding_function(text):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        with torch.no_grad():
            embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        return embeddings

    db = Chroma.from_documents(texts, embedding_function=embedding_function, persist_directory="db")
    db.persist()
    db = None

@st.cache_resource
def llm_pipeline():
    pipe = pipeline(
        'text-generation',
        model=model,
        tokenizer=tokenizer,
        max_length=256,
        do_sample=True,
        temperature=0.3,
        top_p=0.95
    )
    local_llm = HuggingFacePipeline(pipeline=pipe)
    return local_llm

@st.cache_resource
def qa_llm():
    llm = llm_pipeline()
    def embedding_function(text):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        with torch.no_grad():
            embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        return embeddings

    db = Chroma(persist_directory="db", embedding_function=embedding_function)
    retriever = db.as_retriever()
    qa = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True
    )
    return qa

def process_answer(instruction):
    qa = qa_llm()
    generated_text = qa(instruction)
    answer = generated_text['result']
    return answer

def display_conversation(history):
    for i in range(len(history["generated"])):
        message(history["past"][i], is_user=True, key=str(i) + "_user")
        message(history["generated"][i], key=str(i))

def main():
    st.markdown("<h1 style='text-align: center;'>Chat with your PDF</h1>", unsafe_allow_html=True)
    st.markdown("<h2 style='text-align: center;'>Upload your PDF</h2>", unsafe_allow_html=True)
    uploaded_file = st.file_uploader("", type=["pdf"])

    if uploaded_file is not None:
        filepath = "docs/" + uploaded_file.name
        with open(filepath, "wb") as temp_file:
            temp_file.write(uploaded_file.read())

        with st.spinner('Embeddings are creating...'):
            data_ingestion(filepath)
        st.success('Embeddings are created successfully!')

        user_input = st.text_input("", key="input")

        if "generated" not in st.session_state:
            st.session_state["generated"] = ["I am ready to help you"]
        if "past" not in st.session_state:
            st.session_state["past"] = ["Hey there!"]

        if user_input:
            answer = process_answer({'query': user_input})
            st.session_state["past"].append(user_input)
            st.session_state["generated"].append(answer)

        display_conversation(st.session_state)

if __name__ == "__main__":
    main()