Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import base64 | |
from gtts import gTTS | |
# Use environment variable for Hugging Face token | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3" | |
DATA_PATH = "data/" | |
DB_FAISS_PATH = "vectorstore/db_faiss" | |
def load_pdf_files(data_path): | |
"""Load PDF files from the specified directory""" | |
loader = DirectoryLoader(data_path, | |
glob='*.pdf', | |
loader_cls=PyPDFLoader) | |
documents = loader.load() | |
return documents | |
def create_chunks(extracted_data): | |
"""Split documents into chunks""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, | |
chunk_overlap=50) | |
text_chunks = text_splitter.split_documents(extracted_data) | |
return text_chunks | |
def get_embedding_model(): | |
"""Get the embedding model""" | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
return embedding_model | |
def create_embeddings(): | |
"""Create embeddings and save to FAISS database""" | |
# Step 1: Load PDFs | |
documents = load_pdf_files(data_path=DATA_PATH) | |
st.info(f"Loaded {len(documents)} documents") | |
# Step 2: Create chunks | |
text_chunks = create_chunks(extracted_data=documents) | |
st.info(f"Created {len(text_chunks)} text chunks") | |
# Step 3: Get embedding model | |
embedding_model = get_embedding_model() | |
# Step 4: Create and save embeddings | |
os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True) | |
db = FAISS.from_documents(text_chunks, embedding_model) | |
db.save_local(DB_FAISS_PATH) | |
st.success("Embeddings created and saved successfully!") | |
return db | |
def set_custom_prompt(custom_prompt_template): | |
"""Set custom prompt template""" | |
prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"]) | |
return prompt | |
def load_llm(huggingface_repo_id): | |
"""Load Hugging Face LLM""" | |
llm = HuggingFaceEndpoint( | |
repo_id=huggingface_repo_id, | |
task="text-generation", | |
temperature=0.5, | |
model_kwargs={ | |
"token": HF_TOKEN, | |
"max_length": 512 | |
} | |
) | |
return llm | |
def get_vectorstore(): | |
"""Get or create vector store""" | |
if os.path.exists(DB_FAISS_PATH): | |
st.info("Loading existing vector store...") | |
embedding_model = get_embedding_model() | |
try: | |
db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True) | |
return db | |
except Exception as e: | |
st.error(f"Error loading vector store: {e}") | |
st.info("Creating new vector store...") | |
return create_embeddings() | |
else: | |
st.info("Creating new vector store...") | |
return create_embeddings() | |
def text_to_speech(text): | |
"""Convert text to speech and get the audio HTML for playback""" | |
try: | |
# Create a temporary directory for audio files if it doesn't exist | |
os.makedirs("temp", exist_ok=True) | |
# Generate the audio file using gTTS | |
tts = gTTS(text=text, lang='en', slow=False) | |
audio_file_path = "temp/response.mp3" | |
tts.save(audio_file_path) | |
# Read the audio file and encode it to base64 | |
with open(audio_file_path, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
audio_base64 = base64.b64encode(audio_bytes).decode() | |
# Create HTML with auto-play audio element | |
audio_html = f""" | |
<audio autoplay> | |
<source src="data:audio/mp3;base64,{audio_base64}" type="audio/mp3"> | |
Your browser does not support the audio element. | |
</audio> | |
""" | |
return audio_html | |
except Exception as e: | |
st.error(f"Error generating speech: {e}") | |
return None | |
def main(): | |
st.title("BeepKart FAQ Chatbot") | |
st.markdown("Ask questions about buying or selling bikes on BeepKart!") | |
# Initialize session state for messages | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat history | |
for message in st.session_state.messages: | |
st.chat_message(message['role']).markdown(message['content']) | |
# Get user input | |
prompt = st.chat_input("Ask a question about BeepKart...") | |
# Custom prompt template - modified to request concise answers | |
CUSTOM_PROMPT_TEMPLATE = """ | |
Use the pieces of information provided in the context to answer user's question in 1-2 sentences maximum. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Be extremely concise and direct. No explanations or additional information unless specifically requested. | |
Context: {context} | |
Question: {question} | |
Start the answer directly. No small talk please. | |
""" | |
if prompt: | |
# Display user message | |
st.chat_message('user').markdown(prompt) | |
st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
try: | |
with st.spinner("Thinking..."): | |
# Get vector store | |
vectorstore = get_vectorstore() | |
# Create QA chain | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID), | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(search_kwargs={'k': 3}), | |
return_source_documents=True, | |
chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)} | |
) | |
# Get response | |
response = qa_chain.invoke({'query': prompt}) | |
# Extract result only (no sources) | |
result = response["result"] | |
# Keep only the first sentence if the response is too long | |
sentences = result.split('. ') | |
if len(sentences) > 2: | |
result = '. '.join(sentences[:2]) + '.' | |
# Display the result | |
st.chat_message('assistant').markdown(result) | |
st.session_state.messages.append({'role': 'assistant', 'content': result}) | |
# Generate speech from the result and play it | |
audio_html = text_to_speech(result) | |
if audio_html: | |
st.markdown(audio_html, unsafe_allow_html=True) | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
st.error(error_message) | |
st.error("Please check your HuggingFace token and model access permissions") | |
st.session_state.messages.append({'role': 'assistant', 'content': error_message}) | |
if __name__ == "__main__": | |
main() |