jaafarhh commited on
Commit
e49ad3d
·
verified ·
1 Parent(s): 4dcf57d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -217
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
2
- import torch
3
- import torchaudio
4
- import soundfile as sf
5
- from pathlib import Path
6
- from transformers import pipeline, AutoTokenizer
 
7
  from langchain_community.llms import HuggingFaceEndpoint
8
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.memory import ConversationBufferMemory
@@ -12,26 +13,20 @@ from langchain_community.vectorstores import FAISS
12
  from langchain.prompts import PromptTemplate
13
  import os
14
  from dotenv import load_dotenv
 
15
  from requests.adapters import HTTPAdapter
16
  from requests.packages.urllib3.util.retry import Retry
17
- import requests
18
- import time
19
 
20
  # Load environment variables
21
  load_dotenv()
22
 
23
- # CSS styling
24
- css = """
25
- <style>
26
- .chat-message { padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex; }
27
- .chat-message.user { background-color: #2b313e; }
28
- .chat-message.bot { background-color: #475063; }
29
- .avatar { margin-right: 1rem; }
30
- .message { color: white; }
31
- </style>
32
- """
33
 
34
- # Updated prompt template for Mixtral
35
  PROMPT_TEMPLATE = """
36
  <s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
37
  Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
@@ -46,215 +41,157 @@ Context: {context}
46
  [/INST]
47
  """
48
 
49
- class DarijaTherapist:
50
- def __init__(self):
51
- self.setup_models()
52
- self.initialize_session_state()
53
- self.setup_memory()
54
-
55
- def setup_models(self):
56
- try:
57
- # Speech recognition setup
58
- tokenizer = AutoTokenizer.from_pretrained("facebook/seamless-m4t-v2-large")
59
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
60
- self.asr_pipe = pipeline(
61
- "automatic-speech-recognition",
62
- model="facebook/seamless-m4t-v2-large",
63
- tokenizer=tokenizer,
64
- device=self.device
65
- )
66
-
67
- # Configure retry strategy
68
- retry_strategy = Retry(
69
- total=3,
70
- backoff_factor=1,
71
- status_forcelist=[429, 500, 502, 503, 504]
72
- )
73
-
74
- # Create session with retry strategy
75
- session = requests.Session()
76
- session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
77
-
78
- # Updated LLM setup for Mixtral
79
- self.llm = HuggingFaceEndpoint(
80
- endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
81
- task="text-generation",
82
- temperature=0.7,
83
- do_sample=True,
84
- return_full_text=False,
85
- timeout=300,
86
- model_kwargs={
87
- "max_new_tokens": 2048,
88
- "top_p": 0.9,
89
- "repetition_penalty": 1.2,
90
- "return_text": True,
91
- "stop": ["</s>"]
92
- },
93
- huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
94
- client=session
95
- )
96
-
97
- # Embeddings setup
98
- self.embeddings = HuggingFaceBgeEmbeddings(
99
- model_name="BAAI/bge-large-en"
100
- )
101
-
102
- self.vectorstore = FAISS.from_texts(
103
- ["Initial therapeutic context"],
104
- self.embeddings
105
- )
106
- except Exception as e:
107
- st.error(f"Error setting up models: {str(e)}")
108
- st.stop()
109
-
110
- def setup_memory(self):
111
- self.memory = ConversationBufferMemory(
112
- memory_key="chat_history",
113
- return_messages=True
114
- )
115
-
116
- qa_prompt = PromptTemplate(
117
- template=PROMPT_TEMPLATE,
118
- input_variables=["context", "chat_history", "question"]
119
- )
120
-
121
- self.conversation_chain = ConversationalRetrievalChain.from_llm(
122
- llm=self.llm,
123
- retriever=self.vectorstore.as_retriever(),
124
- memory=self.memory,
125
- combine_docs_chain_kwargs={"prompt": qa_prompt},
126
- return_source_documents=True
127
- )
128
-
129
- def initialize_session_state(self):
130
- if "messages" not in st.session_state:
131
- st.session_state.messages = []
132
- if "recording" not in st.session_state:
133
- st.session_state.recording = False
134
- if "audio_buffer" not in st.session_state:
135
- st.session_state.audio_buffer = []
136
-
137
- def handle_audio_input(self):
138
- if not st.session_state.recording:
139
- return
140
-
141
- try:
142
- waveform, sample_rate = torchaudio.load("temp_audio.wav")
143
- st.session_state.audio_buffer.append(waveform)
144
- except Exception as e:
145
- st.error(f"Error recording audio: {str(e)}")
146
-
147
- def process_audio(self):
148
- if not st.session_state.audio_buffer:
149
- return None
150
-
151
  try:
