Govind commited on
Commit
2728e29
·
1 Parent(s): 9dc3e64

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # import fitz # PyMuPDF for extracting text from PDFs
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.docstore.document import Document
7
+ from langchain.llms import HuggingFacePipeline
8
+ from langchain.chains import RetrievalQA
9
+ from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
10
+ import torch
11
+ import re
12
+ import transformers
13
+ from torch import bfloat16
14
+ from langchain_community.document_loaders import DirectoryLoader
15
+
16
+ # Initialize embeddings and ChromaDB
17
+ model_name = "sentence-transformers/all-mpnet-base-v2"
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model_kwargs = {"device": device}
20
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
21
+
22
+ loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
23
+ docs = loader.load()
24
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
25
+ all_splits = text_splitter.split_documents(docs)
26
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
27
+ books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)
28
+
29
+ books_db_client = books_db.as_retriever()
30
+
31
+ # Initialize the model and tokenizer
32
+ model_name = "stabilityai/stablelm-zephyr-3b"
33
+
34
+ bnb_config = transformers.BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_quant_type='nf4',
37
+ bnb_4bit_use_double_quant=True,
38
+ bnb_4bit_compute_dtype=torch.bfloat16
39
+ )
40
+
41
+ model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
42
+ model = transformers.AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ trust_remote_code=True,
45
+ config=model_config,
46
+ quantization_config=bnb_config,
47
+ device_map=device,
48
+ )
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+
52
+ query_pipeline = transformers.pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ return_full_text=True,
57
+ torch_dtype=torch.float16,
58
+ device_map=device,
59
+ temperature=0.7,
60
+ top_p=0.9,
61
+ top_k=50,
62
+ max_new_tokens=256
63
+ )
64
+
65
+ llm = HuggingFacePipeline(pipeline=query_pipeline)
66
+
67
+ books_db_client_retriever = RetrievalQA.from_chain_type(
68
+ llm=llm,
69
+ chain_type="stuff",
70
+ retriever=books_db_client,
71
+ verbose=True
72
+ )
73
+
74
+ st.title("RAG System with ChromaDB")
75
+
76
+ # Initialize session state for tracking previous questions and answers
77
+ if "history" not in st.session_state:
78
+ st.session_state.history = []
79
+
80
+ # Function to retrieve answer using the RAG system
81
+ def test_rag(qa, query):
82
+ return qa.run(query)
83
+
84
+ query = st.text_input("Enter your question:")
85
+
86
+ if st.button("Submit"):
87
+ if query:
88
+ # Get the answer from RAG
89
+ books_retriever = test_rag(books_db_client_retriever, query)
90
+
91
+ # Extracting the relevant answer using regex
92
+ corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
93
+
94
+ if corrected_text_match:
95
+ corrected_text_books = corrected_text_match.group(1).strip()
96
+ else:
97
+ corrected_text_books = "No helpful answer found."
98
+
99
+ # Store the query and answer in session state
100
+ st.session_state.history.append({"question": query, "answer": corrected_text_books})
101
+
102
+ # Display previous questions and answers
103
+ if st.session_state.history:
104
+ # st.write("### Previous Questions and Answers")
105
+ for idx, item in enumerate(st.session_state.history):
106
+ st.write(f"**Question:** {item['question']}")
107
+ st.write(f"**Answer:** {item['answer']}")
108
+ st.write("---")
109
+