TejAndrewsACC's picture
Update app.py
a2583e0 verified
raw
history blame
4.6 kB
import torch
import torch.nn as nn
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from textblob import TextBlob
import gradio as gr
import pickle
import numpy as np
import torch.nn.functional as F
# ---- Constants and Setup ----
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
# Ensure tokenizer pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.clean_up_tokenization_spaces = True
# Set device for model and tensors
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# ---- Memory Management ----
session_memory = []
def save_memory(memory, filename='chat_memory.pkl'):
with open(filename, 'wb') as f:
pickle.dump(memory, f)
def load_memory(filename='chat_memory.pkl'):
try:
with open(filename, 'rb') as f:
return pickle.load(f)
except FileNotFoundError:
return []
session_memory = load_memory()
# ---- Sentiment Analysis ----
def analyze_sentiment(text):
blob = TextBlob(text)
return blob.sentiment.polarity # Range from -1 (negative) to 1 (positive)
def adjust_for_emotion(response, sentiment):
if sentiment > 0.2:
return f"That's wonderful! I'm glad you're feeling good: {response}"
elif sentiment < -0.2:
return f"I'm sorry to hear that: {response}. How can I assist you further?"
return response
# ---- Response Generation ----
def generate_response(prompt, max_length=512):
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
pad_token_id = tokenizer.pad_token_id
with torch.no_grad():
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
early_stopping=False,
pad_token_id=pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# Split response into two parts and apply color
parts = response.split("\n", 1)
if len(parts) > 1:
before_indent = f'<span style="color: orange;">{parts[0].strip()}</span>'
after_indent = f'<span style="color: blue;">Inner Thoughts: {parts[1].strip()}</span>'
colored_response = before_indent + '\n' + after_indent
else:
colored_response = f'<span style="color: orange;">{response.strip()}</span>'
return colored_response
# ---- Interactive Chat Function ----
def advanced_agi_chat(user_input):
session_memory.append({"input": user_input})
save_memory(session_memory)
# Sentiment analysis of user input
user_sentiment = analyze_sentiment(user_input)
# Generate the response based on the prompt
prompt = f"User: {user_input}\nAutistic-Gertrude:"
response = generate_response(prompt)
# Adjust the response based on sentiment
adjusted_response = adjust_for_emotion(response, user_sentiment)
return adjusted_response
# ---- Gradio Interface ----
def chat_interface(user_input):
response = advanced_agi_chat(user_input)
return response
# ---- Gradio App Setup ----
with gr.Blocks() as app:
gr.Markdown("# **Autistic Assistant vß Edition 2024 Ultra: Gertrude's Autistic Experience**")
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(label="What will you say to Gertrude?", placeholder="Type something here... Expect 1-2 Minute Response Times...")
submit_button = gr.Button("Send")
with gr.Column(scale=1):
chatbot = gr.Textbox(label="Gertrude's Response", interactive=False) # This is now a Textbox for output
# Adding custom styling for the UI
gr.HTML("""
<style>
.gradio-container {
background-color: #F4F8FF;
padding: 20px;
border-radius: 15px;
font-family: 'Comic Sans MS';
}
.gradio-row {
display: flex;
justify-content: space-between;
}
</style>
""")
# Setting the button click event
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot)
# Launch the Gradio app
app.launch()