sahilnishad commited on
Commit
2105ace
·
verified ·
1 Parent(s): 6356b35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import streamlit as st
5
+ from streamlit_chat import message
6
+
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain.document_loaders import PDFMinerLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+
15
+ from constants import CHROMA_SETTINGS
16
+
17
+
18
+
19
+ st.set_page_config(layout="centered")
20
+
21
+ checkpoint = "meta-llama/Llama-2-7b-chat-hf"
22
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ checkpoint,
25
+ device_map="auto",
26
+ torch_dtype=torch.float32
27
+ )
28
+
29
+ @st.cache_resource
30
+ def data_ingestion(filepath):
31
+ loader = PDFMinerLoader(filepath)
32
+ documents = loader.load()
33
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
34
+ texts = text_splitter.split_documents(documents)
35
+
36
+ def embedding_function(text):
37
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
38
+ with torch.no_grad():
39
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
40
+ return embeddings
41
+
42
+ db = Chroma.from_documents(texts, embedding_function=embedding_function, persist_directory="db")
43
+ db.persist()
44
+ db = None
45
+
46
+ @st.cache_resource
47
+ def llm_pipeline():
48
+ pipe = pipeline(
49
+ 'text-generation',
50
+ model=model,
51
+ tokenizer=tokenizer,
52
+ max_length=256,
53
+ do_sample=True,
54
+ temperature=0.3,
55
+ top_p=0.95
56
+ )
57
+ local_llm = HuggingFacePipeline(pipeline=pipe)
58
+ return local_llm
59
+
60
+ @st.cache_resource
61
+ def qa_llm():
62
+ llm = llm_pipeline()
63
+ def embedding_function(text):
64
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
65
+ with torch.no_grad():
66
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
67
+ return embeddings
68
+
69
+ db = Chroma(persist_directory="db", embedding_function=embedding_function)
70
+ retriever = db.as_retriever()
71
+ qa = RetrievalQA.from_chain_type(
72
+ llm=llm,
73
+ chain_type="stuff",
74
+ retriever=retriever,
75
+ return_source_documents=True
76
+ )
77
+ return qa
78
+
79
+ def process_answer(instruction):
80
+ qa = qa_llm()
81
+ generated_text = qa(instruction)
82
+ answer = generated_text['result']
83
+ return answer
84
+
85
+ def display_conversation(history):
86
+ for i in range(len(history["generated"])):
87
+ message(history["past"][i], is_user=True, key=str(i) + "_user")
88
+ message(history["generated"][i], key=str(i))
89
+
90
+ def main():
91
+ st.markdown("<h1 style='text-align: center;'>Chat with your PDF</h1>", unsafe_allow_html=True)
92
+ st.markdown("<h2 style='text-align: center;'>Upload your PDF</h2>", unsafe_allow_html=True)
93
+ uploaded_file = st.file_uploader("", type=["pdf"])
94
+
95
+ if uploaded_file is not None:
96
+ filepath = "docs/" + uploaded_file.name
97
+ with open(filepath, "wb") as temp_file:
98
+ temp_file.write(uploaded_file.read())
99
+
100
+ with st.spinner('Embeddings are creating...'):
101
+ data_ingestion(filepath)
102
+ st.success('Embeddings are created successfully!')
103
+
104
+ user_input = st.text_input("", key="input")
105
+
106
+ if "generated" not in st.session_state:
107
+ st.session_state["generated"] = ["I am ready to help you"]
108
+ if "past" not in st.session_state:
109
+ st.session_state["past"] = ["Hey there!"]
110
+
111
+ if user_input:
112
+ answer = process_answer({'query': user_input})
113
+ st.session_state["past"].append(user_input)
114
+ st.session_state["generated"].append(answer)
115
+
116
+ display_conversation(st.session_state)
117
+
118
+ if __name__ == "__main__":
119
+ main()