Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,22 +1,22 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
-
|
4 |
from langchain_community.llms import HuggingFaceEndpoint
|
5 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
|
|
|
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
-
from
|
8 |
-
from langchain_core.memory import BaseMemory
|
9 |
-
from langchain_core.output_parsers import StrOutputParser
|
10 |
-
from langchain_core.runnables import RunnablePassthrough
|
11 |
-
from langchain.chains import ConversationChain
|
12 |
import os
|
13 |
from dotenv import load_dotenv
|
14 |
import requests
|
15 |
from requests.adapters import HTTPAdapter
|
16 |
from requests.packages.urllib3.util.retry import Retry
|
17 |
-
import
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv()
|
@@ -24,12 +24,15 @@ load_dotenv()
|
|
24 |
# Initialize session state
|
25 |
if "messages" not in st.session_state:
|
26 |
st.session_state.messages = []
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
st.session_state.recording = False
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
# Prompt template
|
35 |
PROMPT_TEMPLATE = """
|
@@ -76,7 +79,7 @@ llm = HuggingFaceEndpoint(
|
|
76 |
)
|
77 |
|
78 |
# Setup memory and conversation chain
|
79 |
-
memory =
|
80 |
memory_key="chat_history",
|
81 |
return_messages=True
|
82 |
)
|
@@ -95,24 +98,19 @@ qa_prompt = PromptTemplate(
|
|
95 |
input_variables=["context", "chat_history", "question"]
|
96 |
)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
)
|
112 |
-
|
113 |
-
return chain
|
114 |
-
|
115 |
-
conversation_chain = create_chain()
|
116 |
|
117 |
def get_ai_response(user_input: str) -> str:
|
118 |
max_retries = 3
|
@@ -128,13 +126,13 @@ def get_ai_response(user_input: str) -> str:
|
|
128 |
"question": user_input,
|
129 |
"chat_history": memory.chat_memory.messages[-5:]
|
130 |
})
|
131 |
-
|
132 |
if not response or 'answer' not in response:
|
133 |
if attempt < max_retries - 1:
|
134 |
time.sleep(2 ** attempt)
|
135 |
continue
|
136 |
return "عذراً، كاين مشكل. حاول مرة أخرى."
|
137 |
-
|
138 |
return response['answer']
|
139 |
|
140 |
except requests.exceptions.HTTPError as e:
|
@@ -142,7 +140,7 @@ def get_ai_response(user_input: str) -> str:
|
|
142 |
time.sleep(2 ** attempt)
|
143 |
continue
|
144 |
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
|
145 |
-
|
146 |
except Exception as e:
|
147 |
st.error(f"Error: {str(e)}")
|
148 |
if attempt < max_retries - 1:
|
@@ -152,7 +150,7 @@ def get_ai_response(user_input: str) -> str:
|
|
152 |
|
153 |
def process_message(user_input: str) -> None:
|
154 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
155 |
-
|
156 |
with st.spinner("جاري التفكير..."):
|
157 |
ai_response = get_ai_response(user_input)
|
158 |
if ai_response:
|
@@ -160,54 +158,70 @@ def process_message(user_input: str) -> None:
|
|
160 |
|
161 |
def main():
|
162 |
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
|
163 |
-
|
164 |
st.title("Darija AI Therapist 🧠")
|
165 |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
166 |
-
|
167 |
-
# Chat interface
|
168 |
-
# Create columns for text input and mic button
|
169 |
col1, col2 = st.columns([9, 1])
|
170 |
with col1:
|
171 |
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
|
172 |
with col2:
|
173 |
-
# Mic button
|
174 |
if st.session_state.recording:
|
175 |
-
|
176 |
else:
|
177 |
-
|
178 |
-
|
179 |
-
if st.button(
|
180 |
st.session_state.recording = not st.session_state.recording
|
181 |
if st.session_state.recording:
|
182 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
else:
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
else:
|
196 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
-
# Handle text submission
|
199 |
if user_input:
|
200 |
process_message(user_input)
|
201 |
-
st.session_state.text_input = "" # Clear input field after sending
|
202 |
|
203 |
# Display chat history
|
204 |
-
for
|
205 |
-
|
206 |
-
|
207 |
-
if role == "user":
|
208 |
-
message(content, is_user=True)
|
209 |
-
else:
|
210 |
-
message(content)
|
211 |
|
212 |
if __name__ == "__main__":
|
213 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
3 |
+
from typing import List
|
4 |
from langchain_community.llms import HuggingFaceEndpoint
|
5 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
6 |
+
from langchain.memory import ConversationBufferMemory
|
7 |
+
from langchain.chains import ConversationalRetrievalChain
|
8 |
from langchain_community.vectorstores import FAISS
|
9 |
+
from langchain.prompts import PromptTemplate
|
|
|
|
|
|
|
|
|
10 |
import os
|
11 |
from dotenv import load_dotenv
|
12 |
import requests
|
13 |
from requests.adapters import HTTPAdapter
|
14 |
from requests.packages.urllib3.util.retry import Retry
|
15 |
+
import whisper
|
16 |
+
import numpy as np
|
17 |
+
import av
|
18 |
+
import time # Added import time
|
19 |
+
import queue
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv()
|
|
|
24 |
# Initialize session state
|
25 |
if "messages" not in st.session_state:
|
26 |
st.session_state.messages = []
|
27 |
+
|
28 |
+
if "audio_buffer" not in st.session_state:
|
29 |
+
st.session_state.audio_buffer = queue.Queue()
|
30 |
+
|
31 |
+
if 'recording' not in st.session_state:
|
32 |
st.session_state.recording = False
|
33 |
+
|
34 |
+
if 'webrtc_ctx' not in st.session_state:
|
35 |
+
st.session_state.webrtc_ctx = None
|
36 |
|
37 |
# Prompt template
|
38 |
PROMPT_TEMPLATE = """
|
|
|
79 |
)
|
80 |
|
81 |
# Setup memory and conversation chain
|
82 |
+
memory = ConversationBufferMemory(
|
83 |
memory_key="chat_history",
|
84 |
return_messages=True
|
85 |
)
|
|
|
98 |
input_variables=["context", "chat_history", "question"]
|
99 |
)
|
100 |
|
101 |
+
conversation_chain = ConversationalRetrievalChain.from_llm(
|
102 |
+
llm=llm,
|
103 |
+
retriever=vectorstore.as_retriever(),
|
104 |
+
memory=memory,
|
105 |
+
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
106 |
+
return_source_documents=True,
|
107 |
+
output_key='answer' # Specify output_key to fix the error
|
108 |
+
)
|
109 |
+
|
110 |
+
def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
|
111 |
+
audio = frame.to_ndarray().flatten()
|
112 |
+
st.session_state.audio_buffer.put(audio)
|
113 |
+
return frame
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
def get_ai_response(user_input: str) -> str:
|
116 |
max_retries = 3
|
|
|
126 |
"question": user_input,
|
127 |
"chat_history": memory.chat_memory.messages[-5:]
|
128 |
})
|
129 |
+
|
130 |
if not response or 'answer' not in response:
|
131 |
if attempt < max_retries - 1:
|
132 |
time.sleep(2 ** attempt)
|
133 |
continue
|
134 |
return "عذراً، كاين مشكل. حاول مرة أخرى."
|
135 |
+
|
136 |
return response['answer']
|
137 |
|
138 |
except requests.exceptions.HTTPError as e:
|
|
|
140 |
time.sleep(2 ** attempt)
|
141 |
continue
|
142 |
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
|
143 |
+
|
144 |
except Exception as e:
|
145 |
st.error(f"Error: {str(e)}")
|
146 |
if attempt < max_retries - 1:
|
|
|
150 |
|
151 |
def process_message(user_input: str) -> None:
|
152 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
153 |
+
|
154 |
with st.spinner("جاري التفكير..."):
|
155 |
ai_response = get_ai_response(user_input)
|
156 |
if ai_response:
|
|
|
158 |
|
159 |
def main():
|
160 |
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
|
161 |
+
|
162 |
st.title("Darija AI Therapist 🧠")
|
163 |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
164 |
+
|
|
|
|
|
165 |
col1, col2 = st.columns([9, 1])
|
166 |
with col1:
|
167 |
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
|
168 |
with col2:
|
|
|
169 |
if st.session_state.recording:
|
170 |
+
mic_icon = "🛑"
|
171 |
else:
|
172 |
+
mic_icon = "🎤"
|
173 |
+
|
174 |
+
if st.button(mic_icon):
|
175 |
st.session_state.recording = not st.session_state.recording
|
176 |
if st.session_state.recording:
|
177 |
+
st.session_state.audio_buffer = queue.Queue()
|
178 |
+
st.session_state.webrtc_ctx = webrtc_streamer(
|
179 |
+
key="speech-to-text",
|
180 |
+
mode=WebRtcMode.SENDONLY,
|
181 |
+
audio_receiver_size=256,
|
182 |
+
rtc_configuration=RTCConfiguration(
|
183 |
+
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
184 |
+
),
|
185 |
+
media_stream_constraints={"video": False, "audio": True},
|
186 |
+
async_processing=True,
|
187 |
+
audio_frame_callback=audio_frame_callback,
|
188 |
+
)
|
189 |
else:
|
190 |
+
st.info("🔄 Processing audio...")
|
191 |
+
audio_frames = []
|
192 |
+
while not st.session_state.audio_buffer.empty():
|
193 |
+
audio_frames.append(st.session_state.audio_buffer.get())
|
194 |
+
|
195 |
+
if audio_frames:
|
196 |
+
audio_data = np.concatenate(audio_frames, axis=0).flatten()
|
197 |
+
# Convert to 16-bit integers
|
198 |
+
audio_data_int16 = (audio_data * 32767).astype(np.int16)
|
199 |
+
# Use Whisper to transcribe
|
200 |
+
result = whisper_model.transcribe(audio_data_int16, fp16=False)
|
201 |
+
text = result.get("text", "")
|
202 |
+
if text:
|
203 |
+
process_message(text)
|
204 |
+
else:
|
205 |
+
st.warning("ما فهمتش الصوت. حاول مرة أخرى.")
|
206 |
+
st.session_state.audio_buffer = queue.Queue()
|
207 |
else:
|
208 |
+
st.warning("ما تسجلش الصوت. حاول مرة أخرى.")
|
209 |
+
if st.session_state.webrtc_ctx:
|
210 |
+
st.session_state.webrtc_ctx.stop()
|
211 |
+
st.session_state.webrtc_ctx = None
|
212 |
+
|
213 |
+
if st.session_state.recording:
|
214 |
+
st.info("🎙️ Recording...")
|
215 |
+
else:
|
216 |
+
st.empty()
|
217 |
|
|
|
218 |
if user_input:
|
219 |
process_message(user_input)
|
|
|
220 |
|
221 |
# Display chat history
|
222 |
+
for message in st.session_state.messages:
|
223 |
+
with st.chat_message(message["role"]):
|
224 |
+
st.write(message["content"])
|
|
|
|
|
|
|
|
|
225 |
|
226 |
if __name__ == "__main__":
|
227 |
main()
|