DarijaTherapy / app.py
jaafarhh's picture
Update app.py
c3b2e36 verified
import streamlit as st
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
from typing import List
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
import os
from dotenv import load_dotenv
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import whisper
import numpy as np
import av
import time # Added import time
import queue
# Load environment variables
load_dotenv()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "audio_buffer" not in st.session_state:
st.session_state.audio_buffer = queue.Queue()
if 'recording' not in st.session_state:
st.session_state.recording = False
if 'webrtc_ctx' not in st.session_state:
st.session_state.webrtc_ctx = None
# Prompt template
PROMPT_TEMPLATE = """
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
Always respond in Darija unless specifically asked otherwise.
Previous conversation:
{chat_history}
User message: {question}
Context: {context}
[/INST]
"""
# Setup retry strategy
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504]
)
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
# Initialize models
whisper_model = whisper.load_model("base")
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
task="text-generation",
temperature=0.7,
do_sample=True,
return_full_text=False,
max_new_tokens=2048,
top_p=0.9,
repetition_penalty=1.2,
model_kwargs={
"return_text": True,
"stop": ["</s>"]
},
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
client=session
)
# Setup memory and conversation chain
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-en"
)
vectorstore = FAISS.from_texts(
["Initial therapeutic context"],
embeddings
)
qa_prompt = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "chat_history", "question"]
)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory,
combine_docs_chain_kwargs={"prompt": qa_prompt},
return_source_documents=False, # Changed to False
chain_type="stuff"
)
def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
audio = frame.to_ndarray().flatten()
st.session_state.audio_buffer.put(audio)
return frame
def get_ai_response(user_input: str) -> str:
max_retries = 3
for attempt in range(max_retries):
try:
if not user_input or len(user_input.strip()) == 0:
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
if len(user_input) > 512:
user_input = user_input[:512]
# Update response handling
response = conversation_chain({"question": user_input})
if not response:
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين مشكل. حاول مرة أخرى."
return response['answer']
except requests.exceptions.HTTPError as e:
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
except Exception as e:
st.error(f"Error: {str(e)}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين شي مشكل. حاول مرة أخرى."
def process_message(user_input: str) -> None:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.spinner("جاري التفكير..."):
ai_response = get_ai_response(user_input)
if ai_response:
st.session_state.messages.append({"role": "assistant", "content": ai_response})
def main():
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
st.title("Darija AI Therapist 🧠")
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
col1, col2 = st.columns([9, 1])
with col1:
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
with col2:
if st.session_state.recording:
mic_icon = "🛑"
else:
mic_icon = "🎤"
if st.button(mic_icon):
st.session_state.recording = not st.session_state.recording
if st.session_state.recording:
st.session_state.audio_buffer = queue.Queue()
st.session_state.webrtc_ctx = webrtc_streamer(
key="speech-to-text",
mode=WebRtcMode.SENDONLY,
audio_receiver_size=256,
rtc_configuration=RTCConfiguration(
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
),
media_stream_constraints={"video": False, "audio": True},
async_processing=True,
audio_frame_callback=audio_frame_callback,
)
else:
st.info("🔄 Processing audio...")
audio_frames = []
while not st.session_state.audio_buffer.empty():
audio_frames.append(st.session_state.audio_buffer.get())
if audio_frames:
audio_data = np.concatenate(audio_frames, axis=0).flatten()
# Convert to 16-bit integers
audio_data_int16 = (audio_data * 32767).astype(np.int16)
# Use Whisper to transcribe
result = whisper_model.transcribe(audio_data_int16, fp16=False)
text = result.get("text", "")
if text:
process_message(text)
else:
st.warning("ما فهمتش الصوت. حاول مرة أخرى.")
st.session_state.audio_buffer = queue.Queue()
else:
st.warning("ما تسجلش الصوت. حاول مرة أخرى.")
if st.session_state.webrtc_ctx:
st.session_state.webrtc_ctx.stop()
st.session_state.webrtc_ctx = None
if st.session_state.recording:
st.info("🎙️ Recording...")
else:
st.empty()
if user_input:
process_message(user_input)
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
if __name__ == "__main__":
main()