Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -63,9 +63,13 @@ emotion_prediction_tokenizer = AutoTokenizer.from_pretrained("bhadresh-savani/di
|
|
63 |
# Load pre-trained large language model and tokenizer for response generation with increased context window
|
64 |
response_model_name = "gpt2-xl"
|
65 |
response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
|
|
|
|
|
66 |
with init_empty_weights():
|
67 |
response_model = AutoModelForCausalLM.from_pretrained(response_model_name)
|
68 |
response_model.tie_weights()
|
|
|
|
|
69 |
# Set the pad token
|
70 |
response_tokenizer.pad_token = response_tokenizer.eos_token
|
71 |
|
@@ -183,149 +187,108 @@ def generate_response(input_text, ai_emotion, conversation_history):
|
|
183 |
for entry in conversation_history[-100:]: # Use last 100 entries for context
|
184 |
prompt = f"Human: {entry['user']}\nAI: {entry['response']}\n" + prompt
|
185 |
|
186 |
-
inputs = response_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=8192)
|
187 |
|
188 |
# Adjust generation parameters based on emotion
|
189 |
temperature = 0.7
|
190 |
-
if ai_emotion == 'anger':
|
191 |
-
temperature = 0.9 #
|
192 |
-
elif ai_emotion == '
|
193 |
-
temperature = 0.5 #
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
pad_token_id=response_tokenizer.eos_token_id
|
207 |
-
)
|
208 |
-
response = response_tokenizer.decode(response_ids[0], skip_special_tokens=True)
|
209 |
-
|
210 |
-
# Extract only the AI's response
|
211 |
-
response = response.split("AI:")[-1].strip()
|
212 |
return response
|
213 |
|
214 |
-
def
|
215 |
-
|
|
|
216 |
with torch.no_grad():
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
sentiment_scores = sia.polarity_scores(text)
|
226 |
-
return sentiment_scores
|
227 |
-
|
228 |
-
def extract_entities(text):
|
229 |
-
chunked = ne_chunk(pos_tag(word_tokenize(text)))
|
230 |
-
entities = []
|
231 |
-
for chunk in chunked:
|
232 |
-
if hasattr(chunk, 'label'):
|
233 |
-
entities.append(((' '.join(c[0] for c in chunk)), chunk.label()))
|
234 |
-
return entities
|
235 |
-
|
236 |
-
def analyze_text_complexity(text):
|
237 |
-
blob = TextBlob(text)
|
238 |
-
return {
|
239 |
-
'word_count': len(blob.words),
|
240 |
-
'sentence_count': len(blob.sentences),
|
241 |
-
'average_sentence_length': len(blob.words) / len(blob.sentences) if len(blob.sentences) > 0 else 0,
|
242 |
-
'polarity': blob.sentiment.polarity,
|
243 |
-
'subjectivity': blob.sentiment.subjectivity
|
244 |
-
}
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
-
|
|
|
|
|
259 |
|
260 |
-
|
|
|
|
|
|
|
261 |
global conversation_history
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
}
|
302 |
-
|
303 |
-
return analysis_result
|
304 |
-
except Exception as e:
|
305 |
-
print(f"Error: {e}")
|
306 |
-
return {
|
307 |
-
'predicted_user_emotion': 'unknown',
|
308 |
-
'ai_emotion': 'neutral',
|
309 |
-
'sentiment_scores': {'compound': 0, 'neg': 0, 'neu': 1, 'pos': 0},
|
310 |
-
'entities': [],
|
311 |
-
'text_complexity': {'word_count': 0, 'sentence_count': 0, 'average_sentence_length': 0, 'polarity': 0, 'subjectivity': 0},
|
312 |
-
'response': "I'm sorry, but I encountered an error and was unable to generate a response.",
|
313 |
-
'emotion_visualization': 'emotional_state.png'
|
314 |
-
}
|
315 |
-
|
316 |
-
# Create a Gradio interface
|
317 |
-
gr.Interface(
|
318 |
-
fn=interactive_interface,
|
319 |
-
inputs=gr.Textbox(label="Your Message"),
|
320 |
-
outputs=[
|
321 |
-
gr.Textbox(label="Predicted User Emotion"),
|
322 |
-
gr.Textbox(label="AI Emotion"),
|
323 |
-
gr.Textbox(label="Sentiment Scores"),
|
324 |
-
gr.Textbox(label="Extracted Entities"),
|
325 |
-
gr.Textbox(label="Text Complexity"),
|
326 |
-
gr.Textbox(label="AI Response"),
|
327 |
-
gr.Image(label="Emotional State Visualization")
|
328 |
-
],
|
329 |
-
title="Emotion-Aware AI Assistant by Sephfox",
|
330 |
-
description="Interact with an AI assistant created by Sephfox that responds based on its emotional state.",
|
331 |
-
).launch()
|
|
|
63 |
# Load pre-trained large language model and tokenizer for response generation with increased context window
|
64 |
response_model_name = "gpt2-xl"
|
65 |
response_tokenizer = AutoTokenizer.from_pretrained(response_model_name)
|
66 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
67 |
+
|
68 |
with init_empty_weights():
|
69 |
response_model = AutoModelForCausalLM.from_pretrained(response_model_name)
|
70 |
response_model.tie_weights()
|
71 |
+
response_model.to(device)
|
72 |
+
|
73 |
# Set the pad token
|
74 |
response_tokenizer.pad_token = response_tokenizer.eos_token
|
75 |
|
|
|
187 |
for entry in conversation_history[-100:]: # Use last 100 entries for context
|
188 |
prompt = f"Human: {entry['user']}\nAI: {entry['response']}\n" + prompt
|
189 |
|
190 |
+
inputs = response_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=8192).to(device)
|
191 |
|
192 |
# Adjust generation parameters based on emotion
|
193 |
temperature = 0.7
|
194 |
+
if (ai_emotion == 'anger'):
|
195 |
+
temperature = 0.9 # more intense
|
196 |
+
elif (ai_emotion == 'calmness'):
|
197 |
+
temperature = 0.5 # more composed
|
198 |
+
|
199 |
+
outputs = response_model.generate(
|
200 |
+
inputs['input_ids'],
|
201 |
+
max_length=500,
|
202 |
+
num_return_sequences=1,
|
203 |
+
temperature=temperature,
|
204 |
+
pad_token_id=response_tokenizer.eos_token_id
|
205 |
+
)
|
206 |
+
|
207 |
+
response = response_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
208 |
+
response = response.replace(prompt, "").strip()
|
209 |
+
conversation_history.append({'user': input_text, 'response': response})
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
return response
|
211 |
|
212 |
+
def process_input(input_text):
|
213 |
+
# Predict emotion of the input text
|
214 |
+
inputs = emotion_prediction_tokenizer(input_text, return_tensors='pt', truncation=True, padding=True).to(device)
|
215 |
with torch.no_grad():
|
216 |
+
logits = emotion_prediction_model(**inputs).logits
|
217 |
+
|
218 |
+
predicted_class_id = torch.argmax(logits, dim=1).item()
|
219 |
+
predicted_emotion = emotion_classes[predicted_class_id]
|
220 |
+
|
221 |
+
# Update emotion percentages and intensities based on predicted emotion
|
222 |
+
update_emotion(predicted_emotion, 5, 5) # Example increment values
|
223 |
+
update_emotion_history(predicted_emotion, emotions[predicted_emotion]['percentage'], emotions[predicted_emotion]['intensity'], input_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
+
# Evolve emotions
|
226 |
+
evolve_emotions()
|
227 |
+
|
228 |
+
# Generate response
|
229 |
+
response = generate_response(input_text, predicted_emotion, conversation_history)
|
230 |
|
231 |
+
# Feature transformations
|
232 |
+
feature_transformations()
|
233 |
+
|
234 |
+
return response
|
235 |
+
|
236 |
+
def plot_emotion_distribution():
|
237 |
+
emotion_labels = list(emotions.keys())
|
238 |
+
emotion_percentages = [emotions[emotion]['percentage'] for emotion in emotion_labels]
|
239 |
+
emotion_intensities = [emotions[emotion]['intensity'] for emotion in emotion_labels]
|
240 |
+
|
241 |
+
fig, ax1 = plt.subplots(figsize=(10, 6))
|
242 |
+
|
243 |
+
ax2 = ax1.twinx()
|
244 |
+
ax1.bar(emotion_labels, emotion_percentages, color='b', alpha=0.6)
|
245 |
+
ax2.plot(emotion_labels, emotion_intensities, color='r', marker='o', linestyle='dashed', linewidth=2)
|
246 |
|
247 |
+
ax1.set_xlabel('Emotion')
|
248 |
+
ax1.set_ylabel('Percentage', color='b')
|
249 |
+
ax2.set_ylabel('Intensity', color='r')
|
250 |
|
251 |
+
plt.title('Emotion Distribution and Intensities')
|
252 |
+
plt.show()
|
253 |
+
|
254 |
+
def clear_conversation_history():
|
255 |
global conversation_history
|
256 |
+
conversation_history = []
|
257 |
+
|
258 |
+
# Function to display the history of the 10 most recent conversations
|
259 |
+
def display_recent_conversations():
|
260 |
+
num_conversations = min(len(conversation_history), 10)
|
261 |
+
recent_conversations = conversation_history[-num_conversations:]
|
262 |
+
|
263 |
+
conversation_text = ""
|
264 |
+
for i, conversation in enumerate(recent_conversations, start=1):
|
265 |
+
conversation_text += f"Conversation {i}:\n"
|
266 |
+
conversation_text += f"User: {conversation['user']}\n"
|
267 |
+
conversation_text += f"AI: {conversation['response']}\n\n"
|
268 |
+
|
269 |
+
return conversation_text.strip()
|
270 |
+
|
271 |
+
with gr.Blocks() as chatbot:
|
272 |
+
gr.Markdown("# AI Chatbot with Enhanced Emotions")
|
273 |
+
|
274 |
+
with gr.Row():
|
275 |
+
with gr.Column():
|
276 |
+
input_text = gr.Textbox(label="Input Text")
|
277 |
+
response_text = gr.Textbox(label="Response", interactive=False)
|
278 |
+
send_button = gr.Button("Send")
|
279 |
+
clear_button = gr.Button("Clear Conversation History")
|
280 |
+
|
281 |
+
with gr.Row():
|
282 |
+
recent_conversations = gr.Textbox(label="Recent Conversations", interactive=False)
|
283 |
+
update_button = gr.Button("Update Recent Conversations")
|
284 |
+
|
285 |
+
with gr.Row():
|
286 |
+
emotion_plot = gr.Plot(label="Emotion Distribution and Intensities")
|
287 |
+
update_plot_button = gr.Button("Update Emotion Plot")
|
288 |
+
|
289 |
+
send_button.click(fn=process_input, inputs=input_text, outputs=response_text)
|
290 |
+
clear_button.click(fn=clear_conversation_history)
|
291 |
+
update_button.click(fn=display_recent_conversations, outputs=recent_conversations)
|
292 |
+
update_plot_button.click(fn=plot_emotion_distribution, outputs=emotion_plot)
|
293 |
+
|
294 |
+
chatbot.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|