Somnath3570 commited on
Commit
8f6cb7b
·
verified ·
1 Parent(s): b460be5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain.chains import RetrievalQA
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_huggingface import HuggingFaceEndpoint
8
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ import base64
11
+ from gtts import gTTS
12
+
13
+ # Use environment variable for Hugging Face token
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+ HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
16
+ DATA_PATH = "data/"
17
+ DB_FAISS_PATH = "vectorstore/db_faiss"
18
+
19
+ def load_pdf_files(data_path):
20
+ """Load PDF files from the specified directory"""
21
+ loader = DirectoryLoader(data_path,
22
+ glob='*.pdf',
23
+ loader_cls=PyPDFLoader)
24
+ documents = loader.load()
25
+ return documents
26
+
27
+ def create_chunks(extracted_data):
28
+ """Split documents into chunks"""
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
30
+ chunk_overlap=50)
31
+ text_chunks = text_splitter.split_documents(extracted_data)
32
+ return text_chunks
33
+
34
+ def get_embedding_model():
35
+ """Get the embedding model"""
36
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
+ return embedding_model
38
+
39
+ def create_embeddings():
40
+ """Create embeddings and save to FAISS database"""
41
+ # Step 1: Load PDFs
42
+ documents = load_pdf_files(data_path=DATA_PATH)
43
+ st.info(f"Loaded {len(documents)} documents")
44
+
45
+ # Step 2: Create chunks
46
+ text_chunks = create_chunks(extracted_data=documents)
47
+ st.info(f"Created {len(text_chunks)} text chunks")
48
+
49
+ # Step 3: Get embedding model
50
+ embedding_model = get_embedding_model()
51
+
52
+ # Step 4: Create and save embeddings
53
+ os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True)
54
+ db = FAISS.from_documents(text_chunks, embedding_model)
55
+ db.save_local(DB_FAISS_PATH)
56
+ st.success("Embeddings created and saved successfully!")
57
+ return db
58
+
59
+ def set_custom_prompt(custom_prompt_template):
60
+ """Set custom prompt template"""
61
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
62
+ return prompt
63
+
64
+ def load_llm(huggingface_repo_id):
65
+ """Load Hugging Face LLM"""
66
+ llm = HuggingFaceEndpoint(
67
+ repo_id=huggingface_repo_id,
68
+ task="text-generation",
69
+ temperature=0.5,
70
+ model_kwargs={
71
+ "token": HF_TOKEN,
72
+ "max_length": 512
73
+ }
74
+ )
75
+ return llm
76
+
77
+ def get_vectorstore():
78
+ """Get or create vector store"""
79
+ if os.path.exists(DB_FAISS_PATH):
80
+ st.info("Loading existing vector store...")
81
+ embedding_model = get_embedding_model()
82
+ try:
83
+ db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
84
+ return db
85
+ except Exception as e:
86
+ st.error(f"Error loading vector store: {e}")
87
+ st.info("Creating new vector store...")
88
+ return create_embeddings()
89
+ else:
90
+ st.info("Creating new vector store...")
91
+ return create_embeddings()
92
+
93
+ def text_to_speech(text):
94
+ """Convert text to speech and get the audio HTML for playback"""
95
+ try:
96
+ # Create a temporary directory for audio files if it doesn't exist
97
+ os.makedirs("temp", exist_ok=True)
98
+
99
+ # Generate the audio file using gTTS
100
+ tts = gTTS(text=text, lang='en', slow=False)
101
+ audio_file_path = "temp/response.mp3"
102
+ tts.save(audio_file_path)
103
+
104
+ # Read the audio file and encode it to base64
105
+ with open(audio_file_path, "rb") as audio_file:
106
+ audio_bytes = audio_file.read()
107
+ audio_base64 = base64.b64encode(audio_bytes).decode()
108
+
109
+ # Create HTML with auto-play audio element
110
+ audio_html = f"""
111
+ <audio autoplay>
112
+ <source src="data:audio/mp3;base64,{audio_base64}" type="audio/mp3">
113
+ Your browser does not support the audio element.
114
+ </audio>
115
+ """
116
+ return audio_html
117
+
118
+ except Exception as e:
119
+ st.error(f"Error generating speech: {e}")
120
+ return None
121
+
122
+ def main():
123
+ st.title("BeepKart FAQ Chatbot")
124
+ st.markdown("Ask questions about buying or selling bikes on BeepKart!")
125
+
126
+ # Initialize session state for messages
127
+ if 'messages' not in st.session_state:
128
+ st.session_state.messages = []
129
+
130
+ # Display chat history
131
+ for message in st.session_state.messages:
132
+ st.chat_message(message['role']).markdown(message['content'])
133
+
134
+ # Get user input
135
+ prompt = st.chat_input("Ask a question about BeepKart...")
136
+
137
+ # Custom prompt template - modified to request concise answers
138
+ CUSTOM_PROMPT_TEMPLATE = """
139
+ Use the pieces of information provided in the context to answer user's question in 1-2 sentences maximum.
140
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
141
+
142
+ Be extremely concise and direct. No explanations or additional information unless specifically requested.
143
+
144
+ Context: {context}
145
+ Question: {question}
146
+
147
+ Start the answer directly. No small talk please.
148
+ """
149
+
150
+ if prompt:
151
+ # Display user message
152
+ st.chat_message('user').markdown(prompt)
153
+ st.session_state.messages.append({'role': 'user', 'content': prompt})
154
+
155
+ try:
156
+ with st.spinner("Thinking..."):
157
+ # Get vector store
158
+ vectorstore = get_vectorstore()
159
+
160
+ # Create QA chain
161
+ qa_chain = RetrievalQA.from_chain_type(
162
+ llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID),
163
+ chain_type="stuff",
164
+ retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
165
+ return_source_documents=True,
166
+ chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
167
+ )
168
+
169
+ # Get response
170
+ response = qa_chain.invoke({'query': prompt})
171
+
172
+ # Extract result only (no sources)
173
+ result = response["result"]
174
+
175
+ # Keep only the first sentence if the response is too long
176
+ sentences = result.split('. ')
177
+ if len(sentences) > 2:
178
+ result = '. '.join(sentences[:2]) + '.'
179
+
180
+ # Display the result
181
+ st.chat_message('assistant').markdown(result)
182
+ st.session_state.messages.append({'role': 'assistant', 'content': result})
183
+
184
+ # Generate speech from the result and play it
185
+ audio_html = text_to_speech(result)
186
+ if audio_html:
187
+ st.markdown(audio_html, unsafe_allow_html=True)
188
+
189
+ except Exception as e:
190
+ error_message = f"Error: {str(e)}"
191
+ st.error(error_message)
192
+ st.error("Please check your HuggingFace token and model access permissions")
193
+ st.session_state.messages.append({'role': 'assistant', 'content': error_message})
194
+
195
+ if __name__ == "__main__":
196
+ main()