File size: 7,393 Bytes
8f6cb7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import streamlit as st
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import base64
from gtts import gTTS

# Use environment variable for Hugging Face token
HF_TOKEN = os.environ.get("HF_TOKEN")
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
DATA_PATH = "data/"
DB_FAISS_PATH = "vectorstore/db_faiss"

def load_pdf_files(data_path):
    """Load PDF files from the specified directory"""
    loader = DirectoryLoader(data_path,
                           glob='*.pdf',
                           loader_cls=PyPDFLoader)
    documents = loader.load()
    return documents

def create_chunks(extracted_data):
    """Split documents into chunks"""
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
                                                chunk_overlap=50)
    text_chunks = text_splitter.split_documents(extracted_data)
    return text_chunks

def get_embedding_model():
    """Get the embedding model"""
    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    return embedding_model

def create_embeddings():
    """Create embeddings and save to FAISS database"""
    # Step 1: Load PDFs
    documents = load_pdf_files(data_path=DATA_PATH)
    st.info(f"Loaded {len(documents)} documents")
    
    # Step 2: Create chunks
    text_chunks = create_chunks(extracted_data=documents)
    st.info(f"Created {len(text_chunks)} text chunks")
    
    # Step 3: Get embedding model
    embedding_model = get_embedding_model()
    
    # Step 4: Create and save embeddings
    os.makedirs(os.path.dirname(DB_FAISS_PATH), exist_ok=True)
    db = FAISS.from_documents(text_chunks, embedding_model)
    db.save_local(DB_FAISS_PATH)
    st.success("Embeddings created and saved successfully!")
    return db

def set_custom_prompt(custom_prompt_template):
    """Set custom prompt template"""
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
    return prompt

def load_llm(huggingface_repo_id):
    """Load Hugging Face LLM"""
    llm = HuggingFaceEndpoint(
        repo_id=huggingface_repo_id,
        task="text-generation",
        temperature=0.5,
        model_kwargs={
            "token": HF_TOKEN,
            "max_length": 512
        }
    )
    return llm

def get_vectorstore():
    """Get or create vector store"""
    if os.path.exists(DB_FAISS_PATH):
        st.info("Loading existing vector store...")
        embedding_model = get_embedding_model()
        try:
            db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
            return db
        except Exception as e:
            st.error(f"Error loading vector store: {e}")
            st.info("Creating new vector store...")
            return create_embeddings()
    else:
        st.info("Creating new vector store...")
        return create_embeddings()

def text_to_speech(text):
    """Convert text to speech and get the audio HTML for playback"""
    try:
        # Create a temporary directory for audio files if it doesn't exist
        os.makedirs("temp", exist_ok=True)
        
        # Generate the audio file using gTTS
        tts = gTTS(text=text, lang='en', slow=False)
        audio_file_path = "temp/response.mp3"
        tts.save(audio_file_path)
        
        # Read the audio file and encode it to base64
        with open(audio_file_path, "rb") as audio_file:
            audio_bytes = audio_file.read()
            audio_base64 = base64.b64encode(audio_bytes).decode()
        
        # Create HTML with auto-play audio element
        audio_html = f"""
        <audio autoplay>
            <source src="data:audio/mp3;base64,{audio_base64}" type="audio/mp3">
            Your browser does not support the audio element.
        </audio>
        """
        return audio_html
    
    except Exception as e:
        st.error(f"Error generating speech: {e}")
        return None

def main():
    st.title("BeepKart FAQ Chatbot")
    st.markdown("Ask questions about buying or selling bikes on BeepKart!")
    
    # Initialize session state for messages
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    # Display chat history
    for message in st.session_state.messages:
        st.chat_message(message['role']).markdown(message['content'])
    
    # Get user input
    prompt = st.chat_input("Ask a question about BeepKart...")
    
    # Custom prompt template - modified to request concise answers
    CUSTOM_PROMPT_TEMPLATE = """
    Use the pieces of information provided in the context to answer user's question in 1-2 sentences maximum.
    If you don't know the answer, just say that you don't know, don't try to make up an answer.
    
    Be extremely concise and direct. No explanations or additional information unless specifically requested.
    
    Context: {context}
    Question: {question}
    
    Start the answer directly. No small talk please.
    """
    
    if prompt:
        # Display user message
        st.chat_message('user').markdown(prompt)
        st.session_state.messages.append({'role': 'user', 'content': prompt})
        
        try:
            with st.spinner("Thinking..."):
                # Get vector store
                vectorstore = get_vectorstore()
                
                # Create QA chain
                qa_chain = RetrievalQA.from_chain_type(
                    llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID),
                    chain_type="stuff",
                    retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
                    return_source_documents=True,
                    chain_type_kwargs={'prompt': set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
                )
                
                # Get response
                response = qa_chain.invoke({'query': prompt})
                
                # Extract result only (no sources)
                result = response["result"]
                
                # Keep only the first sentence if the response is too long
                sentences = result.split('. ')
                if len(sentences) > 2:
                    result = '. '.join(sentences[:2]) + '.'
                
                # Display the result
                st.chat_message('assistant').markdown(result)
                st.session_state.messages.append({'role': 'assistant', 'content': result})
                
                # Generate speech from the result and play it
                audio_html = text_to_speech(result)
                if audio_html:
                    st.markdown(audio_html, unsafe_allow_html=True)
                
        except Exception as e:
            error_message = f"Error: {str(e)}"
            st.error(error_message)
            st.error("Please check your HuggingFace token and model access permissions")
            st.session_state.messages.append({'role': 'assistant', 'content': error_message})

if __name__ == "__main__":
    main()