Spaces:
Sleeping
Sleeping
import streamlit as st | |
import whisper | |
import numpy as np | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.memory import BaseMemory | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.chains import ConversationChain | |
import os | |
from dotenv import load_dotenv | |
import requests | |
from requests.adapters import HTTPAdapter | |
from requests.packages.urllib3.util.retry import Retry | |
import time | |
from streamlit_chat import message | |
from streamlit_audiorecorder import audiorecorder | |
# Load environment variables | |
load_dotenv() | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "audio_data" not in st.session_state: | |
st.session_state.audio_data = None | |
if "recording" not in st.session_state: | |
st.session_state.recording = False | |
if "text_input" not in st.session_state: | |
st.session_state.text_input = "" | |
# Prompt template | |
PROMPT_TEMPLATE = """ | |
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija). | |
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques. | |
Always respond in Darija unless specifically asked otherwise. | |
Previous conversation: | |
{chat_history} | |
User message: {question} | |
Context: {context} | |
[/INST] | |
""" | |
# Setup retry strategy | |
retry_strategy = Retry( | |
total=3, | |
backoff_factor=1, | |
status_forcelist=[429, 500, 502, 503, 504] | |
) | |
session = requests.Session() | |
session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) | |
# Initialize models | |
whisper_model = whisper.load_model("base") | |
llm = HuggingFaceEndpoint( | |
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
task="text-generation", | |
temperature=0.7, | |
do_sample=True, | |
return_full_text=False, | |
max_new_tokens=2048, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
model_kwargs={ | |
"return_text": True, | |
"stop": ["</s>"] | |
}, | |
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"), | |
client=session | |
) | |
# Setup memory and conversation chain | |
memory = ConversationChain( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-large-en" | |
) | |
vectorstore = FAISS.from_texts( | |
["Initial therapeutic context"], | |
embeddings | |
) | |
qa_prompt = PromptTemplate( | |
template=PROMPT_TEMPLATE, | |
input_variables=["context", "chat_history", "question"] | |
) | |
def create_chain(): | |
prompt = PromptTemplate( | |
template=PROMPT_TEMPLATE, | |
input_variables=["context", "chat_history", "question"] | |
) | |
retriever = vectorstore.as_retriever() | |
chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
conversation_chain = create_chain() | |
def get_ai_response(user_input: str) -> str: | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
if not user_input or len(user_input.strip()) == 0: | |
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك." | |
if len(user_input) > 512: | |
user_input = user_input[:512] | |
response = conversation_chain({ | |
"question": user_input, | |
"chat_history": memory.chat_memory.messages[-5:] | |
}) | |
if not response or 'answer' not in response: | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين مشكل. حاول مرة أخرى." | |
return response['answer'] | |
except requests.exceptions.HTTPError as e: | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر." | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين شي مشكل. حاول مرة أخرى." | |
def process_message(user_input: str) -> None: | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.spinner("جاري التفكير..."): | |
ai_response = get_ai_response(user_input) | |
if ai_response: | |
st.session_state.messages.append({"role": "assistant", "content": ai_response}) | |
def main(): | |
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠") | |
st.title("Darija AI Therapist 🧠") | |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك") | |
# Chat interface | |
# Create columns for text input and mic button | |
col1, col2 = st.columns([9, 1]) | |
with col1: | |
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input") | |
with col2: | |
# Mic button | |
if st.session_state.recording: | |
mic_label = "🛑" | |
else: | |
mic_label = "🎤" | |
if st.button(mic_label): | |
st.session_state.recording = not st.session_state.recording | |
if st.session_state.recording: | |
st.session_state.audio_data = audiorecorder("Click to stop recording") | |
else: | |
audio_data = st.session_state.audio_data | |
if audio_data is not None: | |
# Convert byte data to numpy array | |
audio_array = np.frombuffer(audio_data.tobytes(), dtype=np.int16) | |
# Normalize audio data | |
audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max | |
# Transcribe audio using Whisper | |
result = whisper_model.transcribe(audio_array, language="ar") | |
if result["text"]: | |
# Put transcribed text into input field | |
st.session_state.text_input = result["text"] | |
else: | |
st.error("No audio data recorded.") | |
# Handle text submission | |
if user_input: | |
process_message(user_input) | |
st.session_state.text_input = "" # Clear input field after sending | |
# Display chat history | |
for message_data in st.session_state.messages: | |
role = message_data["role"] | |
content = message_data["content"] | |
if role == "user": | |
message(content, is_user=True) | |
else: | |
message(content) | |
if __name__ == "__main__": | |
main() |