Sephfox commited on
Commit
1a498d3
·
verified ·
1 Parent(s): faa570b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -33
app.py CHANGED
@@ -88,12 +88,12 @@ criterion = nn.CrossEntropyLoss()
88
  optimizer = optim.Adam(model.parameters(), lr=0.001)
89
 
90
  train_dataset = MemoryEfficientDataset(X_train, y_train, batch_size=32)
91
- train_loader = DataLoader(train_dataset, batch_size=None)
92
 
93
  num_epochs = 100
94
  for epoch in range(num_epochs):
95
  for batch_X, batch_y in train_loader:
96
- batch_X, batch_y = batch_X.to(device), batch_y.to(device)
97
  outputs = model(batch_X)
98
  loss = criterion(outputs, batch_y)
99
  optimizer.zero_grad()
@@ -130,8 +130,11 @@ emotions = {
130
  'optimism': {'percentage': 10, 'motivation': 'hopeful', 'intensity': 0},
131
  'pessimism': {'percentage': 10, 'motivation': 'doubtful', 'intensity': 0},
132
  'boredom': {'percentage': 10, 'motivation': 'indifferent', 'intensity': 0},
133
- 'envy': {'percentage': 10, 'motivation': 'jealous', 'intensity': 0},
134
- emotion_history_file = 'emotion_history.json'
 
 
 
135
 
136
  def load_historical_data(file_path=emotion_history_file):
137
  if os.path.exists(file_path):
@@ -199,6 +202,16 @@ def evolve_emotions():
199
  emotions['ideal_state']['percentage'] = ideal_state
200
 
201
  # Lazy loading for the language models
 
 
 
 
 
 
 
 
 
 
202
  _bloom_tokenizer = None
203
  _bloom_lm_model = None
204
  def get_bloom_model():
@@ -209,22 +222,12 @@ def get_bloom_model():
209
  _bloom_lm_model = AutoModelForCausalLM.from_pretrained(bloom_model_name, device_map="auto", low_cpu_mem_usage=True)
210
  return _bloom_tokenizer, _bloom_lm_model
211
 
212
- _gpt_tokenizer = None
213
- _gpt_lm_model = None
214
- def get_gpt_model():
215
- global _gpt_tokenizer, _gpt_lm_model
216
- if _gpt_tokenizer is None or _gpt_lm_model is None:
217
- gpt_model_name = 'gpt2-medium'
218
- _gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
219
- _gpt_lm_model = AutoModelForCausalLM.from_pretrained(gpt_model_name, device_map="auto", low_cpu_mem_usage=True)
220
- return _gpt_tokenizer, _gpt_lm_model
221
-
222
- def generate_text(prompt, max_length=100, model_type='bloom'):
223
- if model_type == 'bloom':
224
- bloom_tokenizer, bloom_lm_model = get_bloom_model()
225
- input_ids = bloom_tokenizer.encode(prompt, return_tensors='pt').to(bloom_lm_model.device)
226
  with torch.no_grad():
227
- output = bloom_lm_model.generate(
228
  input_ids,
229
  max_length=max_length,
230
  num_return_sequences=1,
@@ -234,12 +237,12 @@ def generate_text(prompt, max_length=100, model_type='bloom'):
234
  top_p=0.95,
235
  temperature=0.7
236
  )
237
- generated_text = bloom_tokenizer.decode(output[0], skip_special_tokens=True)
238
- elif model_type == 'gpt':
239
- gpt_tokenizer, gpt_lm_model = get_gpt_model()
240
- input_ids = gpt_tokenizer.encode(prompt, return_tensors='pt').to(gpt_lm_model.device)
241
  with torch.no_grad():
242
- output = gpt_lm_model.generate(
243
  input_ids,
244
  max_length=max_length,
245
  num_return_sequences=1,
@@ -249,9 +252,9 @@ def generate_text(prompt, max_length=100, model_type='bloom'):
249
  top_p=0.95,
250
  temperature=0.7
251
  )
252
- generated_text = gpt_tokenizer.decode(output[0], skip_special_tokens=True)
253
  else:
254
- raise ValueError("Invalid model type. Choose 'bloom' or 'gpt'.")
255
 
256
  return generated_text
257
 
@@ -270,25 +273,25 @@ def process_input(text):
270
 
271
  rf_prediction = rf_model.predict(encoded_text)[0]
272
  isolation_score = isolation_forest.decision_function(encoded_text)[0]
273
- nn_output = model(torch.LongTensor(encoded_text.toarray()).to(device))
274
  nn_prediction = nn_output.argmax(dim=1).item()
275
 
276
  predicted_emotion = emotion_classes[rf_prediction]
277
  sentiment_score = isolation_score
 
278
  bloom_generated_text = generate_text(normalized_text, model_type='bloom')
279
- gpt_generated_text = generate_text(normalized_text, model_type='gpt')
280
 
281
  historical_data = load_historical_data()
282
  historical_data.append({
283
  'context': text,
284
  'predicted_emotion': predicted_emotion,
285
  'sentiment_score': sentiment_score,
286
- 'bloom_generated_text': bloom_generated_text,
287
- 'gpt_generated_text': gpt_generated_text
288
  })
289
  save_historical_data(historical_data)
290
 
291
- return predicted_emotion, sentiment_score, bloom_generated_text, gpt_generated_text
292
 
293
  except Exception as e:
294
  error_message = f"An error occurred: {str(e)}"
@@ -301,8 +304,8 @@ iface = gr.Interface(
301
  outputs=[
302
  gr.Textbox(label="Emotional Response"),
303
  gr.Textbox(label="Sentiment Response"),
304
- gr.Textbox(label="BLOOM Generated Text"),
305
- gr.Textbox(label="GPT Generated Text")
306
  ],
307
  live=True
308
  )
 
88
  optimizer = optim.Adam(model.parameters(), lr=0.001)
89
 
90
  train_dataset = MemoryEfficientDataset(X_train, y_train, batch_size=32)
91
+ train_loader = DataLoader(train_dataset, batch_size=None, num_workers=4, pin_memory=True)
92
 
93
  num_epochs = 100
94
  for epoch in range(num_epochs):
95
  for batch_X, batch_y in train_loader:
96
+ batch_X, batch_y = batch_X.to(device, non_blocking=True), batch_y.to(device, non_blocking=True)
97
  outputs = model(batch_X)
98
  loss = criterion(outputs, batch_y)
99
  optimizer.zero_grad()
 
130
  'optimism': {'percentage': 10, 'motivation': 'hopeful', 'intensity': 0},
131
  'pessimism': {'percentage': 10, 'motivation': 'doubtful', 'intensity': 0},
132
  'boredom': {'percentage': 10, 'motivation': 'indifferent', 'intensity': 0},
133
+ 'envy': {'percentage': 10, 'motivation': 'jealous', 'intensity': 0}
134
+ }
135
+ total_percentage = 200
136
+ default_percentage = total_percentage / len(emotions)
137
+ for emotion in emotion_history_file = 'emotion_history.json'
138
 
