File size: 7,530 Bytes
77c006e
deb9302
 
09618ca
 
deb9302
 
93b5d0e
deb9302
cb6b5d0
 
e49ad3d
9525ef8
 
deb9302
 
 
 
 
cb6b5d0
 
 
 
e49ad3d
 
 
deb9302
 
 
 
 
699acb6
deb9302
 
 
cb6b5d0
e49ad3d
cb6b5d0
4dcf57d
 
 
cb6b5d0
09618ca
 
cb6b5d0
09618ca
 
4dcf57d
 
cb6b5d0
 
e49ad3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb9302
e49ad3d
 
 
 
 
 
 
 
 
699acb6
e49ad3d
 
 
 
 
 
 
 
deb9302
 
 
 
 
c3b2e36
 
deb9302
 
c3b2e36
deb9302
 
 
 
e49ad3d
 
 
 
cb6b5d0
e49ad3d
 
cb6b5d0
e49ad3d
 
8133539
c3b2e36
 
 
 
4dcf57d
 
 
e49ad3d
deb9302
e49ad3d
cb6b5d0
e49ad3d
 
 
 
 
deb9302
e49ad3d
 
 
 
 
 
 
 
 
deb9302
e49ad3d
 
 
 
 
 
 
deb9302
e49ad3d
 
deb9302
699acb6
 
 
 
 
deb9302
699acb6
deb9302
 
 
699acb6
 
deb9302
 
 
 
 
 
 
 
 
 
 
 
699acb6
deb9302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699acb6
deb9302
 
 
 
 
 
 
 
 
699acb6
e49ad3d
 
 
 
deb9302
 
 
77c006e
 
e49ad3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import streamlit as st
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
from typing import List
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
import os
from dotenv import load_dotenv
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import whisper
import numpy as np
import av
import time  # Added import time
import queue

# Load environment variables
load_dotenv()

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []

if "audio_buffer" not in st.session_state:
    st.session_state.audio_buffer = queue.Queue()

if 'recording' not in st.session_state:
    st.session_state.recording = False

if 'webrtc_ctx' not in st.session_state:
    st.session_state.webrtc_ctx = None

# Prompt template
PROMPT_TEMPLATE = """
<s>[INST] You are a professional therapist who speaks Moroccan Arabic (Darija).
Act as a compassionate therapist and provide empathetic responses using therapeutic techniques.
Always respond in Darija unless specifically asked otherwise.

Previous conversation:
{chat_history}

User message: {question}

Context: {context}
[/INST]
"""

# Setup retry strategy
retry_strategy = Retry(
    total=3,
    backoff_factor=1,
    status_forcelist=[429, 500, 502, 503, 504]
)

session = requests.Session()
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))

# Initialize models
whisper_model = whisper.load_model("base")
llm = HuggingFaceEndpoint(
    endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1",
    task="text-generation",
    temperature=0.7,
    do_sample=True,
    return_full_text=False,
    max_new_tokens=2048,
    top_p=0.9,
    repetition_penalty=1.2,
    model_kwargs={
        "return_text": True,
        "stop": ["</s>"]
    },
    huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN"),
    client=session
)

# Setup memory and conversation chain
memory = ConversationBufferMemory(
    memory_key="chat_history",
    return_messages=True
)

embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en"
)

vectorstore = FAISS.from_texts(
    ["Initial therapeutic context"],
    embeddings
)

qa_prompt = PromptTemplate(
    template=PROMPT_TEMPLATE,
    input_variables=["context", "chat_history", "question"]
)

conversation_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=vectorstore.as_retriever(),
    memory=memory,
    combine_docs_chain_kwargs={"prompt": qa_prompt},
    return_source_documents=False,  # Changed to False
    chain_type="stuff"
)


def audio_frame_callback(frame: av.AudioFrame) -> av.AudioFrame:
    audio = frame.to_ndarray().flatten()
    st.session_state.audio_buffer.put(audio)
    return frame

def get_ai_response(user_input: str) -> str:
    max_retries = 3
    for attempt in range(max_retries):
        try:
            if not user_input or len(user_input.strip()) == 0:
                return "عذراً، ما فهمتش السؤال ديالك. عاود من فضلك."

            if len(user_input) > 512:
                user_input = user_input[:512]

            # Update response handling
            response = conversation_chain({"question": user_input})
            
            if not response:
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                return "عذراً، كاين مشكل. حاول مرة أخرى."

            return response['answer']

        except requests.exceptions.HTTPError as e:
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
                continue
            return "عذراً، كاين مشكل مع النموذج. جرب سؤال أقصر."

        except Exception as e:
            st.error(f"Error: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
                continue
            return "عذراً، كاين شي مشكل. حاول مرة أخرى."

def process_message(user_input: str) -> None:
    st.session_state.messages.append({"role": "user", "content": user_input})

    with st.spinner("جاري التفكير..."):
        ai_response = get_ai_response(user_input)
        if ai_response:
            st.session_state.messages.append({"role": "assistant", "content": ai_response})

def main():
    st.set_page_config(page_title="Darija AI Therapist", page_icon="🧠")

    st.title("Darija AI Therapist 🧠")
    st.subheader("تكلم معايا بالدارجة على اللي كيجول فبالك")

    col1, col2 = st.columns([9, 1])
    with col1:
        user_input = st.text_input("اكتب رسالتك هنا:", key="text_input")
    with col2:
        if st.session_state.recording:
            mic_icon = "🛑"
        else:
            mic_icon = "🎤"

        if st.button(mic_icon):
            st.session_state.recording = not st.session_state.recording
            if st.session_state.recording:
                st.session_state.audio_buffer = queue.Queue()
                st.session_state.webrtc_ctx = webrtc_streamer(
                    key="speech-to-text",
                    mode=WebRtcMode.SENDONLY,
                    audio_receiver_size=256,
                    rtc_configuration=RTCConfiguration(
                        {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
                    ),
                    media_stream_constraints={"video": False, "audio": True},
                    async_processing=True,
                    audio_frame_callback=audio_frame_callback,
                )
            else:
                st.info("🔄 Processing audio...")
                audio_frames = []
                while not st.session_state.audio_buffer.empty():
                    audio_frames.append(st.session_state.audio_buffer.get())

                if audio_frames:
                    audio_data = np.concatenate(audio_frames, axis=0).flatten()
                    # Convert to 16-bit integers
                    audio_data_int16 = (audio_data * 32767).astype(np.int16)
                    # Use Whisper to transcribe
                    result = whisper_model.transcribe(audio_data_int16, fp16=False)
                    text = result.get("text", "")
                    if text:
                        process_message(text)
                    else:
                        st.warning("ما فهمتش الصوت. حاول مرة أخرى.")
                    st.session_state.audio_buffer = queue.Queue()
                else:
                    st.warning("ما تسجلش الصوت. حاول مرة أخرى.")
                if st.session_state.webrtc_ctx:
                    st.session_state.webrtc_ctx.stop()
                    st.session_state.webrtc_ctx = None

    if st.session_state.recording:
        st.info("🎙️ Recording...")
    else:
        st.empty()

    if user_input:
        process_message(user_input)

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.write(message["content"])

if __name__ == "__main__":
    main()