Spaces:
Sleeping
Sleeping
Delete app.py
Browse files
app.py
DELETED
@@ -1,272 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import streamlit as st
|
3 |
-
from pydub import AudioSegment
|
4 |
-
import numpy as np
|
5 |
-
from transformers import pipeline
|
6 |
-
from langchain_huggingface import HuggingFaceEndpoint
|
7 |
-
from langchain_core.prompts import PromptTemplate
|
8 |
-
from langchain_core.output_parsers import StrOutputParser
|
9 |
-
|
10 |
-
# Model IDs
|
11 |
-
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
12 |
-
model2_id = "meta-llama/Llama-3.2-1B-Instruct"
|
13 |
-
whisper_model = "openai/whisper-small" # Using Whisper model for audio transcription
|
14 |
-
|
15 |
-
def get_llm_hf_inference(model_id, max_new_tokens=128, temperature=0.1):
|
16 |
-
"""Returns a language model for HuggingFace inference."""
|
17 |
-
try:
|
18 |
-
llm = HuggingFaceEndpoint(
|
19 |
-
repo_id=model_id,
|
20 |
-
max_new_tokens=max_new_tokens,
|
21 |
-
temperature=temperature,
|
22 |
-
token=os.getenv("HF_TOKEN")
|
23 |
-
)
|
24 |
-
return llm
|
25 |
-
except Exception as e:
|
26 |
-
st.error(f"Error initializing model: {e}")
|
27 |
-
return None
|
28 |
-
|
29 |
-
# Initialize Whisper transcription model
|
30 |
-
def load_transcription_model():
|
31 |
-
try:
|
32 |
-
transcriber = pipeline("automatic-speech-recognition", model=whisper_model)
|
33 |
-
return transcriber
|
34 |
-
except Exception as e:
|
35 |
-
st.error(f"Error loading Whisper model: {e}")
|
36 |
-
return None
|
37 |
-
|
38 |
-
# Preprocess audio to 16kHz mono
|
39 |
-
def preprocess_audio(file):
|
40 |
-
audio = AudioSegment.from_file(file).set_frame_rate(16000).set_channels(1)
|
41 |
-
audio_samples = np.array(audio.get_array_of_samples()).astype(np.float32) / (2**15)
|
42 |
-
return audio_samples
|
43 |
-
|
44 |
-
# Function to transcribe audio with preprocessing
|
45 |
-
def transcribe_audio(file, transcriber):
|
46 |
-
audio = preprocess_audio(file)
|
47 |
-
transcription = transcriber(audio)["text"]
|
48 |
-
return transcription
|
49 |
-
|
50 |
-
# Configure the Streamlit app
|
51 |
-
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗")
|
52 |
-
|
53 |
-
# Initialize session state
|
54 |
-
if "avatars" not in st.session_state:
|
55 |
-
st.session_state.avatars = {'user': None, 'assistant': None}
|
56 |
-
|
57 |
-
if 'user_text' not in st.session_state:
|
58 |
-
st.session_state.user_text = None
|
59 |
-
|
60 |
-
if "max_response_length" not in st.session_state:
|
61 |
-
st.session_state.max_response_length = 1000
|
62 |
-
|
63 |
-
if "system_message" not in st.session_state:
|
64 |
-
st.session_state.system_message = "friendly AI conversing with a human user"
|
65 |
-
|
66 |
-
if "starter_message" not in st.session_state:
|
67 |
-
st.session_state.starter_message = "Hello, there! How can I help you today?"
|
68 |
-
|
69 |
-
if "chat_history" not in st.session_state:
|
70 |
-
st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
|
71 |
-
|
72 |
-
# Homepage content
|
73 |
-
def display_homepage():
|
74 |
-
st.markdown(
|
75 |
-
"""
|
76 |
-
<div style="text-align: center; margin-top: 50px; font-size: 60px; font-weight: bold; color: #2c3e50; margin-left: auto; margin-right: auto;">
|
77 |
-
Your personal AI Therapist
|
78 |
-
</div>
|
79 |
-
""", unsafe_allow_html=True
|
80 |
-
)
|
81 |
-
|
82 |
-
st.markdown(
|
83 |
-
"""
|
84 |
-
<div style="text-align: center; font-size: 20px; color: #000000; margin-top: 20px; max-width: 700px; margin-left: auto; margin-right: auto;">
|
85 |
-
This is your healthcare chatbot that streamlines outpatient care, solves routine queries 24/7, and effortlessly automates appointment bookings, prescriptions, and reports. Let AI help you with your mental health journey.
|
86 |
-
</div>
|
87 |
-
""", unsafe_allow_html=True
|
88 |
-
)
|
89 |
-
|
90 |
-
st.markdown(
|
91 |
-
"""
|
92 |
-
<style>
|
93 |
-
.stApp {
|
94 |
-
background-image: url('https://images.pexels.com/photos/2680270/pexels-photo-2680270.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1');
|
95 |
-
background-size: cover;
|
96 |
-
background-position: center;
|
97 |
-
background-repeat: no-repeat;
|
98 |
-
height: 100vh;
|
99 |
-
color: white;
|
100 |
-
}
|
101 |
-
|
102 |
-
.stButton>button {
|
103 |
-
background-color: #2980b9;
|
104 |
-
color: white;
|
105 |
-
font-size: 18px;
|
106 |
-
font-weight: bold;
|
107 |
-
padding: 15px 30px;
|
108 |
-
border-radius: 8px;
|
109 |
-
border: none;
|
110 |
-
box-shadow: 0px 5px 15px rgba(0, 0, 0, 0.1);
|
111 |
-
cursor: pointer;
|
112 |
-
transition: background-color 0.3s ease;
|
113 |
-
}
|
114 |
-
|
115 |
-
.stButton>button:hover {
|
116 |
-
background-color: #3498db;
|
117 |
-
}
|
118 |
-
</style>
|
119 |
-
""", unsafe_allow_html=True
|
120 |
-
)
|
121 |
-
|
122 |
-
if st.button("Start Chat", key="start_chat_button"):
|
123 |
-
st.session_state.chat_started = True
|
124 |
-
# No need for rerun; session state change will trigger re-rendering
|
125 |
-
|
126 |
-
# Chatbot page content
|
127 |
-
def display_chatbot():
|
128 |
-
st.title("Personal Psychologist Chatbot")
|
129 |
-
st.markdown(f"*This is a simple chatbot that acts as a psychologist and gives solutions to your psychological problems. It uses the {model_id}.*")
|
130 |
-
|
131 |
-
# Sidebar for settings
|
132 |
-
with st.sidebar:
|
133 |
-
# Reset Chat History
|
134 |
-
reset_history = st.button("Reset Chat History")
|
135 |
-
go_home = st.button("Back to Home")
|
136 |
-
if go_home:
|
137 |
-
st.session_state.chat_started = False
|
138 |
-
st.experimental_rerun() # This will reload the app to show the homepage
|
139 |
-
|
140 |
-
# Initialize or reset chat history
|
141 |
-
if reset_history:
|
142 |
-
st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
|
143 |
-
|
144 |
-
def get_response(system_message, chat_history, user_text, model_id, max_new_tokens=256):
|
145 |
-
"""Generates a response from the chatbot model."""
|
146 |
-
hf = get_llm_hf_inference(model_id=model_id, max_new_tokens=max_new_tokens)
|
147 |
-
if hf is None:
|
148 |
-
return "Error: Model not initialized.", chat_history
|
149 |
-
|
150 |
-
# Create the prompt template
|
151 |
-
prompt = PromptTemplate.from_template(
|
152 |
-
(
|
153 |
-
"[INST] {system_message}"
|
154 |
-
"\nCurrent Conversation:\n{chat_history}\n\n"
|
155 |
-
"\nUser: {user_text}.\n [/INST]"
|
156 |
-
"\nAI:"
|
157 |
-
)
|
158 |
-
)
|
159 |
-
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
|
160 |
-
|
161 |
-
# Generate the response
|
162 |
-
response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
|
163 |
-
response = response.split("AI:")[-1].strip()
|
164 |
-
|
165 |
-
# Enhanced end-of-conversation detection:
|
166 |
-
# Check if response contains low engagement or end-of-conversation patterns
|
167 |
-
low_engagement_threshold = 3 # Threshold for short responses
|
168 |
-
end_keywords = ["thank you", "thanks", "goodbye", "bye", "that's all", "done"]
|
169 |
-
|
170 |
-
# Check for short responses over multiple turns
|
171 |
-
short_responses = len(user_text.split()) <= low_engagement_threshold
|
172 |
-
end_pattern_match = any(keyword in user_text.lower() for keyword in end_keywords)
|
173 |
-
|
174 |
-
# Recent short responses pattern
|
175 |
-
recent_short_responses = all(len(msg["content"].split()) <= low_engagement_threshold for msg in chat_history[-2:])
|
176 |
-
response_is_acknowledgment = user_text.lower() in ["yes", "okay", "alright"]
|
177 |
-
|
178 |
-
# Trigger health report prompt based on combination of patterns
|
179 |
-
if (end_pattern_match or (short_responses and recent_short_responses)) and not response_is_acknowledgment:
|
180 |
-
follow_up_question = "Would you like to have a report of your current health? Yes/No"
|
181 |
-
response += f"\n\n{follow_up_question}"
|
182 |
-
|
183 |
-
# Update the chat history
|
184 |
-
chat_history.append({'role': 'user', 'content': user_text})
|
185 |
-
chat_history.append({'role': 'assistant', 'content': response})
|
186 |
-
return response, chat_history
|
187 |
-
|
188 |
-
def get_summary_of_chat_history(chat_history, model2_id):
|
189 |
-
"""Generates a comprehensive summary of the chat history and a health report."""
|
190 |
-
hf = get_llm_hf_inference(model_id=model2_id, max_new_tokens=256)
|
191 |
-
if hf is None:
|
192 |
-
return "Error: Model not initialized."
|
193 |
-
|
194 |
-
# Format the chat content
|
195 |
-
chat_content = "\n".join([f"{message['role']}: {message['content']}" for message in chat_history])
|
196 |
-
|
197 |
-
# Improved summary prompt
|
198 |
-
prompt = PromptTemplate.from_template(
|
199 |
-
(
|
200 |
-
"Please analyze the following conversation to determine the user's emotional state and well-being."
|
201 |
-
"\nConsider the user's tone, engagement, and specific phrases they used. Based on the overall mood, "
|
202 |
-
"patterns of responses, and any indicators of stress, anxiety, positivity, or concern, generate a detailed "
|
203 |
-
"health summary that reflects the user's mental and emotional condition."
|
204 |
-
"\nAdditionally, provide recommendations or reassurance if needed.\n\nConversation:\n{chat_content}"
|
205 |
-
)
|
206 |
-
)
|
207 |
-
|
208 |
-
summary = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
|
209 |
-
summary_response = summary.invoke(input={"chat_content": chat_content})
|
210 |
-
|
211 |
-
return summary_response
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
# Load Whisper model for transcription
|
216 |
-
transcriber = load_transcription_model()
|
217 |
-
|
218 |
-
# User input for audio and text
|
219 |
-
st.markdown("### Choose your input:")
|
220 |
-
audio_file = st.file_uploader("Upload an audio file for transcription", type=["mp3", "wav", "m4a"])
|
221 |
-
st.session_state.user_text = st.chat_input(placeholder="Or enter your text here.")
|
222 |
-
|
223 |
-
# Check if audio file is uploaded and transcribe if available
|
224 |
-
if audio_file is not None and transcriber:
|
225 |
-
with st.spinner("Transcribing audio..."):
|
226 |
-
try:
|
227 |
-
st.session_state.user_text = transcribe_audio(audio_file, transcriber)
|
228 |
-
st.success("Audio transcribed successfully!")
|
229 |
-
except Exception as e:
|
230 |
-
st.error(f"Error transcribing audio: {e}")
|
231 |
-
|
232 |
-
# Chat interface
|
233 |
-
output_container = st.container()
|
234 |
-
|
235 |
-
# Display chat messages
|
236 |
-
with output_container:
|
237 |
-
for message in st.session_state.chat_history:
|
238 |
-
if message['role'] == 'system':
|
239 |
-
continue
|
240 |
-
with st.chat_message(message['role'], avatar=st.session_state['avatars'][message['role']]):
|
241 |
-
st.markdown(message['content'])
|
242 |
-
|
243 |
-
# Process text input for chatbot response
|
244 |
-
if st.session_state.user_text:
|
245 |
-
with st.chat_message("user", avatar=st.session_state.avatars['user']):
|
246 |
-
st.markdown(st.session_state.user_text)
|
247 |
-
|
248 |
-
with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
|
249 |
-
with st.spinner("Addressing your concerns..."):
|
250 |
-
response, st.session_state.chat_history = get_response(
|
251 |
-
system_message=st.session_state.system_message,
|
252 |
-
user_text=st.session_state.user_text,
|
253 |
-
chat_history=st.session_state.chat_history,
|
254 |
-
model_id=model_id,
|
255 |
-
max_new_tokens=st.session_state.max_response_length,
|
256 |
-
)
|
257 |
-
st.markdown(response)
|
258 |
-
|
259 |
-
# Check if the user has agreed to the report
|
260 |
-
if "yes" in st.session_state.user_text.lower() and "Would you like to have a report of your current health?" in response:
|
261 |
-
with st.spinner("Generating your health report..."):
|
262 |
-
report = get_summary_of_chat_history(st.session_state.chat_history, model2_id)
|
263 |
-
st.markdown(report)
|
264 |
-
|
265 |
-
# Main logic to switch between homepage and chatbot
|
266 |
-
if 'chat_started' not in st.session_state or not st.session_state.chat_started:
|
267 |
-
display_homepage()
|
268 |
-
else:
|
269 |
-
display_chatbot()
|
270 |
-
|
271 |
-
# Adjust the layout to reduce white space
|
272 |
-
st.markdown("<style>div.stContainer {padding-top: 0;}</style>", unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|