TejAndrewsACC's picture
Update app.py
0378fd5 verified
raw
history blame
3.92 kB
import torch
import torch.nn as nn
import random
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pickle
import numpy as np
import torch.nn.functional as F
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
# ---- Constants and Setup ----
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
# Ensure tokenizer pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.clean_up_tokenization_spaces = True
# Set device for model and tensors
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# ---- Memory Management ----
session_memory = []
def save_memory(memory, filename='chat_memory.pkl'):
with open(filename, 'wb') as f:
pickle.dump(memory, f)
def load_memory(filename='chat_memory.pkl'):
try:
with open(filename, 'rb') as f:
return pickle.load(f)
except FileNotFoundError:
return []
session_memory = load_memory()
# ---- Response Generation ----
def generate_response(prompt, max_length=25):
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
pad_token_id = tokenizer.pad_token_id
with torch.no_grad():
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
early_stopping=False,
pad_token_id=pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
# Split response into two parts, where the second indent is considered the "inner thoughts"
parts = response.split("\n", 1)
if len(parts) > 1:
before_indent = parts[0].strip()
after_indent = "vß Gertrude" + parts[1].strip()
final_response = before_indent + '\n' + after_indent
else:
final_response = response.strip()
return final_response
# ---- Interactive Chat Function ----
def advanced_agi_chat(user_input):
session_memory.append({"input": user_input})
save_memory(session_memory)
# Generate the response based on the prompt
prompt = f"User: {user_input}\nResponse:"
response = generate_response(prompt)
return response
# ---- Gradio Interface ----
def chat_interface(user_input):
response = advanced_agi_chat(user_input)
return response
# ---- Gradio App Setup ----
import gradio as gr
auth = ("Tej", "186281mps", "ACC", "HIPE")
with gr.Blocks() as app:
gr.Markdown("# **Autistic Assistant vß Edition 2024 Ultra: Gertrude's Autistic Experience**")
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(label="What will you say to Gertrude?", placeholder="Type something here...")
submit_button = gr.Button("Send")
with gr.Column(scale=1):
chatbot = gr.Textbox(label="Gertrude's Response", interactive=False) # This is now a Textbox for output
# Adding custom styling for the UI
gr.HTML("""
<style>
.gradio-container {
background-color: #B3D9FF;
padding: 20px;
border-radius: 15px;
font-family: 'Comic Sans MS';
}
.gradio-row {
display: flex;
justify-content: space-between;
}
</style>
""")
# Setting the button click event
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot)
# Launch the Gradio app
app.launch()