Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
from
|
|
|
7 |
from langchain_community.llms import HuggingFaceEndpoint
|
8 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
9 |
from langchain.memory import ConversationBufferMemory
|
@@ -12,26 +13,20 @@ from langchain_community.vectorstores import FAISS
|
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
import os
|
14 |
from dotenv import load_dotenv
|
|
|
15 |
from requests.adapters import HTTPAdapter
|
16 |
from requests.packages.urllib3.util.retry import Retry
|
17 |
-
import requests
|
18 |
-
import time
|
19 |
|
20 |
# Load environment variables
|
21 |
load_dotenv()
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
.
|
28 |
-
.chat-message.bot { background-color: #475063; }
|
29 |
-
.avatar { margin-right: 1rem; }
|
30 |
-
.message { color: white; }
|
31 |
-
</style>
|
32 |
-
"""
|
33 |
|
34 |
-
#
|
35 |
PROMPT_TEMPLATE = """
|
36 |
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
|
37 |
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
|
@@ -46,215 +41,157 @@ Context: {context}
|
|
46 |
[/INST]
|
47 |
"""
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
input_variables=["context", "chat_history", "question"]
|
119 |
-
)
|
120 |
-
|
121 |
-
self.conversation_chain = ConversationalRetrievalChain.from_llm(
|
122 |
-
llm=self.llm,
|
123 |
-
retriever=self.vectorstore.as_retriever(),
|
124 |
-
memory=self.memory,
|
125 |
-
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
126 |
-
return_source_documents=True
|
127 |
-
)
|
128 |
-
|
129 |
-
def initialize_session_state(self):
|
130 |
-
if "messages" not in st.session_state:
|
131 |
-
st.session_state.messages = []
|
132 |
-
if "recording" not in st.session_state:
|
133 |
-
st.session_state.recording = False
|
134 |
-
if "audio_buffer" not in st.session_state:
|
135 |
-
st.session_state.audio_buffer = []
|
136 |
-
|
137 |
-
def handle_audio_input(self):
|
138 |
-
if not st.session_state.recording:
|
139 |
-
return
|
140 |
-
|
141 |
-
try:
|
142 |
-
waveform, sample_rate = torchaudio.load("temp_audio.wav")
|
143 |
-
st.session_state.audio_buffer.append(waveform)
|
144 |
-
except Exception as e:
|
145 |
-
st.error(f"Error recording audio: {str(e)}")
|
146 |
-
|
147 |
-
def process_audio(self):
|
148 |
-
if not st.session_state.audio_buffer:
|
149 |
-
return None
|
150 |
-
|
151 |
try:
|
152 |
-
|
153 |
-
|
154 |
-
audio, rate = sf.read("temp_audio.wav", dtype='float32')
|
155 |
-
result = self.asr_pipe(
|
156 |
-
audio,
|
157 |
-
generate_kwargs={"task": "transcribe", "language": "ara"}
|
158 |
-
)
|
159 |
-
return result["text"]
|
160 |
-
except Exception as e:
|
161 |
-
st.error(f"Error processing audio: {str(e)}")
|
162 |
-
return None
|
163 |
-
finally:
|
164 |
-
st.session_state.audio_buffer = []
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
for attempt in range(max_retries):
|
169 |
-
try:
|
170 |
-
# Validate and clean input
|
171 |
-
if not user_input or len(user_input.strip()) == 0:
|
172 |
-
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
response = self.conversation_chain({
|
179 |
-
"question": user_input,
|
180 |
-
"chat_history": self.memory.chat_memory.messages[-5:] # Limit context window
|
181 |
-
})
|
182 |
-
|
183 |
-
if not response or 'answer' not in response:
|
184 |
-
if attempt < max_retries - 1:
|
185 |
-
time.sleep(2 ** attempt)
|
186 |
-
continue
|
187 |
-
return "عذراً، كاين مشكل. حاول مرة أخرى."
|
188 |
-
|
189 |
-
return response['answer']
|
190 |
-
|
191 |
-
except requests.exceptions.HTTPError as e:
|
192 |
-
if e.response.status_code == 424:
|
193 |
-
if attempt < max_retries - 1:
|
194 |
-
st.warning("Model error, retrying with simplified input...")
|
195 |
-
time.sleep(2 ** attempt)
|
196 |
-
continue
|
197 |
-
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
|
198 |
-
|
199 |
-
except requests.exceptions.ReadTimeout:
|
200 |
-
if attempt < max_retries - 1:
|
201 |
-
st.warning(f"Attempt {attempt + 1} timed out, retrying...")
|
202 |
-
time.sleep(2 ** attempt)
|
203 |
-
continue
|
204 |
-
return "عذراً، الخادم بطيء حالياً. حاول مرة أخرى."
|
205 |
|
206 |
-
|
207 |
-
st.error(f"Error: {str(e)}")
|
208 |
if attempt < max_retries - 1:
|
209 |
time.sleep(2 ** attempt)
|
210 |
continue
|
211 |
-
return "عذراً، كاين
|
212 |
-
|
213 |
-
def run(self):
|
214 |
-
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
|
215 |
-
st.markdown(css, unsafe_allow_html=True)
|
216 |
-
|
217 |
-
st.title("Darija AI Therapist 🧠")
|
218 |
-
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
219 |
-
|
220 |
-
with st.sidebar:
|
221 |
-
st.header("Settings ⚙️")
|
222 |
-
if st.button("Clear Chat History"):
|
223 |
-
st.session_state.messages = []
|
224 |
-
self.memory.clear()
|
225 |
|
226 |
-
|
227 |
-
st.info("This AI therapist speaks Darija and is here to help.")
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
if __name__ == "__main__":
|
259 |
-
|
260 |
-
app.run()
|
|
|
1 |
import streamlit as st
|
2 |
+
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
3 |
+
import whisper
|
4 |
+
import numpy as np
|
5 |
+
import av
|
6 |
+
from typing import List
|
7 |
+
import queue
|
8 |
from langchain_community.llms import HuggingFaceEndpoint
|
9 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
10 |
from langchain.memory import ConversationBufferMemory
|
|
|
13 |
from langchain.prompts import PromptTemplate
|
14 |
import os
|
15 |
from dotenv import load_dotenv
|
16 |
+
import requests
|
17 |
from requests.adapters import HTTPAdapter
|
18 |
from requests.packages.urllib3.util.retry import Retry
|
|
|
|
|
19 |
|
20 |
# Load environment variables
|
21 |
load_dotenv()
|
22 |
|
23 |
+
# Initialize session state
|
24 |
+
if "messages" not in st.session_state:
|
25 |
+
st.session_state.messages = []
|
26 |
+
if "audio_buffer" not in st.session_state:
|
27 |
+
st.session_state.audio_buffer = queue.Queue()
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Prompt template
|
30 |
PROMPT_TEMPLATE = """
|
31 |
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
|
32 |
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
|
|
|
41 |
[/INST]
|
42 |
"""
|
43 |
|
44 |
+
# Setup retry strategy
|
45 |
+
retry_strategy = Retry(
|
46 |
+
total=3,
|
47 |
+
backoff_factor=1,
|
48 |
+
status_forcelist=[429, 500, 502, 503, 504]
|
49 |
+
)
|
50 |
+
|
51 |
+
session = requests.Session()
|
52 |
+
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
|
53 |
+
|
54 |
+
# Initialize models
|
55 |
+
whisper_model = whisper.load_model("base")
|
56 |
+
llm = HuggingFaceEndpoint(
|
57 |
+
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
|
58 |
+
task="text-generation",
|
59 |
+
temperature=0.7,
|
60 |
+
do_sample=True,
|
61 |
+
return_full_text=False,
|
62 |
+
max_new_tokens=2048,
|
63 |
+
top_p=0.9,
|
64 |
+
repetition_penalty=1.2,
|
65 |
+
model_kwargs={
|
66 |
+
"return_text": True,
|
67 |
+
"stop": ["</s>"]
|
68 |
+
},
|
69 |
+
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
|
70 |
+
client=session
|
71 |
+
)
|
72 |
+
|
73 |
+
# Setup memory and conversation chain
|
74 |
+
memory = ConversationBufferMemory(
|
75 |
+
memory_key="chat_history",
|
76 |
+
return_messages=True
|
77 |
+
)
|
78 |
+
|
79 |
+
embeddings = HuggingFaceBgeEmbeddings(
|
80 |
+
model_name="BAAI/bge-large-en"
|
81 |
+
)
|
82 |
+
|
83 |
+
vectorstore = FAISS.from_texts(
|
84 |
+
["Initial therapeutic context"],
|
85 |
+
embeddings
|
86 |
+
)
|
87 |
+
|
88 |
+
qa_prompt = PromptTemplate(
|
89 |
+
template=PROMPT_TEMPLATE,
|
90 |
+
input_variables=["context", "chat_history", "question"]
|
91 |
+
)
|
92 |
+
|
93 |
+
conversation_chain = ConversationalRetrievalChain.from_llm(
|
94 |
+
llm=llm,
|
95 |
+
retriever=vectorstore.as_retriever(),
|
96 |
+
memory=memory,
|
97 |
+
combine_docs_chain_kwargs={"prompt": qa_prompt},
|
98 |
+
return_source_documents=True
|
99 |
+
)
|
100 |
+
|
101 |
+
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
|
102 |
+
return frame
|
103 |
+
|
104 |
+
def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
|
105 |
+
if st.session_state.recording:
|
106 |
+
sound = frame.to_ndarray()
|
107 |
+
st.session_state.audio_buffer.put(sound)
|
108 |
+
return frame
|
109 |
+
|
110 |
+
def get_ai_response(user_input: str) -> str:
|
111 |
+
max_retries = 3
|
112 |
+
for attempt in range(max_retries):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
try:
|
114 |
+
if not user_input or len(user_input.strip()) == 0:
|
115 |
+
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
if len(user_input) > 512:
|
118 |
+
user_input = user_input[:512]
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
response = conversation_chain({
|
121 |
+
"question": user_input,
|
122 |
+
"chat_history": memory.chat_memory.messages[-5:]
|
123 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
if not response or 'answer' not in response:
|
|
|
126 |
if attempt < max_retries - 1:
|
127 |
time.sleep(2 ** attempt)
|
128 |
continue
|
129 |
+
return "عذراً، كاين مشكل. حاول مرة أخرى."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
+
return response['answer']
|
|
|
132 |
|
133 |
+
except requests.exceptions.HTTPError as e:
|
134 |
+
if attempt < max_retries - 1:
|
135 |
+
time.sleep(2 ** attempt)
|
136 |
+
continue
|
137 |
+
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
|
138 |
|
139 |
+
except Exception as e:
|
140 |
+
st.error(f"Error: {str(e)}")
|
141 |
+
if attempt < max_retries - 1:
|
142 |
+
time.sleep(2 ** attempt)
|
143 |
+
continue
|
144 |
+
return "عذراً، كاين شي مشكل. حاول مرة أخرى."
|
145 |
+
|
146 |
+
def process_message(user_input: str) -> None:
|
147 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
148 |
+
|
149 |
+
with st.spinner("جاري التفكير..."):
|
150 |
+
ai_response = get_ai_response(user_input)
|
151 |
+
if ai_response:
|
152 |
+
st.session_state.messages.append({"role": "assistant", "content": ai_response})
|
153 |
+
|
154 |
+
def main():
|
155 |
+
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
|
156 |
+
|
157 |
+
st.title("Darija AI Therapist 🧠")
|
158 |
+
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
159 |
+
|
160 |
+
# WebRTC setup
|
161 |
+
webrtc_ctx = webrtc_streamer(
|
162 |
+
key="speech-to-text",
|
163 |
+
mode=WebRtcMode.SENDONLY,
|
164 |
+
audio_receiver_size=1024,
|
165 |
+
rtc_configuration=RTCConfiguration(
|
166 |
+
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
167 |
+
),
|
168 |
+
video_frame_callback=video_frame_callback,
|
169 |
+
audio_frame_callback=audio_frame_callback,
|
170 |
+
media_stream_constraints={"video": False, "audio": True},
|
171 |
+
)
|
172 |
+
|
173 |
+
# Chat interface
|
174 |
+
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
|
175 |
+
if user_input:
|
176 |
+
process_message(user_input)
|
177 |
+
|
178 |
+
# Process audio when recording stops
|
179 |
+
if webrtc_ctx.state.playing and len(st.session_state.audio_buffer) > 0:
|
180 |
+
audio_frames = []
|
181 |
+
while not st.session_state.audio_buffer.empty():
|
182 |
+
audio_frames.append(st.session_state.audio_buffer.get())
|
183 |
|
184 |
+
if audio_frames:
|
185 |
+
audio_data = np.concatenate(audio_frames, axis=0)
|
186 |
+
text = whisper_model.transcribe(audio_data)["text"]
|
187 |
+
if text:
|
188 |
+
process_message(text)
|
189 |
+
st.session_state.audio_buffer = queue.Queue() # Clear buffer
|
190 |
+
|
191 |
+
# Display chat history
|
192 |
+
for message in st.session_state.messages:
|
193 |
+
with st.chat_message(message["role"]):
|
194 |
+
st.write(message["content"])
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
+
main()
|
|