152
- audio_data = torch.cat(st.session_state.audio_buffer, dim=1)
153
- torchaudio.save("temp_audio.wav", audio_data, 16000)
154
- audio, rate = sf.read("temp_audio.wav", dtype='float32')
155
- result = self.asr_pipe(
156
- audio,
157
- generate_kwargs={"task": "transcribe", "language": "ara"}
158
- )
159
- return result["text"]
160
- except Exception as e:
161
- st.error(f"Error processing audio: {str(e)}")
162
- return None
163
- finally:
164
- st.session_state.audio_buffer = []
165
 
166
- def get_ai_response(self, user_input):
167
- max_retries = 3
168
- for attempt in range(max_retries):
169
- try:
170
- # Validate and clean input
171
- if not user_input or len(user_input.strip()) == 0:
172
- return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
173
 
174
- # Limit input length to prevent tensor size issues
175
- if len(user_input) > 512:
176
- user_input = user_input[:512]
177
-
178
- response = self.conversation_chain({
179
- "question": user_input,
180
- "chat_history": self.memory.chat_memory.messages[-5:] # Limit context window
181
- })
182
-
183
- if not response or 'answer' not in response:
184
- if attempt < max_retries - 1:
185
- time.sleep(2 ** attempt)
186
- continue
187
- return "عذراً، كاين مشكل. حاول مرة أخرى."
188
-
189
- return response['answer']
190
-
191
- except requests.exceptions.HTTPError as e:
192
- if e.response.status_code == 424:
193
- if attempt < max_retries - 1:
194
- st.warning("Model error, retrying with simplified input...")
195
- time.sleep(2 ** attempt)
196
- continue
197
- return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
198
-
199
- except requests.exceptions.ReadTimeout:
200
- if attempt < max_retries - 1:
201
- st.warning(f"Attempt {attempt + 1} timed out, retrying...")
202
- time.sleep(2 ** attempt)
203
- continue
204
- return "عذراً، الخادم بطيء حالياً. حاول مرة أخرى."
205
 
206
- except Exception as e:
207
- st.error(f"Error: {str(e)}")
208
  if attempt < max_retries - 1:
209
  time.sleep(2 ** attempt)
210
  continue
211
- return "عذراً، كاين شي مشكل. حاول مرة أخرى."
212
-
213
- def run(self):
214
- st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
215
- st.markdown(css, unsafe_allow_html=True)
216
-
217
- st.title("Darija AI Therapist 🧠")
218
- st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
219
-
220
- with st.sidebar:
221
- st.header("Settings ⚙️")
222
- if st.button("Clear Chat History"):
223
- st.session_state.messages = []
224
- self.memory.clear()
225
 
226
- st.markdown("### About")
227
- st.info("This AI therapist speaks Darija and is here to help.")
228
 
229
- col1, col2 = st.columns(2)
230
- with col1:
231
- if st.button("🎤 Start Recording", disabled=st.session_state.recording):
232
- st.session_state.recording = True
233
- st.session_state.audio_buffer = []
234
 
235
- with col2:
236
- if st.button("⏹️ Stop Recording", disabled=not st.session_state.recording):
237
- st.session_state.recording = False
238
- transcription = self.process_audio()
239
- if transcription:
240
- self.process_message(transcription)
241
-
242
- user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
243
- if user_input:
244
- self.process_message(user_input)
245
-
246
- for message in st.session_state.messages:
247
- with st.chat_message(message["role"]):
248
- st.write(message["content"])
249
-
250
- def process_message(self, user_input):
251
- st.session_state.messages.append({"role": "user", "content": user_input})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- with st.spinner("جاري التفكير..."):
254
- ai_response = self.get_ai_response(user_input)
255
- if ai_response:
256
- st.session_state.messages.append({"role": "assistant", "content": ai_response})
 
 
 
 
 
 
 
257
 
258
  if __name__ == "__main__":
259
- app = DarijaTherapist()
260
- app.run()
 
1
  import streamlit as st
2
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
3
+ import whisper
4
+ import numpy as np
5
+ import av
6
+ from typing import List
7
+ import queue
8
  from langchain_community.llms import HuggingFaceEndpoint
9
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
10
  from langchain.memory import ConversationBufferMemory
 
13
  from langchain.prompts import PromptTemplate
14
  import os
15
  from dotenv import load_dotenv
16
+ import requests
17
  from requests.adapters import HTTPAdapter
18
  from requests.packages.urllib3.util.retry import Retry
 
 
19
 
20
  # Load environment variables
21
  load_dotenv()
22
 
23
+ # Initialize session state
24
+ if "messages" not in st.session_state:
25
+ st.session_state.messages = []
26
+ if "audio_buffer" not in st.session_state:
27
+ st.session_state.audio_buffer = queue.Queue()
 
 
 
 
 
28
 
29
+ # Prompt template
30
  PROMPT_TEMPLATE = """
31
  <s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
32
  Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
 
41
  [/INST]
42
  """
43
 
