Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -88,12 +88,12 @@ criterion = nn.CrossEntropyLoss()
|
|
88 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
89 |
|
90 |
train_dataset = MemoryEfficientDataset(X_train, y_train, batch_size=32)
|
91 |
-
train_loader = DataLoader(train_dataset, batch_size=None)
|
92 |
|
93 |
num_epochs = 100
|
94 |
for epoch in range(num_epochs):
|
95 |
for batch_X, batch_y in train_loader:
|
96 |
-
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
97 |
outputs = model(batch_X)
|
98 |
loss = criterion(outputs, batch_y)
|
99 |
optimizer.zero_grad()
|
@@ -130,8 +130,11 @@ emotions = {
|
|
130 |
'optimism': {'percentage': 10, 'motivation': 'hopeful', 'intensity': 0},
|
131 |
'pessimism': {'percentage': 10, 'motivation': 'doubtful', 'intensity': 0},
|
132 |
'boredom': {'percentage': 10, 'motivation': 'indifferent', 'intensity': 0},
|
133 |
-
'envy': {'percentage': 10, 'motivation': 'jealous', 'intensity': 0}
|
134 |
-
|
|
|
|
|
|
|
135 |
|
136 |
def load_historical_data(file_path=emotion_history_file):
|
137 |
if os.path.exists(file_path):
|
@@ -199,6 +202,16 @@ def evolve_emotions():
|
|
199 |
emotions['ideal_state']['percentage'] = ideal_state
|
200 |
|
201 |
# Lazy loading for the language models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
_bloom_tokenizer = None
|
203 |
_bloom_lm_model = None
|
204 |
def get_bloom_model():
|
@@ -209,22 +222,12 @@ def get_bloom_model():
|
|
209 |
_bloom_lm_model = AutoModelForCausalLM.from_pretrained(bloom_model_name, device_map="auto", low_cpu_mem_usage=True)
|
210 |
return _bloom_tokenizer, _bloom_lm_model
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
if _gpt_tokenizer is None or _gpt_lm_model is None:
|
217 |
-
gpt_model_name = 'gpt2-medium'
|
218 |
-
_gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
|
219 |
-
_gpt_lm_model = AutoModelForCausalLM.from_pretrained(gpt_model_name, device_map="auto", low_cpu_mem_usage=True)
|
220 |
-
return _gpt_tokenizer, _gpt_lm_model
|
221 |
-
|
222 |
-
def generate_text(prompt, max_length=100, model_type='bloom'):
|
223 |
-
if model_type == 'bloom':
|
224 |
-
bloom_tokenizer, bloom_lm_model = get_bloom_model()
|
225 |
-
input_ids = bloom_tokenizer.encode(prompt, return_tensors='pt').to(bloom_lm_model.device)
|
226 |
with torch.no_grad():
|
227 |
-
output =
|
228 |
input_ids,
|
229 |
max_length=max_length,
|
230 |
num_return_sequences=1,
|
@@ -234,12 +237,12 @@ def generate_text(prompt, max_length=100, model_type='bloom'):
|
|
234 |
top_p=0.95,
|
235 |
temperature=0.7
|
236 |
)
|
237 |
-
generated_text =
|
238 |
-
elif model_type == '
|
239 |
-
|
240 |
-
input_ids =
|
241 |
with torch.no_grad():
|
242 |
-
output =
|
243 |
input_ids,
|
244 |
max_length=max_length,
|
245 |
num_return_sequences=1,
|
@@ -249,9 +252,9 @@ def generate_text(prompt, max_length=100, model_type='bloom'):
|
|
249 |
top_p=0.95,
|
250 |
temperature=0.7
|
251 |
)
|
252 |
-
generated_text =
|
253 |
else:
|
254 |
-
raise ValueError("Invalid model type. Choose '
|
255 |
|
256 |
return generated_text
|
257 |
|
@@ -270,25 +273,25 @@ def process_input(text):
|
|
270 |
|
271 |
rf_prediction = rf_model.predict(encoded_text)[0]
|
272 |
isolation_score = isolation_forest.decision_function(encoded_text)[0]
|
273 |
-
nn_output = model(torch.LongTensor(encoded_text.toarray()).to(device))
|
274 |
nn_prediction = nn_output.argmax(dim=1).item()
|
275 |
|
276 |
predicted_emotion = emotion_classes[rf_prediction]
|
277 |
sentiment_score = isolation_score
|
|
|
278 |
bloom_generated_text = generate_text(normalized_text, model_type='bloom')
|
279 |
-
gpt_generated_text = generate_text(normalized_text, model_type='gpt')
|
280 |
|
281 |
historical_data = load_historical_data()
|
282 |
historical_data.append({
|
283 |
'context': text,
|
284 |
'predicted_emotion': predicted_emotion,
|
285 |
'sentiment_score': sentiment_score,
|
286 |
-
'
|
287 |
-
'
|
288 |
})
|
289 |
save_historical_data(historical_data)
|
290 |
|
291 |
-
return predicted_emotion, sentiment_score,
|
292 |
|
293 |
except Exception as e:
|
294 |
error_message = f"An error occurred: {str(e)}"
|
@@ -301,8 +304,8 @@ iface = gr.Interface(
|
|
301 |
outputs=[
|
302 |
gr.Textbox(label="Emotional Response"),
|
303 |
gr.Textbox(label="Sentiment Response"),
|
304 |
-
gr.Textbox(label="
|
305 |
-
gr.Textbox(label="
|
306 |
],
|
307 |
live=True
|
308 |
)
|
|
|
88 |
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
89 |
|
90 |
train_dataset = MemoryEfficientDataset(X_train, y_train, batch_size=32)
|
91 |
+
train_loader = DataLoader(train_dataset, batch_size=None, num_workers=4, pin_memory=True)
|
92 |
|
93 |
num_epochs = 100
|
94 |
for epoch in range(num_epochs):
|
95 |
for batch_X, batch_y in train_loader:
|
96 |
+
batch_X, batch_y = batch_X.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
|
97 |
outputs = model(batch_X)
|
98 |
loss = criterion(outputs, batch_y)
|
99 |
optimizer.zero_grad()
|
|
|
130 |
'optimism': {'percentage': 10, 'motivation': 'hopeful', 'intensity': 0},
|
131 |
'pessimism': {'percentage': 10, 'motivation': 'doubtful', 'intensity': 0},
|
132 |
'boredom': {'percentage': 10, 'motivation': 'indifferent', 'intensity': 0},
|
133 |
+
'envy': {'percentage': 10, 'motivation': 'jealous', 'intensity': 0}
|
134 |
+
}
|
135 |
+
total_percentage = 200
|
136 |
+
default_percentage = total_percentage / len(emotions)
|
137 |
+
for emotion in emotion_history_file = 'emotion_history.json'
|
138 |
|
139 |
def load_historical_data(file_path=emotion_history_file):
|
140 |
if os.path.exists(file_path):
|
|
|
202 |
emotions['ideal_state']['percentage'] = ideal_state
|
203 |
|
204 |
# Lazy loading for the language models
|
205 |
+
_distilgpt3_tokenizer = None
|
206 |
+
_distilgpt3_lm_model = None
|
207 |
+
def get_distilgpt3_model():
|
208 |
+
global _distilgpt3_tokenizer, _distilgpt3_lm_model
|
209 |
+
if _distilgpt3_tokenizer is None or _distilgpt3_lm_model is None:
|
210 |
+
distilgpt3_model_name = 'distilgpt2' # Replace with the fine-tuned DistilGPT-3 model name
|
211 |
+
_distilgpt3_tokenizer = AutoTokenizer.from_pretrained(distilgpt3_model_name)
|
212 |
+
_distilgpt3_lm_model = AutoModelForCausalLM.from_pretrained(distilgpt3_model_name, device_map="auto", low_cpu_mem_usage=True)
|
213 |
+
return _distilgpt3_tokenizer, _distilgpt3_lm_model
|
214 |
+
|
215 |
_bloom_tokenizer = None
|
216 |
_bloom_lm_model = None
|
217 |
def get_bloom_model():
|
|
|
222 |
_bloom_lm_model = AutoModelForCausalLM.from_pretrained(bloom_model_name, device_map="auto", low_cpu_mem_usage=True)
|
223 |
return _bloom_tokenizer, _bloom_lm_model
|
224 |
|
225 |
+
def generate_text(prompt, max_length=100, model_type='distilgpt3'):
|
226 |
+
if model_type == 'distilgpt3':
|
227 |
+
distilgpt3_tokenizer, distilgpt3_lm_model = get_distilgpt3_model()
|
228 |
+
input_ids = distilgpt3_tokenizer.encode(prompt, return_tensors='pt').to(distilgpt3_lm_model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
with torch.no_grad():
|
230 |
+
output = distilgpt3_lm_model.generate(
|
231 |
input_ids,
|
232 |
max_length=max_length,
|
233 |
num_return_sequences=1,
|
|
|
237 |
top_p=0.95,
|
238 |
temperature=0.7
|
239 |
)
|
240 |
+
generated_text = distilgpt3_tokenizer.decode(output[0], skip_special_tokens=True)
|
241 |
+
elif model_type == 'bloom':
|
242 |
+
bloom_tokenizer, bloom_lm_model = get_bloom_model()
|
243 |
+
input_ids = bloom_tokenizer.encode(prompt, return_tensors='pt').to(bloom_lm_model.device)
|
244 |
with torch.no_grad():
|
245 |
+
output = bloom_lm_model.generate(
|
246 |
input_ids,
|
247 |
max_length=max_length,
|
248 |
num_return_sequences=1,
|
|
|
252 |
top_p=0.95,
|
253 |
temperature=0.7
|
254 |
)
|
255 |
+
generated_text = bloom_tokenizer.decode(output[0], skip_special_tokens=True)
|
256 |
else:
|
257 |
+
raise ValueError("Invalid model type. Choose 'distilgpt3' or 'bloom'.")
|
258 |
|
259 |
return generated_text
|
260 |
|
|
|
273 |
|
274 |
rf_prediction = rf_model.predict(encoded_text)[0]
|
275 |
isolation_score = isolation_forest.decision_function(encoded_text)[0]
|
276 |
+
nn_output = model(torch.LongTensor(encoded_text.toarray()).to(device, non_blocking=True))
|
277 |
nn_prediction = nn_output.argmax(dim=1).item()
|
278 |
|
279 |
predicted_emotion = emotion_classes[rf_prediction]
|
280 |
sentiment_score = isolation_score
|
281 |
+
distilgpt3_generated_text = generate_text(normalized_text, model_type='distilgpt3')
|
282 |
bloom_generated_text = generate_text(normalized_text, model_type='bloom')
|
|
|
283 |
|
284 |
historical_data = load_historical_data()
|
285 |
historical_data.append({
|
286 |
'context': text,
|
287 |
'predicted_emotion': predicted_emotion,
|
288 |
'sentiment_score': sentiment_score,
|
289 |
+
'distilgpt3_generated_text': distilgpt3_generated_text,
|
290 |
+
'bloom_generated_text': bloom_generated_text
|
291 |
})
|
292 |
save_historical_data(historical_data)
|
293 |
|
294 |
+
return predicted_emotion, sentiment_score, distilgpt3_generated_text, bloom_generated_text
|
295 |
|
296 |
except Exception as e:
|
297 |
error_message = f"An error occurred: {str(e)}"
|
|
|
304 |
outputs=[
|
305 |
gr.Textbox(label="Emotional Response"),
|
306 |
gr.Textbox(label="Sentiment Response"),
|
307 |
+
gr.Textbox(label="DistilGPT-3 Generated Text"),
|
308 |
+
gr.Textbox(label="BLOOM Generated Text")
|
309 |
],
|
310 |
live=True
|
311 |
)
|