jaafarhh commited on
Commit
cb6b5d0
·
verified ·
1 Parent(s): cdb5aa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -94
app.py CHANGED
@@ -1,109 +1,188 @@
1
  import streamlit as st
2
  import torch
3
- from audio_handler import AudioHandler
4
- from chat_handler import ChatHandler
 
5
  from transformers import pipeline
6
- import base64
7
- import time
8
-
9
- # Page configuration
10
- st.set_page_config(
11
- page_title="Darija Mental Health Assistant",
12
- page_icon="🤗",
13
- layout="centered"
14
- )
15
-
16
- # Custom CSS
17
- st.markdown("""
 
 
18
  <style>
19
- .stButton > button {
20
- background-color: #4CAF50;
21
- color: white;
22
- padding: 10px 20px;
23
- border-radius: 5px;
24
- }
25
- .main-header {
26
- text-align: center;
27
- color: #2e7d32;
28
- }
29
- .chat-container {
30
- padding: 20px;
31
- border-radius: 10px;
32
- background-color: #f5f5f5;
33
- margin: 10px 0;
34
- }
35
  </style>
36
- """, unsafe_allow_html=True)
37
-
38
- # Initialize session state
39
- if 'messages' not in st.session_state:
40
- st.session_state.messages = []
41
- if 'recording' not in st.session_state:
42
- st.session_state.recording = False
43
- if 'audio_handler' not in st.session_state:
44
- st.session_state.audio_handler = AudioHandler()
45
- if 'chat_handler' not in st.session_state:
46
- st.session_state.chat_handler = ChatHandler()
47
-
48
- def main():
49
- # Header
50
- st.markdown("<h1 class='main-header'>Darija Mental Health Assistant 🤗</h1>", unsafe_allow_html=True)
51
- st.markdown("<h3 style='text-align: center;'>تحدث معي بالدارجة عن مشاعرك</h3>", unsafe_allow_html=True)
52
-
53
- # Sidebar for settings
54
- with st.sidebar:
55
- st.header("Settings ⚙️")
56
- voice_enabled = st.toggle("Enable Voice Input 🎤", True)
57
- st.divider()
58
- st.markdown("### About")
59
- st.info("This assistant provides mental health support in Moroccan Arabic (Darija). Feel free to speak or type your thoughts.")
60
-
61
- # Main chat interface
62
- st.markdown("<div class='chat-container'>", unsafe_allow_html=True)
63
-
64
- # Voice input section
65
- if voice_enabled:
66
- cols = st.columns([1, 1])
67
- with cols[0]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if st.button("🎤 Start Recording", disabled=st.session_state.recording):
69
  st.session_state.recording = True
70
- st.session_state.audio_handler.start_recording()
71
- st.experimental_rerun()
72
 
73
- with cols[1]:
74
  if st.button("⏹️ Stop Recording", disabled=not st.session_state.recording):
75
- if st.session_state.recording:
76
- audio_file = st.session_state.audio_handler.stop_recording()
77
- with st.spinner("Processing your message..."):
78
- transcription = st.session_state.audio_handler.transcribe_audio(audio_file)
79
- process_input(transcription)
80
- st.session_state.recording = False
81
- st.experimental_rerun()
82
-
83
- # Text input
84
- user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
85
- if user_input:
86
- process_input(user_input)
87
-
88
- # Display chat history
89
- display_chat_history()
90
-
91
- st.markdown("</div>", unsafe_allow_html=True)
92
-
93
- def process_input(user_input):
94
- if user_input:
95
  # Add user message
96
  st.session_state.messages.append({"role": "user", "content": user_input})
97
 
98
- # Get bot response
99
  with st.spinner("جاري التفكير..."):
100
- response = st.session_state.chat_handler.get_response(user_input)
101
- st.session_state.messages.append({"role": "assistant", "content": response})
102
-
103
- def display_chat_history():
104
- for message in st.session_state.messages:
105
- with st.chat_message(message["role"]):
106
- st.write(message["content"])
107
 
108
  if __name__ == "__main__":
109
- main()
 
 
1
  import streamlit as st
2
  import torch
3
+ import torchaudio
4
+ import soundfile as sf
5
+ from pathlib import Path
6
  from transformers import pipeline
