|
import torch |
|
import torch.nn as nn |
|
import random |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import pickle |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch |
|
|
|
|
|
model_name = 'gpt2' |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
model.eval() |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
tokenizer.clean_up_tokenization_spaces = True |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
|
|
session_memory = [] |
|
|
|
def save_memory(memory, filename='chat_memory.pkl'): |
|
with open(filename, 'wb') as f: |
|
pickle.dump(memory, f) |
|
|
|
def load_memory(filename='chat_memory.pkl'): |
|
try: |
|
with open(filename, 'rb') as f: |
|
return pickle.load(f) |
|
except FileNotFoundError: |
|
return [] |
|
|
|
session_memory = load_memory() |
|
|
|
|
|
def generate_response(prompt, max_length=25): |
|
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length) |
|
input_ids = inputs['input_ids'].to(device) |
|
attention_mask = inputs['attention_mask'].to(device) |
|
pad_token_id = tokenizer.pad_token_id |
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=2, |
|
do_sample=True, |
|
temperature=0.9, |
|
top_k=50, |
|
top_p=0.95, |
|
early_stopping=False, |
|
pad_token_id=pad_token_id |
|
) |
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
parts = response.split("\n", 1) |
|
if len(parts) > 1: |
|
before_indent = parts[0].strip() |
|
after_indent = "vß Gertrude" + parts[1].strip() |
|
final_response = before_indent + '\n' + after_indent |
|
else: |
|
final_response = response.strip() |
|
|
|
return final_response |
|
|
|
|
|
def advanced_agi_chat(user_input): |
|
session_memory.append({"input": user_input}) |
|
save_memory(session_memory) |
|
|
|
|
|
prompt = f"User: {user_input}\nResponse:" |
|
response = generate_response(prompt) |
|
|
|
return response |
|
|
|
|
|
def chat_interface(user_input): |
|
response = advanced_agi_chat(user_input) |
|
return response |
|
|
|
|
|
import gradio as gr |
|
|
|
auth = ("Tej", "186281mps", "ACC", "HIPE") |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("# **Autistic Assistant vß Edition 2024 Ultra: Gertrude's Autistic Experience**") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
user_input = gr.Textbox(label="What will you say to Gertrude?", placeholder="Type something here...") |
|
submit_button = gr.Button("Send") |
|
with gr.Column(scale=1): |
|
chatbot = gr.Textbox(label="Gertrude's Response", interactive=False) |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
.gradio-container { |
|
background-color: #B3D9FF; |
|
padding: 20px; |
|
border-radius: 15px; |
|
font-family: 'Comic Sans MS'; |
|
} |
|
.gradio-row { |
|
display: flex; |
|
justify-content: space-between; |
|
} |
|
</style> |
|
""") |
|
|
|
|
|
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot) |
|
|
|
|
|
app.launch() |
|
|
|
|