jaafarhh commited on
Commit
699acb6
·
verified ·
1 Parent(s): 5f90f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -48
app.py CHANGED
@@ -1,10 +1,6 @@
1
  import streamlit as st
2
- from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
3
  import whisper
4
  import numpy as np
5
- import av
6
- from typing import List
7
- import queue
8
  from langchain_community.llms import HuggingFaceEndpoint
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
10
  from langchain.memory import ConversationBufferMemory
@@ -16,6 +12,9 @@ from dotenv import load_dotenv
16
  import requests
17
  from requests.adapters import HTTPAdapter
18
  from requests.packages.urllib3.util.retry import Retry
 
 
 
19
 
20
  # Load environment variables
21
  load_dotenv()
@@ -23,8 +22,10 @@ load_dotenv()
23
  # Initialize session state
24
  if "messages" not in st.session_state:
25
  st.session_state.messages = []
26
- if "audio_buffer" not in st.session_state:
27
- st.session_state.audio_buffer = queue.Queue()
 
 
28
 
29
  # Prompt template
30
  PROMPT_TEMPLATE = """
@@ -81,7 +82,7 @@ embeddings = HuggingFaceBgeEmbeddings(
81
  )
82
 
83
  vectorstore = FAISS.from_texts(
84
- ["Initial therapeutic context"],
85
  embeddings
86
  )
87
 
@@ -95,18 +96,10 @@ conversation_chain = ConversationalRetrievalChain.from_llm(
95
  retriever=vectorstore.as_retriever(),
96
  memory=memory,
97
  combine_docs_chain_kwargs={"prompt": qa_prompt},
98
- return_source_documents=True
 
99
  )
100
 
101
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
102
- return frame
103
-
104
- def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
105
- if st.session_state.recording:
106
- sound = frame.to_ndarray()
107
- st.session_state.audio_buffer.put(sound)
108
- return frame
109
-
110
  def get_ai_response(user_input: str) -> str:
111
  max_retries = 3
112
  for attempt in range(max_retries):
@@ -156,42 +149,51 @@ def main():
156
 
157
  st.title("Darija AI Therapist 🧠")
158
  st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
159
-
160
- # WebRTC setup
161
- webrtc_ctx = webrtc_streamer(
162
- key="speech-to-text",
163
- mode=WebRtcMode.SENDONLY,
164
- audio_receiver_size=1024,
165
- rtc_configuration=RTCConfiguration(
166
- {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
167
- ),
168
- video_frame_callback=video_frame_callback,
169
- audio_frame_callback=audio_frame_callback,
170
- media_stream_constraints={"video": False, "audio": True},
171
- )
172
-
173
  # Chat interface
174
- user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  if user_input:
176
  process_message(user_input)
177
-
178
- # Process audio when recording stops
179
- if webrtc_ctx.state.playing and len(st.session_state.audio_buffer) > 0:
180
- audio_frames = []
181
- while not st.session_state.audio_buffer.empty():
182
- audio_frames.append(st.session_state.audio_buffer.get())
183
-
184
- if audio_frames:
185
- audio_data = np.concatenate(audio_frames, axis=0)
186
- text = whisper_model.transcribe(audio_data)["text"]
187
- if text:
188
- process_message(text)
189
- st.session_state.audio_buffer = queue.Queue() # Clear buffer
190
 
191
  # Display chat history
192
- for message in st.session_state.messages:
193
- with st.chat_message(message["role"]):
194
- st.write(message["content"])
 
 
 
 
195
 
196
  if __name__ == "__main__":
197
  main()
 
1
  import streamlit as st
 
2
  import whisper
3
  import numpy as np
 
 
 
4
  from langchain_community.llms import HuggingFaceEndpoint
5
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
6
  from langchain.memory import ConversationBufferMemory
 
12
  import requests
13
  from requests.adapters import HTTPAdapter
14
  from requests.packages.urllib3.util.retry import Retry
15
+ import time # Imported time module
16
+ from streamlit_chat import message
17
+ from streamlit_audiorecorder import audiorecorder # For audio recording
18
 
19
  # Load environment variables
20
  load_dotenv()
 
22
  # Initialize session state
23
  if "messages" not in st.session_state:
24
  st.session_state.messages = []
25
+ if "audio_data" not in st.session_state:
26
+ st.session_state.audio_data = None
27
+ if "recording" not in st.session_state:
28
+ st.session_state.recording = False
29
 
30
  # Prompt template
31
  PROMPT_TEMPLATE = """
 
82
  )
83
 
84
  vectorstore = FAISS.from_texts(
85
+ ["Initial therapeutic context"],
86
  embeddings
87
  )
88
 
 
96
  retriever=vectorstore.as_retriever(),
97
  memory=memory,
98
  combine_docs_chain_kwargs={"prompt": qa_prompt},
99
+ return_source_documents=True,
100
+ output_key='answer' # Specify output_key to fix the error
101
  )
102
 
 
 
 
 
 
 
 
 
 
103
  def get_ai_response(user_input: str) -> str:
104
  max_retries = 3
105
  for attempt in range(max_retries):
 
149
 
150
  st.title("Darija AI Therapist 🧠")
151
  st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
152
+
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Chat interface
154
+ # Create columns for text input and mic button
155
+ col1, col2 = st.columns([9, 1])
156
+ with col1:
157
+ user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
158
+ with col2:
159
+ # Mic button
160
+ if st.session_state.recording:
161
+ mic_label = "🛑"
162
+ else:
163
+ mic_label = "🎤"
164
+
165
+ if st.button(mic_label):
166
+ st.session_state.recording = not st.session_state.recording
167
+ if st.session_state.recording:
168
+ st.session_state.audio_data = audiorecorder("Click to stop recording")
169
+ else:
170
+ audio_data = st.session_state.audio_data
171
+ if audio_data is not None:
172
+ # Convert byte data to numpy array
173
+ audio_array = np.frombuffer(audio_data.tobytes(), dtype=np.int16)
174
+ # Normalize audio data
175
+ audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
176
+ # Transcribe audio using Whisper
177
+ result = whisper_model.transcribe(audio_array, language="ar")
178
+ if result["text"]:
179
+ # Put transcribed text into input field
180
+ st.session_state.text_input = result["text"]
181
+ else:
182
+ st.error("No audio data recorded.")
183
+
184
+ # Handle text submission
185
  if user_input:
186
  process_message(user_input)
187
+ st.session_state.text_input = "" # Clear input field after sending
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  # Display chat history
190
+ for message_data in st.session_state.messages:
191
+ role = message_data["role"]
192
+ content = message_data["content"]
193
+ if role == "user":
194
+ message(content, is_user=True)
195
+ else:
196
+ message(content)
197
 
198
  if __name__ == "__main__":
199
  main()