Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,109 +1,188 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
|
4 |
-
|
|
|
5 |
from transformers import pipeline
|
6 |
-
import
|
7 |
-
import
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
<style>
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
}
|
25 |
-
.main-header {
|
26 |
-
text-align: center;
|
27 |
-
color: #2e7d32;
|
28 |
-
}
|
29 |
-
.chat-container {
|
30 |
-
padding: 20px;
|
31 |
-
border-radius: 10px;
|
32 |
-
background-color: #f5f5f5;
|
33 |
-
margin: 10px 0;
|
34 |
-
}
|
35 |
</style>
|
36 |
-
"""
|
37 |
-
|
38 |
-
#
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if st.button("🎤 Start Recording", disabled=st.session_state.recording):
|
69 |
st.session_state.recording = True
|
70 |
-
st.session_state.
|
71 |
-
st.experimental_rerun()
|
72 |
|
73 |
-
with
|
74 |
if st.button("⏹️ Stop Recording", disabled=not st.session_state.recording):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
st.markdown("</div>", unsafe_allow_html=True)
|
92 |
-
|
93 |
-
def process_input(user_input):
|
94 |
-
if user_input:
|
95 |
# Add user message
|
96 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
97 |
|
98 |
-
# Get
|
99 |
with st.spinner("جاري التفكير..."):
|
100 |
-
|
101 |
-
st.session_state.messages.append({"role": "assistant", "content":
|
102 |
-
|
103 |
-
def display_chat_history():
|
104 |
-
for message in st.session_state.messages:
|
105 |
-
with st.chat_message(message["role"]):
|
106 |
-
st.write(message["content"])
|
107 |
|
108 |
if __name__ == "__main__":
|
109 |
-
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import torchaudio
|
4 |
+
import soundfile as sf
|
5 |
+
from pathlib import Path
|
6 |
from transformers import pipeline
|
7 |
+
from langchain.memory import ConversationBufferMemory
|
8 |
+
from langchain.chains import ConversationalRetrievalChain
|
9 |
+
from langchain.llms import HuggingFaceHub
|
10 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
11 |
+
from langchain.vectorstores import FAISS
|
12 |
+
from langchain import PromptTemplate
|
13 |
+
import os
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
|
16 |
+
# Load environment variables
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
# CSS Styling
|
20 |
+
css = """
|
21 |
<style>
|
22 |
+
.chat-message { padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex; }
|
23 |
+
.chat-message.user { background-color: #2b313e; }
|
24 |
+
.chat-message.bot { background-color: #475063; }
|
25 |
+
.avatar { margin-right: 1rem; }
|
26 |
+
.message { color: white; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
</style>
|
28 |
+
"""
|
29 |
+
|
30 |
+
# Prompt template
|
31 |
+
PROMPT_TEMPLATE = """
|
32 |
+
You are a professional therapist who speaks Moroccan Arabic (Darija).
|
33 |
+
Respond with empathy and use therapeutic techniques.
|
34 |
+
Always respond in Darija unless specifically asked to use another language.
|
35 |
+
|
36 |
+
Previous conversation context:
|
37 |
+
{chat_history}
|
38 |
+
|
39 |
+
Current message: {question}
|
40 |
+
|
41 |
+
Therapeutic response:
|
42 |
+
"""
|
43 |
+
|
44 |
+
class DarijaTherapist:
|
45 |
+
def __init__(self):
|
46 |
+
self.setup_models()
|
47 |
+
self.initialize_session_state()
|
48 |
+
self.setup_memory()
|
49 |
+
|
50 |
+
def setup_models(self):
|
51 |
+
# Speech recognition setup
|
52 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
53 |
+
self.asr_pipe = pipeline(
|
54 |
+
"automatic-speech-recognition",
|
55 |
+
model="facebook/seamless-m4t-v2-large",
|
56 |
+
device=self.device
|
57 |
+
)
|
58 |
+
|
59 |
+
# LLM and conversation chain setup
|
60 |
+
self.llm = HuggingFaceHub(
|
61 |
+
repo_id="MBZUAI-Paris/Atlas-Chat-27B",
|
62 |
+
model_kwargs={"temperature": 0.7, "max_length": 512},
|
63 |
+
)
|
64 |
+
|
65 |
+
# Create embeddings and vector store
|
66 |
+
self.embeddings = HuggingFaceEmbeddings()
|
67 |
+
self.vectorstore = FAISS.from_texts(
|
68 |
+
["Initial therapeutic context"],
|
69 |
+
self.embeddings
|
70 |
+
)
|
71 |
+
|
72 |
+
def setup_memory(self):
|
73 |
+
self.memory = ConversationBufferMemory(
|
74 |
+
memory_key="chat_history",
|
75 |
+
return_messages=True
|
76 |
+
)
|
77 |
+
|
78 |
+
self.conversation_chain = ConversationalRetrievalChain.from_llm(
|
79 |
+
llm=self.llm,
|
80 |
+
retriever=self.vectorstore.as_retriever(),
|
81 |
+
memory=self.memory,
|
82 |
+
combine_docs_chain_kwargs={"prompt": PromptTemplate.from_template(PROMPT_TEMPLATE)}
|
83 |
+
)
|
84 |
+
|
85 |
+
def initialize_session_state(self):
|
86 |
+
if "messages" not in st.session_state:
|
87 |
+
st.session_state.messages = []
|
88 |
+
if "recording" not in st.session_state:
|
89 |
+
st.session_state.recording = False
|
90 |
+
if "audio_buffer" not in st.session_state:
|
91 |
+
st.session_state.audio_buffer = []
|
92 |
+
|
93 |
+
def handle_audio_input(self):
|
94 |
+
if not st.session_state.recording:
|
95 |
+
return
|
96 |
+
|
97 |
+
try:
|
98 |
+
# Record audio using torchaudio
|
99 |
+
waveform, sample_rate = torchaudio.load("temp_audio.wav")
|
100 |
+
st.session_state.audio_buffer.append(waveform)
|
101 |
+
except Exception as e:
|
102 |
+
st.error(f"Error recording audio: {str(e)}")
|
103 |
+
|
104 |
+
def process_audio(self):
|
105 |
+
if not st.session_state.audio_buffer:
|
106 |
+
return None
|
107 |
+
|
108 |
+
try:
|
109 |
+
# Concatenate audio buffer
|
110 |
+
audio_data = torch.cat(st.session_state.audio_buffer, dim=1)
|
111 |
+
# Save temporary file
|
112 |
+
torchaudio.save("temp_audio.wav", audio_data, 16000)
|
113 |
+
# Transcribe
|
114 |
+
audio, rate = sf.read("temp_audio.wav", dtype='float32')
|
115 |
+
result = self.asr_pipe(
|
116 |
+
audio,
|
117 |
+
generate_kwargs={"task": "transcribe", "language": "ara"}
|
118 |
+
)
|
119 |
+
return result["text"]
|
120 |
+
except Exception as e:
|
121 |
+
st.error(f"Error processing audio: {str(e)}")
|
122 |
+
return None
|
123 |
+
finally:
|
124 |
+
# Clear buffer
|
125 |
+
st.session_state.audio_buffer = []
|
126 |
+
|
127 |
+
def get_ai_response(self, user_input):
|
128 |
+
try:
|
129 |
+
response = self.conversation_chain({"question": user_input})
|
130 |
+
return response['answer']
|
131 |
+
except Exception as e:
|
132 |
+
st.error(f"Error getting AI response: {str(e)}")
|
133 |
+
return "عذراً، كاين شي مشكل. حاول مرة أخرى."
|
134 |
+
|
135 |
+
def run(self):
|
136 |
+
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
|
137 |
+
st.markdown(css, unsafe_allow_html=True)
|
138 |
+
|
139 |
+
st.title("Darija AI Therapist 🧠")
|
140 |
+
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
|
141 |
+
|
142 |
+
# Sidebar
|
143 |
+
with st.sidebar:
|
144 |
+
st.header("Settings ⚙️")
|
145 |
+
if st.button("Clear Chat History"):
|
146 |
+
st.session_state.messages = []
|
147 |
+
self.memory.clear()
|
148 |
+
|
149 |
+
st.markdown("### About")
|
150 |
+
st.info("This AI therapist speaks Darija and is here to help. "
|
151 |
+
"You can either type or speak your messages.")
|
152 |
+
|
153 |
+
# Audio input
|
154 |
+
col1, col2 = st.columns(2)
|
155 |
+
with col1:
|
156 |
if st.button("🎤 Start Recording", disabled=st.session_state.recording):
|
157 |
st.session_state.recording = True
|
158 |
+
st.session_state.audio_buffer = []
|
|
|
159 |
|
160 |
+
with col2:
|
161 |
if st.button("⏹️ Stop Recording", disabled=not st.session_state.recording):
|
162 |
+
st.session_state.recording = False
|
163 |
+
transcription = self.process_audio()
|
164 |
+
if transcription:
|
165 |
+
self.process_message(transcription)
|
166 |
+
|
167 |
+
# Text input
|
168 |
+
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
|
169 |
+
if user_input:
|
170 |
+
self.process_message(user_input)
|
171 |
+
|
172 |
+
# Display chat history
|
173 |
+
for message in st.session_state.messages:
|
174 |
+
with st.chat_message(message["role"]):
|
175 |
+
st.write(message["content"])
|
176 |
+
|
177 |
+
def process_message(self, user_input):
|
|
|
|
|
|
|
|
|
178 |
# Add user message
|
179 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
180 |
|
181 |
+
# Get and add AI response
|
182 |
with st.spinner("جاري التفكير..."):
|
183 |
+
ai_response = self.get_ai_response(user_input)
|
184 |
+
st.session_state.messages.append({"role": "assistant", "content": ai_response})
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
if __name__ == "__main__":
|
187 |
+
app = DarijaTherapist()
|
188 |
+
app.run()
|