Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1 |
-
# Hugging Face Space Adaptation for Autistic Assistant 2024 Ultra
|
2 |
-
|
3 |
# Install necessary libraries (if running locally)
|
4 |
# !pip install transformers torch textblob numpy gradio
|
5 |
|
6 |
# Import necessary libraries
|
7 |
import torch
|
8 |
-
import random
|
9 |
import torch.nn as nn
|
|
|
10 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
11 |
from textblob import TextBlob
|
12 |
import gradio as gr
|
13 |
import pickle
|
14 |
import numpy as np
|
|
|
15 |
|
16 |
# ---- Constants and Setup ----
|
17 |
model_name = 'gpt2'
|
@@ -51,11 +50,63 @@ def analyze_sentiment(text):
|
|
51 |
|
52 |
def adjust_for_emotion(response, sentiment):
|
53 |
if sentiment > 0.2:
|
54 |
-
return f"That's wonderful! I'm
|
55 |
elif sentiment < -0.2:
|
56 |
-
return f"I'm
|
57 |
return response
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# ---- Response Generation ----
|
60 |
def generate_response(prompt, max_length=1024):
|
61 |
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
@@ -105,13 +156,16 @@ def chat_interface(user_input):
|
|
105 |
return response
|
106 |
|
107 |
with gr.Blocks() as app:
|
108 |
-
gr.Markdown("# Autistic Assistant vß Edition 2024 Gertrude")
|
109 |
with gr.Row():
|
110 |
with gr.Column():
|
111 |
-
user_input = gr.Textbox(label="What will you say to Gertrude?", placeholder="Type something here...
|
112 |
submit_button = gr.Button("Send")
|
113 |
with gr.Column():
|
114 |
chatbot = gr.Textbox(label="Gertrude's Response", interactive=False)
|
|
|
|
|
|
|
115 |
|
116 |
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot)
|
117 |
|
|
|
|
|
|
|
1 |
# Install necessary libraries (if running locally)
|
2 |
# !pip install transformers torch textblob numpy gradio
|
3 |
|
4 |
# Import necessary libraries
|
5 |
import torch
|
|
|
6 |
import torch.nn as nn
|
7 |
+
import random
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
9 |
from textblob import TextBlob
|
10 |
import gradio as gr
|
11 |
import pickle
|
12 |
import numpy as np
|
13 |
+
import torch.nn.functional as F
|
14 |
|
15 |
# ---- Constants and Setup ----
|
16 |
model_name = 'gpt2'
|
|
|
50 |
|
51 |
def adjust_for_emotion(response, sentiment):
|
52 |
if sentiment > 0.2:
|
53 |
+
return f"That's wonderful! I'm glad you're feeling good: {response}"
|
54 |
elif sentiment < -0.2:
|
55 |
+
return f"I'm sorry to hear that: {response}. How can I assist you further?"
|
56 |
return response
|
57 |
|
58 |
+
# ---- Neural Network Models ----
|
59 |
+
|
60 |
+
# 1. RNN Model for Sentiment Classification
|
61 |
+
class RNN(nn.Module):
|
62 |
+
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
|
63 |
+
super(RNN, self).__init__()
|
64 |
+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
65 |
+
self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
|
66 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
embedded = self.embedding(x)
|
70 |
+
out, _ = self.rnn(embedded)
|
71 |
+
out = out[:, -1, :] # Get the last hidden state
|
72 |
+
out = self.fc(out)
|
73 |
+
return out
|
74 |
+
|
75 |
+
# 2. CNN Model for Text Classification
|
76 |
+
class CNN(nn.Module):
|
77 |
+
def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, output_dim):
|
78 |
+
super(CNN, self).__init__()
|
79 |
+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
80 |
+
self.convs = nn.ModuleList([
|
81 |
+
nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes
|
82 |
+
])
|
83 |
+
self.fc = nn.Linear(num_filters * len(filter_sizes), output_dim)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
embedded = self.embedding(x).unsqueeze(1) # Add channel dimension
|
87 |
+
conv_results = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
|
88 |
+
pooled_results = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in conv_results]
|
89 |
+
cat_results = torch.cat(pooled_results, dim=1)
|
90 |
+
out = self.fc(cat_results)
|
91 |
+
return out
|
92 |
+
|
93 |
+
# 3. Simple Feed-Forward Neural Network (NN) for additional processing
|
94 |
+
class FFNN(nn.Module):
|
95 |
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
96 |
+
super(FFNN, self).__init__()
|
97 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
98 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
x = F.relu(self.fc1(x))
|
102 |
+
x = self.fc2(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
# Initialize models
|
106 |
+
rnn_model = RNN(vocab_size=len(tokenizer), embedding_dim=100, hidden_dim=128, output_dim=2).to(device)
|
107 |
+
cnn_model = CNN(vocab_size=len(tokenizer), embedding_dim=100, num_filters=64, filter_sizes=[3, 4, 5], output_dim=2).to(device)
|
108 |
+
ffnn_model = FFNN(input_dim=100, hidden_dim=50, output_dim=1).to(device)
|
109 |
+
|
110 |
# ---- Response Generation ----
|
111 |
def generate_response(prompt, max_length=1024):
|
112 |
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
|
|
|
156 |
return response
|
157 |
|
158 |
with gr.Blocks() as app:
|
159 |
+
gr.Markdown("# **Autistic Assistant vß Edition 2024 Ultra: Gertrude's Autistic Experience**")
|
160 |
with gr.Row():
|
161 |
with gr.Column():
|
162 |
+
user_input = gr.Textbox(label="What will you say to Gertrude?", placeholder="Type something here... Expect 1-2 Minute Response Times...")
|
163 |
submit_button = gr.Button("Send")
|
164 |
with gr.Column():
|
165 |
chatbot = gr.Textbox(label="Gertrude's Response", interactive=False)
|
166 |
+
|
167 |
+
# Theme the UI with colors and a pattern
|
168 |
+
gr.HTML("<style>.gradio-container { background-color: #F4F8FF; padding: 20px; border-radius: 15px; font-family: 'Comic Sans MS'; }</style>")
|
169 |
|
170 |
submit_button.click(chat_interface, inputs=user_input, outputs=chatbot)
|
171 |
|