jaafarhh commited on
Commit
4f8302d
·
verified ·
1 Parent(s): 97e38ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -3,18 +3,20 @@ import whisper
3
  import numpy as np
4
  from langchain_community.llms import HuggingFaceEndpoint
5
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
6
- from langchain.memory import ConversationBufferMemory
7
- from langchain.chains import ConversationalRetrievalChain
8
  from langchain_community.vectorstores import FAISS
9
- from langchain.prompts import PromptTemplate
 
 
 
 
10
  import os
11
  from dotenv import load_dotenv
12
  import requests
13
  from requests.adapters import HTTPAdapter
14
  from requests.packages.urllib3.util.retry import Retry
15
- import time # Imported time module
16
  from streamlit_chat import message
17
- from streamlit_audiorecorder import audiorecorder # For audio recording
18
 
19
  # Load environment variables
20
  load_dotenv()
@@ -26,6 +28,8 @@ if "audio_data" not in st.session_state:
26
  st.session_state.audio_data = None
27
  if "recording" not in st.session_state:
28
  st.session_state.recording = False
 
 
29
 
30
  # Prompt template
31
  PROMPT_TEMPLATE = """
@@ -72,7 +76,7 @@ llm = HuggingFaceEndpoint(
72
  )
73
 
74
  # Setup memory and conversation chain
75
- memory = ConversationBufferMemory(
76
  memory_key="chat_history",
77
  return_messages=True
78
  )
@@ -91,14 +95,24 @@ qa_prompt = PromptTemplate(
91
  input_variables=["context", "chat_history", "question"]
92
  )
93
 
94
- conversation_chain = ConversationalRetrievalChain.from_llm(
95
- llm=llm,
96
- retriever=vectorstore.as_retriever(),
97
- memory=memory,
98
- combine_docs_chain_kwargs={"prompt": qa_prompt},
99
- return_source_documents=True,
100
- output_key='answer' # Specify output_key to fix the error
101
- )
 
 
 
 
 
 
 
 
 
 
102
 
103
  def get_ai_response(user_input: str) -> str:
104
  max_retries = 3
 
3
  import numpy as np
4
  from langchain_community.llms import HuggingFaceEndpoint
5
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
 
6
  from langchain_community.vectorstores import FAISS
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.memory import BaseMemory
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from langchain_core.runnables import RunnablePassthrough
11
+ from langchain.chains import ConversationChain
12
  import os
13
  from dotenv import load_dotenv
14
  import requests
15
  from requests.adapters import HTTPAdapter
16
  from requests.packages.urllib3.util.retry import Retry
17
+ import time
18
  from streamlit_chat import message
19
+ from streamlit_audiorecorder import audiorecorder
20
 
21
  # Load environment variables
22
  load_dotenv()
 
28
  st.session_state.audio_data = None
29
  if "recording" not in st.session_state:
30
  st.session_state.recording = False
31
+ if "text_input" not in st.session_state:
32
+ st.session_state.text_input = ""
33
 
34
  # Prompt template
35
  PROMPT_TEMPLATE = """
 
76
  )
77
 
78
  # Setup memory and conversation chain
79
+ memory = ConversationChain(
80
  memory_key="chat_history",
81
  return_messages=True
82
  )
 
95
  input_variables=["context", "chat_history", "question"]
96
  )
97
 
98
+ def create_chain():
99
+ prompt = PromptTemplate(
100
+ template=PROMPT_TEMPLATE,
101
+ input_variables=["context", "chat_history", "question"]
102
+ )
103
+
104
+ retriever = vectorstore.as_retriever()
105
+
106
+ chain = (
107
+ {"context": retriever, "question": RunnablePassthrough()}
108
+ | prompt
109
+ | llm
110
+ | StrOutputParser()
111
+ )
112
+
113
+ return chain
114
+
115
+ conversation_chain = create_chain()
116
 
117
  def get_ai_response(user_input: str) -> str:
118
  max_retries = 3