139
  def load_historical_data(file_path=emotion_history_file):
140
  if os.path.exists(file_path):
 
202
  emotions['ideal_state']['percentage'] = ideal_state
203
 
204
  # Lazy loading for the language models
205
+ _distilgpt3_tokenizer = None
206
+ _distilgpt3_lm_model = None
207
+ def get_distilgpt3_model():
208
+ global _distilgpt3_tokenizer, _distilgpt3_lm_model
209
+ if _distilgpt3_tokenizer is None or _distilgpt3_lm_model is None:
210
+ distilgpt3_model_name = 'distilgpt2' # Replace with the fine-tuned DistilGPT-3 model name
211
+ _distilgpt3_tokenizer = AutoTokenizer.from_pretrained(distilgpt3_model_name)
212
+ _distilgpt3_lm_model = AutoModelForCausalLM.from_pretrained(distilgpt3_model_name, device_map="auto", low_cpu_mem_usage=True)
213
+ return _distilgpt3_tokenizer, _distilgpt3_lm_model
214
+
215
  _bloom_tokenizer = None
216
  _bloom_lm_model = None
217
  def get_bloom_model():
 
222
  _bloom_lm_model = AutoModelForCausalLM.from_pretrained(bloom_model_name, device_map="auto", low_cpu_mem_usage=True)
223
  return _bloom_tokenizer, _bloom_lm_model
224
 
225
+ def generate_text(prompt, max_length=100, model_type='distilgpt3'):
226
+ if model_type == 'distilgpt3':
227
+ distilgpt3_tokenizer, distilgpt3_lm_model = get_distilgpt3_model()
228
+ input_ids = distilgpt3_tokenizer.encode(prompt, return_tensors='pt').to(distilgpt3_lm_model.device)
 
 
 
 
 
 
 
 
 
 
229
  with torch.no_grad():
230
+ output = distilgpt3_lm_model.generate(
231
  input_ids,
232
  max_length=max_length,
233
  num_return_sequences=1,
 
237
  top_p=0.95,
238
  temperature=0.7
239
  )
240
+ generated_text = distilgpt3_tokenizer.decode(output[0], skip_special_tokens=True)
241
+ elif model_type == 'bloom':
242
+ bloom_tokenizer, bloom_lm_model = get_bloom_model()
243
+ input_ids = bloom_tokenizer.encode(prompt, return_tensors='pt').to(bloom_lm_model.device)
244
  with torch.no_grad():
245
+ output = bloom_lm_model.generate(
246
  input_ids,
247
  max_length=max_length,
248
  num_return_sequences=1,
 
252
  top_p=0.95,
253
  temperature=0.7
254
  )
255
+ generated_text = bloom_tokenizer.decode(output[0], skip_special_tokens=True)
256
  else:
257
+ raise ValueError("Invalid model type. Choose 'distilgpt3' or 'bloom'.")
258
 
259
  return generated_text
260
 
 
273
 
274
  rf_prediction = rf_model.predict(encoded_text)[0]
275
  isolation_score = isolation_forest.decision_function(encoded_text)[0]
276
+ nn_output = model(torch.LongTensor(encoded_text.toarray()).to(device, non_blocking=True))
277
  nn_prediction = nn_output.argmax(dim=1).item()
278
 
279
  predicted_emotion = emotion_classes[rf_prediction]
280
  sentiment_score = isolation_score
281
+ distilgpt3_generated_text = generate_text(normalized_text, model_type='distilgpt3')
282
  bloom_generated_text = generate_text(normalized_text, model_type='bloom')
 
283
 
284
  historical_data = load_historical_data()
285
  historical_data.append({
286
  'context': text,
287
  'predicted_emotion': predicted_emotion,
288
  'sentiment_score': sentiment_score,
289
+ 'distilgpt3_generated_text': distilgpt3_generated_text,
290
+ 'bloom_generated_text': bloom_generated_text
291
  })
292
  save_historical_data(historical_data)
293
 
294
+ return predicted_emotion, sentiment_score, distilgpt3_generated_text, bloom_generated_text
295
 
296
  except Exception as e:
297
  error_message = f"An error occurred: {str(e)}"
 
304
  outputs=[
305
  gr.Textbox(label="Emotional Response"),
306
  gr.Textbox(label="Sentiment Response"),
307
+ gr.Textbox(label="DistilGPT-3 Generated Text"),
308
+ gr.Textbox(label="BLOOM Generated Text")
309
  ],
310
  live=True
311
  )