jaafarhh commited on
Commit
deb9302
·
verified ·
1 Parent(s): 3de17cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -69
app.py CHANGED
@@ -1,22 +1,22 @@
1
  import streamlit as st
2
- import whisper
3
- import numpy as np
4
  from langchain_community.llms import HuggingFaceEndpoint
5
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
 
6
  from langchain_community.vectorstores import FAISS
7
- from langchain_core.prompts import PromptTemplate
8
- from langchain_core.memory import BaseMemory
9
- from langchain_core.output_parsers import StrOutputParser
10
- from langchain_core.runnables import RunnablePassthrough
11
- from langchain.chains import ConversationChain
12
  import os
13
  from dotenv import load_dotenv
14
  import requests
15
  from requests.adapters import HTTPAdapter
16
  from requests.packages.urllib3.util.retry import Retry
17
- import time
18
- from streamlit_chat import message
19
- from streamlit_audiorecorder import audiorecorder
 
 
20
 
21
  # Load environment variables
22
  load_dotenv()
@@ -24,12 +24,15 @@ load_dotenv()
24
  # Initialize session state
25
  if "messages" not in st.session_state:
26
  st.session_state.messages = []
27
- if "audio_data" not in st.session_state:
28
- st.session_state.audio_data = None
29
- if "recording" not in st.session_state:
 
 
30
  st.session_state.recording = False
31
- if "text_input" not in st.session_state:
32
- st.session_state.text_input = ""
 
33
 
34
  # Prompt template
35
  PROMPT_TEMPLATE = """
@@ -76,7 +79,7 @@ llm = HuggingFaceEndpoint(
76
  )
77
 
78
  # Setup memory and conversation chain
79
- memory = ConversationChain(
80
  memory_key="chat_history",
81
  return_messages=True
82
  )
@@ -95,24 +98,19 @@ qa_prompt = PromptTemplate(
95
  input_variables=["context", "chat_history", "question"]
96
  )
97
 
98
- def create_chain():
99
- prompt = PromptTemplate(
100
- template=PROMPT_TEMPLATE,
101
- input_variables=["context", "chat_history", "question"]
102
- )
103
-
104
- retriever = vectorstore.as_retriever()
105
-
106
- chain = (
107
- {"context": retriever, "question": RunnablePassthrough()}
108
- | prompt
109
- | llm
110
- | StrOutputParser()
111
- )
112
-
113
- return chain
114
-
115
- conversation_chain = create_chain()
116
 
117
  def get_ai_response(user_input: str) -> str:
118
  max_retries = 3
@@ -128,13 +126,13 @@ def get_ai_response(user_input: str) -> str:
128
  "question": user_input,
129
  "chat_history": memory.chat_memory.messages[-5:]
130
  })
131
-
132
  if not response or 'answer' not in response:
133
  if attempt < max_retries - 1:
134
  time.sleep(2 ** attempt)
135
  continue
136
  return "عذراً، كاين مشكل. حاول مرة أخرى."
137
-
138
  return response['answer']
139
 
140
  except requests.exceptions.HTTPError as e:
@@ -142,7 +140,7 @@ def get_ai_response(user_input: str) -> str:
142
  time.sleep(2 ** attempt)
143
  continue
144
  return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
145
-
146
  except Exception as e:
147
  st.error(f"Error: {str(e)}")
148
  if attempt < max_retries - 1:
@@ -152,7 +150,7 @@ def get_ai_response(user_input: str) -> str:
152
 
153
  def process_message(user_input: str) -> None:
154
  st.session_state.messages.append({"role": "user", "content": user_input})
155
-
156
  with st.spinner("جاري التفكير..."):
157
  ai_response = get_ai_response(user_input)
158
  if ai_response:
@@ -160,54 +158,70 @@ def process_message(user_input: str) -> None:
160
 
161
  def main():
162
  st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
163
-
164
  st.title("Darija AI Therapist 🧠")
165
  st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
166
-
167
- # Chat interface
168
- # Create columns for text input and mic button
169
  col1, col2 = st.columns([9, 1])
170
  with col1:
171
  user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
172
  with col2:
173
- # Mic button
174
  if st.session_state.recording:
175
- mic_label = "🛑"
176
  else:
177
- mic_label = "🎤"
178
-
179
- if st.button(mic_label):
180
  st.session_state.recording = not st.session_state.recording
181
  if st.session_state.recording:
182
- st.session_state.audio_data = audiorecorder("Click to stop recording")
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
- audio_data = st.session_state.audio_data
185
- if audio_data is not None:
186
- # Convert byte data to numpy array
187
- audio_array = np.frombuffer(audio_data.tobytes(), dtype=np.int16)
188
- # Normalize audio data
189
- audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
190
- # Transcribe audio using Whisper
191
- result = whisper_model.transcribe(audio_array, language="ar")
192
- if result["text"]:
193
- # Put transcribed text into input field
194
- st.session_state.text_input = result["text"]
 
 
 
 
 
 
195
  else:
196
- st.error("No audio data recorded.")
 
 
 
 
 
 
 
 
197
 
198
- # Handle text submission
199
  if user_input:
200
  process_message(user_input)
201
- st.session_state.text_input = "" # Clear input field after sending
202
 
203
  # Display chat history
204
- for message_data in st.session_state.messages:
205
- role = message_data["role"]
206
- content = message_data["content"]
207
- if role == "user":
208
- message(content, is_user=True)
209
- else:
210
- message(content)
211
 
212
  if __name__ == "__main__":
213
  main()
 
1
  import streamlit as st
2
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
3
+ from typing import List
4
  from langchain_community.llms import HuggingFaceEndpoint
5
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.chains import ConversationalRetrievalChain
8
  from langchain_community.vectorstores import FAISS
9
+ from langchain.prompts import PromptTemplate
 
 
 
 
10
  import os
11
  from dotenv import load_dotenv
12
  import requests
13
  from requests.adapters import HTTPAdapter
14
  from requests.packages.urllib3.util.retry import Retry
15
+ import whisper
16
+ import numpy as np
17
+ import av
18
+ import time # Added import time
19
+ import queue
20
 
21
  # Load environment variables
22
  load_dotenv()
 
24
  # Initialize session state
25
  if "messages" not in st.session_state:
26
  st.session_state.messages = []
27
+
28
+ if "audio_buffer" not in st.session_state:
29
+ st.session_state.audio_buffer = queue.Queue()
30
+
31
+ if 'recording' not in st.session_state:
32
  st.session_state.recording = False
33
+
34
+ if 'webrtc_ctx' not in st.session_state:
35
+ st.session_state.webrtc_ctx = None
36
 
37
  # Prompt template
38
  PROMPT_TEMPLATE = """
 
