Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
-
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
3 |
import whisper
|
4 |
import numpy as np
|
5 |
-
import av
|
6 |
-
from typing import List
|
7 |
-
import queue
|
8 |
from langchain_community.llms import HuggingFaceEndpoint
|
9 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
10 |
from langchain.memory import ConversationBufferMemory
|
@@ -16,6 +12,9 @@ from dotenv import load_dotenv
|
|
16 |
import requests
|
17 |
from requests.adapters import HTTPAdapter
|
18 |
from requests.packages.urllib3.util.retry import Retry
|
|
|
|
|
|
|
19 |
|
20 |
# Load environment variables
|
21 |
load_dotenv()
|
@@ -23,8 +22,10 @@ load_dotenv()
|
|
23 |
# Initialize session state
|
24 |
if "messages" not in st.session_state:
|
25 |
st.session_state.messages = []
|
26 |
-
if "
|
27 |
-
st.session_state.
|
|
|
|
|
28 |
|
29 |
# Prompt template
|
30 |
PROMPT_TEMPLATE = """
|
@@ -81,7 +82,7 @@ embeddings = HuggingFaceBgeEmbeddings(
|
|
81 |
)
|
82 |
|
83 |
vectorstore = FAISS.from_texts(
|
84 |
-
["Initial therapeutic context"],
|
85 |
embeddings
|
86 |
)
|
87 |
|
@@ -95,18 +96,10 @@ conversation_chain = ConversationalRetrievalChain.from_llm(
|
|
95 |
retriever=vectorstore.as_retriever(),
|
96 |
memory=memory,
|
97 |
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
98 |
-
return_source_documents=True
|
|
|
99 |
)
|
100 |
|
101 |
-
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
|
102 |
-
return frame
|
103 |
-
|
104 |
-
def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
|
105 |
-
if st.session_state.recording:
|
106 |
-
sound = frame.to_ndarray()
|
107 |
-
st.session_state.audio_buffer.put(sound)
|
108 |
-
return frame
|
109 |
-
|
110 |
def get_ai_response(user_input: str) -> str:
|
111 |
max_retries = 3
|
112 |
for attempt in range(max_retries):
|
@@ -156,42 +149,51 @@ def main():
|
|
156 |
|
157 |
st.title("Darija AI Therapist 🧠")
|
158 |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
159 |
-
|
160 |
-
# WebRTC setup
|
161 |
-
webrtc_ctx = webrtc_streamer(
|
162 |
-
key="speech-to-text",
|
163 |
-
mode=WebRtcMode.SENDONLY,
|
164 |
-
audio_receiver_size=1024,
|
165 |
-
rtc_configuration=RTCConfiguration(
|
166 |
-
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
167 |
-
),
|
168 |
-
video_frame_callback=video_frame_callback,
|
169 |
-
audio_frame_callback=audio_frame_callback,
|
170 |
-
media_stream_constraints={"video": False, "audio": True},
|
171 |
-
)
|
172 |
-
|
173 |
# Chat interface
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
if user_input:
|
176 |
process_message(user_input)
|
177 |
-
|
178 |
-
# Process audio when recording stops
|
179 |
-
if webrtc_ctx.state.playing and len(st.session_state.audio_buffer) > 0:
|
180 |
-
audio_frames = []
|
181 |
-
while not st.session_state.audio_buffer.empty():
|
182 |
-
audio_frames.append(st.session_state.audio_buffer.get())
|
183 |
-
|
184 |
-
if audio_frames:
|
185 |
-
audio_data = np.concatenate(audio_frames, axis=0)
|
186 |
-
text = whisper_model.transcribe(audio_data)["text"]
|
187 |
-
if text:
|
188 |
-
process_message(text)
|
189 |
-
st.session_state.audio_buffer = queue.Queue() # Clear buffer
|
190 |
|
191 |
# Display chat history
|
192 |
-
for
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
main()
|
|
|
1 |
import streamlit as st
|
|
|
2 |
import whisper
|
3 |
import numpy as np
|
|
|
|
|
|
|
4 |
from langchain_community.llms import HuggingFaceEndpoint
|
5 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
6 |
from langchain.memory import ConversationBufferMemory
|
|
|
12 |
import requests
|
13 |
from requests.adapters import HTTPAdapter
|
14 |
from requests.packages.urllib3.util.retry import Retry
|
15 |
+
import time # Imported time module
|
16 |
+
from streamlit_chat import message
|
17 |
+
from streamlit_audiorecorder import audiorecorder # For audio recording
|
18 |
|
19 |
# Load environment variables
|
20 |
load_dotenv()
|
|
|
22 |
# Initialize session state
|
23 |
if "messages" not in st.session_state:
|
24 |
st.session_state.messages = []
|
25 |
+
if "audio_data" not in st.session_state:
|
26 |
+
st.session_state.audio_data = None
|
27 |
+
if "recording" not in st.session_state:
|
28 |
+
st.session_state.recording = False
|
29 |
|
30 |
# Prompt template
|
31 |
PROMPT_TEMPLATE = """
|
|
|
82 |
)
|
83 |
|
84 |
vectorstore = FAISS.from_texts(
|
85 |
+
["Initial therapeutic context"],
|
86 |
embeddings
|
87 |
)
|
88 |
|
|
|
96 |
retriever=vectorstore.as_retriever(),
|
97 |
memory=memory,
|
98 |
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
99 |
+
return_source_documents=True,
|
100 |
+
output_key='answer' # Specify output_key to fix the error
|
101 |
)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def get_ai_response(user_input: str) -> str:
|
104 |
max_retries = 3
|
105 |
for attempt in range(max_retries):
|
|
|
149 |
|
150 |
st.title("Darija AI Therapist 🧠")
|
151 |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
152 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
# Chat interface
|
154 |
+
# Create columns for text input and mic button
|
155 |
+
col1, col2 = st.columns([9, 1])
|
156 |
+
with col1:
|
157 |
+
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
|
158 |
+
with col2:
|
159 |
+
# Mic button
|
160 |
+
if st.session_state.recording:
|
161 |
+
mic_label = "🛑"
|
162 |
+
else:
|
163 |
+
mic_label = "🎤"
|
164 |
+
|
165 |
+
if st.button(mic_label):
|
166 |
+
st.session_state.recording = not st.session_state.recording
|
167 |
+
if st.session_state.recording:
|
168 |
+
st.session_state.audio_data = audiorecorder("Click to stop recording")
|
169 |
+
else:
|
170 |
+
audio_data = st.session_state.audio_data
|
171 |
+
if audio_data is not None:
|
172 |
+
# Convert byte data to numpy array
|
173 |
+
audio_array = np.frombuffer(audio_data.tobytes(), dtype=np.int16)
|
174 |
+
# Normalize audio data
|
175 |
+
audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
|
176 |
+
# Transcribe audio using Whisper
|
177 |
+
result = whisper_model.transcribe(audio_array, language="ar")
|
178 |
+
if result["text"]:
|
179 |
+
# Put transcribed text into input field
|
180 |
+
st.session_state.text_input = result["text"]
|
181 |
+
else:
|
182 |
+
st.error("No audio data recorded.")
|
183 |
+
|
184 |
+
# Handle text submission
|
185 |
if user_input:
|
186 |
process_message(user_input)
|
187 |
+
st.session_state.text_input = "" # Clear input field after sending
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
# Display chat history
|
190 |
+
for message_data in st.session_state.messages:
|
191 |
+
role = message_data["role"]
|
192 |
+
content = message_data["content"]
|
193 |
+
if role == "user":
|
194 |
+
message(content, is_user=True)
|
195 |
+
else:
|
196 |
+
message(content)
|
197 |
|
198 |
if __name__ == "__main__":
|
199 |
main()
|