DarijaTherapy / app.py
jaafarhh's picture
Update app.py
4f8302d verified
raw
history blame
6.89 kB
import streamlit as st
import whisper
import numpy as np
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_core.memory import BaseMemory
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import ConversationChain
import os
from dotenv import load_dotenv
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import time
from streamlit_chat import message
from streamlit_audiorecorder import audiorecorder
# Load environment variables
load_dotenv()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "audio_data" not in st.session_state:
st.session_state.audio_data = None
if "recording" not in st.session_state:
st.session_state.recording = False
if "text_input" not in st.session_state:
st.session_state.text_input = ""
# Prompt template
PROMPT_TEMPLATE = """
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
Always respond in Darija unless specifically asked otherwise.
Previous conversation:
{chat_history}
User message: {question}
Context: {context}
[/INST]
"""
# Setup retry strategy
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504]
)
session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
# Initialize models
whisper_model = whisper.load_model("base")
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
task="text-generation",
temperature=0.7,
do_sample=True,
return_full_text=False,
max_new_tokens=2048,
top_p=0.9,
repetition_penalty=1.2,
model_kwargs={
"return_text": True,
"stop": ["</s>"]
},
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
client=session
)
# Setup memory and conversation chain
memory = ConversationChain(
memory_key="chat_history",
return_messages=True
)
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-large-en"
)
vectorstore = FAISS.from_texts(
["Initial therapeutic context"],
embeddings
)
qa_prompt = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "chat_history", "question"]
)
def create_chain():
prompt = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "chat_history", "question"]
)
retriever = vectorstore.as_retriever()
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return chain
conversation_chain = create_chain()
def get_ai_response(user_input: str) -> str:
max_retries = 3
for attempt in range(max_retries):
try:
if not user_input or len(user_input.strip()) == 0:
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."
if len(user_input) > 512:
user_input = user_input[:512]
response = conversation_chain({
"question": user_input,
"chat_history": memory.chat_memory.messages[-5:]
})
if not response or 'answer' not in response:
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين مشكل. حاول مرة أخرى."
return response['answer']
except requests.exceptions.HTTPError as e:
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."
except Exception as e:
st.error(f"Error: {str(e)}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return "عذراً، كاين شي مشكل. حاول مرة أخرى."
def process_message(user_input: str) -> None:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.spinner("جاري التفكير..."):
ai_response = get_ai_response(user_input)
if ai_response:
st.session_state.messages.append({"role": "assistant", "content": ai_response})
def main():
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")
st.title("Darija AI Therapist 🧠")
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")
# Chat interface
# Create columns for text input and mic button
col1, col2 = st.columns([9, 1])
with col1:
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
with col2:
# Mic button
if st.session_state.recording:
mic_label = "🛑"
else:
mic_label = "🎤"
if st.button(mic_label):
st.session_state.recording = not st.session_state.recording
if st.session_state.recording:
st.session_state.audio_data = audiorecorder("Click to stop recording")
else:
audio_data = st.session_state.audio_data
if audio_data is not None:
# Convert byte data to numpy array
audio_array = np.frombuffer(audio_data.tobytes(), dtype=np.int16)
# Normalize audio data
audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
# Transcribe audio using Whisper
result = whisper_model.transcribe(audio_array, language="ar")
if result["text"]:
# Put transcribed text into input field
st.session_state.text_input = result["text"]
else:
st.error("No audio data recorded.")
# Handle text submission
if user_input:
process_message(user_input)
st.session_state.text_input = "" # Clear input field after sending
# Display chat history
for message_data in st.session_state.messages:
role = message_data["role"]
content = message_data["content"]
if role == "user":
message(content, is_user=True)
else:
message(content)
if __name__ == "__main__":
main()