79
  )
80
 
81
  # Setup memory and conversation chain
82
+ memory = ConversationBufferMemory(
83
  memory_key="chat_history",
84
  return_messages=True
85
  )
 
98
  input_variables=["context", "chat_history", "question"]
99
  )
100
 
101
+ conversation_chain = ConversationalRetrievalChain.from_llm(
102
+ llm=llm,
103
+ retriever=vectorstore.as_retriever(),
104
+ memory=memory,
105
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
106
+ return_source_documents=True,
107
+ output_key='answer' # Specify output_key to fix the error
108
+ )
109
+
110
+ def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
111
+ audio = frame.to_ndarray().flatten()
112
+ st.session_state.audio_buffer.put(audio)
113
+ return frame
 
 
 
 
 
114
 
115
  def get_ai_response(user_input: str) -> str:
116
  max_retries = 3
 
126
  "question": user_input,
127
  "chat_history": memory.chat_memory.messages[-5:]
128
  })
129
+
130
  if not response or 'answer' not in response:
131
  if attempt < max_retries - 1:
132
  time.sleep(2 ** attempt)
133
  continue
134
  return "عذراً، كاين مشكل. حاول مرة أخرى."
135
+
136
  return response['answer']
137
 
138
  except requests.exceptions.HTTPError as e:
 
140
  time.sleep(2 ** attempt)
141
  continue
142
  return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
143
+
144
  except Exception as e:
145
  st.error(f"Error: {str(e)}")
146
  if attempt < max_retries - 1:
 
150
 
151
  def process_message(user_input: str) -> None:
152
  st.session_state.messages.append({"role": "user", "content": user_input})
153
+
154
  with st.spinner("جاري التفكير..."):
155
  ai_response = get_ai_response(user_input)
156
  if ai_response:
 
158
 
159
  def main():
160
  st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
161
+
162
  st.title("Darija AI Therapist 🧠")
163
  st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
164
+
 
 
165
  col1, col2 = st.columns([9, 1])
166
  with col1:
167
  user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
168
  with col2:
 
169
  if st.session_state.recording:
170
+ mic_icon = "🛑"
171
  else:
172
+ mic_icon = "🎤"
173
+
174
+ if st.button(mic_icon):
175
  st.session_state.recording = not st.session_state.recording
176
  if st.session_state.recording:
177
+ st.session_state.audio_buffer = queue.Queue()
178
+ st.session_state.webrtc_ctx = webrtc_streamer(
179
+ key="speech-to-text",
180
+ mode=WebRtcMode.SENDONLY,
181
+ audio_receiver_size=256,
182
+ rtc_configuration=RTCConfiguration(
183
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
184
+ ),
185
+ media_stream_constraints={"video": False, "audio": True},
186
+ async_processing=True,
187
+ audio_frame_callback=audio_frame_callback,
188
+ )
189
  else:
190
+ st.info("🔄 Processing audio...")
191
+ audio_frames = []
192
+ while not st.session_state.audio_buffer.empty():
193
+ audio_frames.append(st.session_state.audio_buffer.get())
194
+
195
+ if audio_frames:
196
+ audio_data = np.concatenate(audio_frames, axis=0).flatten()
197
+ # Convert to 16-bit integers
198
+ audio_data_int16 = (audio_data * 32767).astype(np.int16)
199
+ # Use Whisper to transcribe
200
+ result = whisper_model.transcribe(audio_data_int16, fp16=False)
201
+ text = result.get("text", "")
202
+ if text:
203
+ process_message(text)
204
+ else:
205
+ st.warning("ما فهمتش الصوت. حاول مرة أخرى.")
206
+ st.session_state.audio_buffer = queue.Queue()
207
  else:
208
+ st.warning("ما تسجلش الصوت. حاول مرة أخرى.")
209
+ if st.session_state.webrtc_ctx:
210
+ st.session_state.webrtc_ctx.stop()
211
+ st.session_state.webrtc_ctx = None
212
+
213
+ if st.session_state.recording:
214
+ st.info("🎙️ Recording...")
215
+ else:
216
+ st.empty()
217
 
 
218
  if user_input:
219
  process_message(user_input)
 
220
 
221
  # Display chat history
222
+ for message in st.session_state.messages:
223
+ with st.chat_message(message["role"]):
224
+ st.write(message["content"])
 
 
 
 
225
 
226
  if __name__ == "__main__":
227
  main()