jaafarhh commited on
Commit
09618ca
·
verified ·
1 Parent(s): 719f101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -22
app.py CHANGED
@@ -4,19 +4,19 @@ import torchaudio
4
  import soundfile as sf
5
  from pathlib import Path
6
  from transformers import pipeline, AutoTokenizer
 
 
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain.chains import ConversationalRetrievalChain
9
- from langchain_community.llms import HuggingFaceHub
10
- from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from langchain_community.vectorstores import FAISS
12
- from langchain_core.prompts import PromptTemplate
13
  import os
14
  from dotenv import load_dotenv
15
 
16
  # Load environment variables
17
  load_dotenv()
18
 
19
- # CSS remains the same
20
  css = """
21
  <style>
22
  .chat-message { padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex; }
@@ -27,17 +27,20 @@ css = """
27
  </style>
28
  """
29
 
30
- # Updated prompt template with correct variables
31
  PROMPT_TEMPLATE = """
32
  You are a professional therapist who speaks Moroccan Arabic (Darija).
33
  Respond with empathy and use therapeutic techniques.
34
  Always respond in Darija unless specifically asked to use another language.
35
 
36
- Context: {context}
37
- Chat History: {chat_history}
38
- Current Question: {question}
39
 
40
- Therapeutic response:
 
 
 
 
41
  """
42
 
43
  class DarijaTherapist:
@@ -48,8 +51,8 @@ class DarijaTherapist:
48
 
49
  def setup_models(self):
50
  try:
 
51
  tokenizer = AutoTokenizer.from_pretrained("facebook/seamless-m4t-v2-large")
52
-
53
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
54
  self.asr_pipe = pipeline(
55
  "automatic-speech-recognition",
@@ -58,17 +61,25 @@ class DarijaTherapist:
58
  device=self.device
59
  )
60
 
61
- self.llm = HuggingFaceHub(
62
- repo_id="MBZUAI-Paris/Atlas-Chat-9B",
63
- model_kwargs={"temperature": 0.7, "max_length": 512, "torch_dtype": torch.bfloat16, "do_sample": True},
64
- device="cuda",
 
 
 
 
 
 
65
  huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN")
66
  )
67
 
68
- self.embeddings = HuggingFaceEmbeddings(
69
- model_name="sentence-transformers/all-mpnet-base-v2"
 
70
  )
71
 
 
72
  self.vectorstore = FAISS.from_texts(
73
  ["Initial therapeutic context"],
74
  self.embeddings
@@ -76,14 +87,13 @@ class DarijaTherapist:
76
  except Exception as e:
77
  st.error(f"Error setting up models: {str(e)}")
78
  st.stop()
79
-
80
  def setup_memory(self):
81
  self.memory = ConversationBufferMemory(
82
  memory_key="chat_history",
83
  return_messages=True
84
  )
85
 
86
- # Updated chain creation with correct prompt
87
  qa_prompt = PromptTemplate(
88
  template=PROMPT_TEMPLATE,
89
  input_variables=["context", "chat_history", "question"]
@@ -94,11 +104,9 @@ class DarijaTherapist:
94
  retriever=self.vectorstore.as_retriever(),
95
  memory=self.memory,
96
  combine_docs_chain_kwargs={"prompt": qa_prompt},
97
- return_source_documents=True,
98
- chain_type="stuff"
99
  )
100
 
101
- # Rest of the methods remain the same
102
  def initialize_session_state(self):
103
  if "messages" not in st.session_state:
104
  st.session_state.messages = []
@@ -138,7 +146,10 @@ class DarijaTherapist:
138
 
139
  def get_ai_response(self, user_input):
140
  try:
141
- response = self.conversation_chain({"question": user_input})
 
 
 
142
  return response['answer']
143
  except Exception as e:
144
  st.error(f"Error getting AI response: {str(e)}")
 
4
  import soundfile as sf
5
  from pathlib import Path
6
  from transformers import pipeline, AutoTokenizer
7
+ from langchain_community.llms import HuggingFaceEndpoint
8
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain.chains import ConversationalRetrievalChain
 
 
11
  from langchain_community.vectorstores import FAISS
12
+ from langchain.prompts import PromptTemplate
13
  import os
14
  from dotenv import load_dotenv
15
 
16
  # Load environment variables
17
  load_dotenv()
18
 
19
+ # CSS styling
20
  css = """
21
  <style>
22
  .chat-message { padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex; }
 
27
  </style>
28
  """
29
 
30
+ # Prompt template
31
  PROMPT_TEMPLATE = """
32
  You are a professional therapist who speaks Moroccan Arabic (Darija).
33
  Respond with empathy and use therapeutic techniques.
34
  Always respond in Darija unless specifically asked to use another language.
35
 
36
+ Previous conversation:
37
+ {chat_history}
 
38
 
39
+ User message: {question}
40
+
41
+ Additional context: {context}
42
+
43
+ Therapeutic response in Darija:
44
  """
45
 
46
  class DarijaTherapist:
 
51
 
52
  def setup_models(self):
53
  try:
54
+ # Speech recognition setup
55
  tokenizer = AutoTokenizer.from_pretrained("facebook/seamless-m4t-v2-large")
 
56
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
  self.asr_pipe = pipeline(
58
  "automatic-speech-recognition",
 
61
  device=self.device
62
  )
63
 
64
+ # LLM setup
65
+ self.llm = HuggingFaceEndpoint(
66
+ endpoint_url="https://api-inference.huggingface.co/models/MBZUAI-Paris/Atlas-Chat-9B",
67
+ task="text-generation",
68
+ model_kwargs={
69
+ "temperature": 0.7,
70
+ "max_length": 512,
71
+ "do_sample": True,
72
+ "return_full_text": False
73
+ },
74
  huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_TOKEN")
75
  )
76
 
77
+ # Embeddings setup
78
+ self.embeddings = HuggingFaceBgeEmbeddings(
79
+ model_name="BAAI/bge-large-en"
80
  )
81
 
82
+ # Vector store setup
83
  self.vectorstore = FAISS.from_texts(
84
  ["Initial therapeutic context"],
85
  self.embeddings
 
87
  except Exception as e:
88
  st.error(f"Error setting up models: {str(e)}")
89
  st.stop()
90
+
91
  def setup_memory(self):
92
  self.memory = ConversationBufferMemory(
93
  memory_key="chat_history",
94
  return_messages=True
95
  )
96
 
 
97
  qa_prompt = PromptTemplate(
98
  template=PROMPT_TEMPLATE,
99
  input_variables=["context", "chat_history", "question"]
 
104
  retriever=self.vectorstore.as_retriever(),
105
  memory=self.memory,
106
  combine_docs_chain_kwargs={"prompt": qa_prompt},
107
+ return_source_documents=True
 
108
  )
109
 
 
110
  def initialize_session_state(self):
111
  if "messages" not in st.session_state:
112
  st.session_state.messages = []
 
146
 
147
  def get_ai_response(self, user_input):
148
  try:
149
+ response = self.conversation_chain({
150
+ "question": user_input,
151
+ "chat_history": self.memory.chat_memory.messages
152
+ })
153
  return response['answer']
154
  except Exception as e:
155
  st.error(f"Error getting AI response: {str(e)}")