Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration | |
from typing import List | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
import os | |
from dotenv import load_dotenv | |
import requests | |
from requests.adapters import HTTPAdapter | |
from requests.packages.urllib3.util.retry import Retry | |
import whisper | |
import numpy as np | |
import av | |
import time # Added import time | |
import queue | |
# Load environment variables | |
load_dotenv() | |
# Initialize session state | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "audio_buffer" not in st.session_state: | |
st.session_state.audio_buffer = queue.Queue() | |
if 'recording' not in st.session_state: | |
st.session_state.recording = False | |
if 'webrtc_ctx' not in st.session_state: | |
st.session_state.webrtc_ctx = None | |
# Prompt template | |
PROMPT_TEMPLATE = """ | |
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija). | |
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques. | |
Always respond in Darija unless specifically asked otherwise. | |
Previous conversation: | |
{chat_history} | |
User message: {question} | |
Context: {context} | |
[/INST] | |
""" | |
# Setup retry strategy | |
retry_strategy = Retry( | |
total=3, | |
backoff_factor=1, | |
status_forcelist=[429, 500, 502, 503, 504] | |
) | |
session = requests.Session() | |
session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) | |
# Initialize models | |
whisper_model = whisper.load_model("base") | |
llm = HuggingFaceEndpoint( | |
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
task="text-generation", | |
temperature=0.7, | |
do_sample=True, | |
return_full_text=False, | |
max_new_tokens=2048, | |
top_p=0.9, | |
repetition_penalty=1.2, | |
model_kwargs={ | |
"return_text": True, | |
"stop": ["</s>"] | |
}, | |
huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"), | |
client=session | |
) | |
# Setup memory and conversation chain | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name="BAAI/bge-large-en" | |
) | |
vectorstore = FAISS.from_texts( | |
["Initial therapeutic context"], | |
embeddings | |
) | |
qa_prompt = PromptTemplate( | |
template=PROMPT_TEMPLATE, | |
input_variables=["context", "chat_history", "question"] | |
) | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vectorstore.as_retriever(), | |
memory=memory, | |
combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
return_source_documents=False, # Changed to False | |
chain_type="stuff" | |
) | |
def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame: | |
audio = frame.to_ndarray().flatten() | |
st.session_state.audio_buffer.put(audio) | |
return frame | |
def get_ai_response(user_input: str) -> str: | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
if not user_input or len(user_input.strip()) == 0: | |
return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك." | |
if len(user_input) > 512: | |
user_input = user_input[:512] | |
# Update response handling | |
response = conversation_chain({"question": user_input}) | |
if not response: | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين مشكل. حاول مرة أخرى." | |
return response['answer'] | |
except requests.exceptions.HTTPError as e: | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر." | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) | |
continue | |
return "عذراً، كاين شي مشكل. حاول مرة أخرى." | |
def process_message(user_input: str) -> None: | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
with st.spinner("جاري التفكير..."): | |
ai_response = get_ai_response(user_input) | |
if ai_response: | |
st.session_state.messages.append({"role": "assistant", "content": ai_response}) | |
def main(): | |
st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠") | |
st.title("Darija AI Therapist 🧠") | |
st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك") | |
col1, col2 = st.columns([9, 1]) | |
with col1: | |
user_input = st.text_input("اكتب رسالتك هنا:", key="text_input") | |
with col2: | |
if st.session_state.recording: | |
mic_icon = "🛑" | |
else: | |
mic_icon = "🎤" | |
if st.button(mic_icon): | |
st.session_state.recording = not st.session_state.recording | |
if st.session_state.recording: | |
st.session_state.audio_buffer = queue.Queue() | |
st.session_state.webrtc_ctx = webrtc_streamer( | |
key="speech-to-text", | |
mode=WebRtcMode.SENDONLY, | |
audio_receiver_size=256, | |
rtc_configuration=RTCConfiguration( | |
{"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
), | |
media_stream_constraints={"video": False, "audio": True}, | |
async_processing=True, | |
audio_frame_callback=audio_frame_callback, | |
) | |
else: | |
st.info("🔄 Processing audio...") | |
audio_frames = [] | |
while not st.session_state.audio_buffer.empty(): | |
audio_frames.append(st.session_state.audio_buffer.get()) | |
if audio_frames: | |
audio_data = np.concatenate(audio_frames, axis=0).flatten() | |
# Convert to 16-bit integers | |
audio_data_int16 = (audio_data * 32767).astype(np.int16) | |
# Use Whisper to transcribe | |
result = whisper_model.transcribe(audio_data_int16, fp16=False) | |
text = result.get("text", "") | |
if text: | |
process_message(text) | |
else: | |
st.warning("ما فهمتش الصوت. حاول مرة أخرى.") | |
st.session_state.audio_buffer = queue.Queue() | |
else: | |
st.warning("ما تسجلش الصوت. حاول مرة أخرى.") | |
if st.session_state.webrtc_ctx: | |
st.session_state.webrtc_ctx.stop() | |
st.session_state.webrtc_ctx = None | |
if st.session_state.recording: | |
st.info("🎙️ Recording...") | |
else: | |
st.empty() | |
if user_input: | |
process_message(user_input) | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
if __name__ == "__main__": | |
main() |