7
+ from langchain.memory import ConversationBufferMemory
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.llms import HuggingFaceHub
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+ from langchain.vectorstores import FAISS
12
+ from langchain import PromptTemplate
13
+ import os
14
+ from dotenv import load_dotenv
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ # CSS Styling
20
+ css = """
21
  <style>
22
+ .chat-message { padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex; }
23
+ .chat-message.user { background-color: #2b313e; }
24
+ .chat-message.bot { background-color: #475063; }
25
+ .avatar { margin-right: 1rem; }
26
+ .message { color: white; }
 
 
 
 
 
 
 
 
 
 
 
27
  </style>
28
+ """
29
+
30
+ # Prompt template
31
+ PROMPT_TEMPLATE = """
32
+ You are a professional therapist who speaks Moroccan Arabic (Darija).
33
+ Respond with empathy and use therapeutic techniques.
34
+ Always respond in Darija unless specifically asked to use another language.
35
+
36
+ Previous conversation context:
37
+ {chat_history}
38
+
39
+ Current message: {question}
40
+
41
+ Therapeutic response:
42
+ """
43
+
44
+ class DarijaTherapist:
45
+ def __init__(self):
46
+ self.setup_models()
47
+ self.initialize_session_state()
48
+ self.setup_memory()
49
+
50
+ def setup_models(self):
51
+ # Speech recognition setup
52
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ self.asr_pipe = pipeline(
54
+ "automatic-speech-recognition",
55
+ model="facebook/seamless-m4t-v2-large",
56
+ device=self.device
57
+ )
58
+
59
+ # LLM and conversation chain setup
60
+ self.llm = HuggingFaceHub(
61
+ repo_id="MBZUAI-Paris/Atlas-Chat-27B",
62
+ model_kwargs={"temperature": 0.7, "max_length": 512},
63
+ )
64
+
65
+ # Create embeddings and vector store
66
+ self.embeddings = HuggingFaceEmbeddings()
67
+ self.vectorstore = FAISS.from_texts(
68
+ ["Initial therapeutic context"],
69
+ self.embeddings
70
+ )
71
+
72
+ def setup_memory(self):
73
+ self.memory = ConversationBufferMemory(
74
+ memory_key="chat_history",
75
+ return_messages=True
76
+ )
77
+
78
+ self.conversation_chain = ConversationalRetrievalChain.from_llm(
79
+ llm=self.llm,
80
+ retriever=self.vectorstore.as_retriever(),
81
+ memory=self.memory,
82
+ combine_docs_chain_kwargs={"prompt": PromptTemplate.from_template(PROMPT_TEMPLATE)}
83
+ )
84
+
85
+ def initialize_session_state(self):
86
+ if "messages" not in st.session_state:
87
+ st.session_state.messages = []
88
+ if "recording" not in st.session_state:
89
+ st.session_state.recording = False
90
+ if "audio_buffer" not in st.session_state:
91
+ st.session_state.audio_buffer = []
92
+
93
+ def handle_audio_input(self):
94
+ if not st.session_state.recording:
95
+ return
96
+
97
+ try:
98
+ # Record audio using torchaudio
99
+ waveform, sample_rate = torchaudio.load("temp_audio.wav")
100
+ st.session_state.audio_buffer.append(waveform)
101
+ except Exception as e:
102
+ st.error(f"Error recording audio: {str(e)}")
103
+
104
+ def process_audio(self):
105
+ if not st.session_state.audio_buffer:
106
+ return None
107
+
108
+ try:
109
+ # Concatenate audio buffer
110
+ audio_data = torch.cat(st.session_state.audio_buffer, dim=1)
111
+ # Save temporary file
112
+ torchaudio.save("temp_audio.wav", audio_data, 16000)
113
+ # Transcribe
114
+ audio, rate = sf.read("temp_audio.wav", dtype='float32')
115
+ result = self.asr_pipe(
116
+ audio,
117
+ generate_kwargs={"task": "transcribe", "language": "ara"}
118
+ )
119
+ return result["text"]
120
+ except Exception as e:
121
+ st.error(f"Error processing audio: {str(e)}")
122
+ return None
123
+ finally:
124
+ # Clear buffer
125
+ st.session_state.audio_buffer = []
126
+
127
+ def get_ai_response(self, user_input):
128
+ try:
129
+ response = self.conversation_chain({"question": user_input})
130
+ return response['answer']
131
+ except Exception as e:
132
+ st.error(f"Error getting AI response: {str(e)}")
133
+ return "عذراً، كاين شي مشكل. حاول مرة أخرى."
134
+
135
+ def run(self):
136
+ st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
137
+ st.markdown(css, unsafe_allow_html=True)
138
+
139
+ st.title("Darija AI Therapist 🧠")
140
+ st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
141
+
142
+ # Sidebar
143
+ with st.sidebar:
144
+ st.header("Settings ⚙️")
145
+ if st.button("Clear Chat History"):
146
+ st.session_state.messages = []
147
+ self.memory.clear()
148
+
149
+ st.markdown("### About")
150
+ st.info("This AI therapist speaks Darija and is here to help. "
151
+ "You can either type or speak your messages.")
152
+
153
+ # Audio input
154
+ col1, col2 = st.columns(2)
155
+ with col1:
156
  if st.button("🎤 Start Recording", disabled=st.session_state.recording):
157
  st.session_state.recording = True
158
+ st.session_state.audio_buffer = []
 
159
 
160
+ with col2:
161
  if st.button("⏹️ Stop Recording", disabled=not st.session_state.recording):
162
+ st.session_state.recording = False
163
+ transcription = self.process_audio()
164
+ if transcription:
165
+ self.process_message(transcription)
166
+
167
+ # Text input
168
+ user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
169
+ if user_input:
170
+ self.process_message(user_input)
171
+
172
+ # Display chat history
173
+ for message in st.session_state.messages:
174
+ with st.chat_message(message["role"]):
175
+ st.write(message["content"])
176
+
177
+ def process_message(self, user_input):
 
 
 
 
178
  # Add user message
179
  st.session_state.messages.append({"role": "user", "content": user_input})
180
 
181
+ # Get and add AI response
182
  with st.spinner("جاري التفكير..."):
183
+ ai_response = self.get_ai_response(user_input)
184
+ st.session_state.messages.append({"role": "assistant", "content": ai_response})
 
 
 
 
 
185
 
186
  if __name__ == "__main__":
187
+ app = DarijaTherapist()
188
+ app.run()