44
+ # Setup retry strategy
45
+ retry_strategy = Retry(
46
+ total=3,
47
+ backoff_factor=1,
48
+ status_forcelist=[429, 500, 502, 503, 504]
49
+ )
50
+
51
+ session = requests.Session()
52
+ session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
53
+
54
+ # Initialize models
55
+ whisper_model = whisper.load_model("base")
56
+ llm = HuggingFaceEndpoint(
57
+ endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
58
+ task="text-generation",
59
+ temperature=0.7,
60
+ do_sample=True,
61
+ return_full_text=False,
62
+ max_new_tokens=2048,
63
+ top_p=0.9,
64
+ repetition_penalty=1.2,
65
+ model_kwargs={
66
+ "return_text": True,
67
+ "stop": ["</s>"]
68
+ },
69
+ huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
70
+ client=session
71
+ )
72
+
73
+ # Setup memory and conversation chain
74
+ memory = ConversationBufferMemory(
75
+ memory_key="chat_history",
76
+ return_messages=True
77
+ )
78
+
79
+ embeddings = HuggingFaceBgeEmbeddings(
80
+ model_name="BAAI/bge-large-en"
81
+ )
82
+
83
+ vectorstore = FAISS.from_texts(
84
+ ["Initial therapeutic context"],
85
+ embeddings
86
+ )
87
+
88
+ qa_prompt = PromptTemplate(
89
+ template=PROMPT_TEMPLATE,
90
+ input_variables=["context", "chat_history", "question"]
91
+ )
92
+
93
+ conversation_chain = ConversationalRetrievalChain.from_llm(
94
+ llm=llm,
95
+ retriever=vectorstore.as_retriever(),
96
+ memory=memory,
97
+ combine_docs_chain_kwargs={"prompt": qa_prompt},
98
+ return_source_documents=True
99
+ )
100
+
101
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
102
+ return frame
103
+
104
+ def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
105
+ if st.session_state.recording:
106
+ sound = frame.to_ndarray()
107
+ st.session_state.audio_buffer.put(sound)
108
+ return frame
109
+
110
+ def get_ai_response(user_input: str) -> str:
111
+ max_retries = 3
112
+ for attempt in range(max_retries):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  try:
114
+ if not user_input or len(user_input.strip()) == 0:
115
+ return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ if len(user_input) > 512:
118
+ user_input = user_input[:512]
 
 
 
 
 
119
 
120
+ response = conversation_chain({
121
+ "question": user_input,
122
+ "chat_history": memory.chat_memory.messages[-5:]
123
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ if not response or 'answer' not in response:
 
126
  if attempt < max_retries - 1:
127
  time.sleep(2 ** attempt)
128
  continue
129
+ return "عذراً، كاين مشكل. حاول مرة أخرى."
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ return response['answer']
 
132
 
133
+ except requests.exceptions.HTTPError as e:
134
+ if attempt < max_retries - 1:
135
+ time.sleep(2 ** attempt)
136
+ continue
137
+ return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
138
 
139
+ except Exception as e:
140
+ st.error(f"Error: {str(e)}")
141
+ if attempt < max_retries - 1:
142
+ time.sleep(2 ** attempt)
143
+ continue
144
+ return "عذراً، كاين شي مشكل. حاول مرة أخرى."
145
+
146
+ def process_message(user_input: str) -> None:
147
+ st.session_state.messages.append({"role": "user", "content": user_input})
148
+
149
+ with st.spinner("جاري التفكير..."):
150
+ ai_response = get_ai_response(user_input)
151
+ if ai_response:
152
+ st.session_state.messages.append({"role": "assistant", "content": ai_response})
153
+
154
+ def main():
155
+ st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
156
+
157
+ st.title("Darija AI Therapist 🧠")
158
+ st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
159
+
160
+ # WebRTC setup
161
+ webrtc_ctx = webrtc_streamer(
162
+ key="speech-to-text",
163
+ mode=WebRtcMode.SENDONLY,
164
+ audio_receiver_size=1024,
165
+ rtc_configuration=RTCConfiguration(
166
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
167
+ ),
168
+ video_frame_callback=video_frame_callback,
169
+ audio_frame_callback=audio_frame_callback,
170
+ media_stream_constraints={"video": False, "audio": True},
171
+ )
172
+
173
+ # Chat interface
174
+ user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
175
+ if user_input:
176
+ process_message(user_input)
177
+
178
+ # Process audio when recording stops
179
+ if webrtc_ctx.state.playing and len(st.session_state.audio_buffer) > 0:
180
+ audio_frames = []
181
+ while not st.session_state.audio_buffer.empty():
182
+ audio_frames.append(st.session_state.audio_buffer.get())
183
 
184
+ if audio_frames:
185
+ audio_data = np.concatenate(audio_frames, axis=0)
186
+ text = whisper_model.transcribe(audio_data)["text"]
187
+ if text:
188
+ process_message(text)
189
+ st.session_state.audio_buffer = queue.Queue() # Clear buffer
190
+
191
+ # Display chat history
192
+ for message in st.session_state.messages:
193
+ with st.chat_message(message["role"]):
194
+ st.write(message["content"])
195
 
196
  if __name__ == "__main__":
197
+ main()