Sephfox commited on
Commit
b6382cd
·
verified ·
1 Parent(s): 6148fb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -8
app.py CHANGED
@@ -135,8 +135,7 @@ emotions = {
135
  total_percentage = 200
136
  default_percentage = total_percentage / len(emotions)
137
  for emotion in emotions:
138
- emotions[emotion]['percentage'] = default_percentage
139
- emotion_history_file = 'emotion_history.json'
140
 
141
  def load_historical_data(file_path=emotion_history_file):
142
  if os.path.exists(file_path):
@@ -203,6 +202,71 @@ def evolve_emotions():
203
 
204
  emotions['ideal_state']['percentage'] = ideal_state
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def process_input(text):
207
  try:
208
  normalized_text = normalize_context(text)
@@ -214,23 +278,25 @@ def process_input(text):
214
 
215
  predicted_emotion = emotion_classes[rf_prediction]
216
  sentiment_score = isolation_score
217
- generated_text = emotion_classes[nn_prediction]
 
218
 
219
  historical_data = load_historical_data()
220
  historical_data.append({
221
  'context': text,
222
  'predicted_emotion': predicted_emotion,
223
  'sentiment_score': sentiment_score,
224
- 'generated_text': generated_text
 
225
  })
226
  save_historical_data(historical_data)
227
 
228
- return predicted_emotion, sentiment_score, generated_text
229
 
230
  except Exception as e:
231
  error_message = f"An error occurred: {str(e)}"
232
  print(error_message) # Logging the error
233
- return error_message, error_message, error_message
234
 
235
  iface = gr.Interface(
236
  fn=process_input,
@@ -238,9 +304,10 @@ iface = gr.Interface(
238
  outputs=[
239
  gr.Textbox(label="Emotional Response"),
240
  gr.Textbox(label="Sentiment Response"),
241
- gr.Textbox(label="Generated Text")
 
242
  ],
243
  live=True
244
  )
245
 
246
- iface.launch(share=True)
 
135
  total_percentage = 200
136
  default_percentage = total_percentage / len(emotions)
137
  for emotion in emotions:
138
+ emotions[emotion]['emotion_history_file = 'emotion_history.json'
 
139
 
140
  def load_historical_data(file_path=emotion_history_file):
141
  if os.path.exists(file_path):
 
202
 
203
  emotions['ideal_state']['percentage'] = ideal_state
204
 
205
+ # Lazy loading for the language models
206
+ _bloom_tokenizer = None
207
+ _bloom_lm_model = None
208
+ def get_bloom_model():
209
+ global _bloom_tokenizer, _bloom_lm_model
210
+ if _bloom_tokenizer is None or _bloom_lm_model is None:
211
+ bloom_model_name = 'bigscience/bloom-1b7'
212
+ _bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name)
213
+ _bloom_lm_model = AutoModelForCausalLM.from_pretrained(bloom_model_name, device_map="auto", low_cpu_mem_usage=True)
214
+ return _bloom_tokenizer, _bloom_lm_model
215
+
216
+ _gpt_tokenizer = None
217
+ _gpt_lm_model = None
218
+ def get_gpt_model():
219
+ global _gpt_tokenizer, _gpt_lm_model
220
+ if _gpt_tokenizer is None or _gpt_lm_model is None:
221
+ gpt_model_name = 'gpt2-medium'
222
+ _gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
223
+ _gpt_lm_model = AutoModelForCausalLM.from_pretrained(gpt_model_name, device_map="auto", low_cpu_mem_usage=True)
224
+ return _gpt_tokenizer, _gpt_lm_model
225
+
226
+ def generate_text(prompt, max_length=100, model_type='bloom'):
227
+ if model_type == 'bloom':
228
+ bloom_tokenizer, bloom_lm_model = get_bloom_model()
229
+ input_ids = bloom_tokenizer.encode(prompt, return_tensors='pt').to(bloom_lm_model.device)
230
+ with torch.no_grad():
231
+ output = bloom_lm_model.generate(
232
+ input_ids,
233
+ max_length=max_length,
234
+ num_return_sequences=1,
235
+ no_repeat_ngram_size=2,
236
+ do_sample=True,
237
+ top_k=50,
238
+ top_p=0.95,
239
+ temperature=0.7
240
+ )
241
+ generated_text = bloom_tokenizer.decode(output[0], skip_special_tokens=True)
242
+ elif model_type == 'gpt':
243
+ gpt_tokenizer, gpt_lm_model = get_gpt_model()
244
+ input_ids = gpt_tokenizer.encode(prompt, return_tensors='pt').to(gpt_lm_model.device)
245
+ with torch.no_grad():
246
+ output = gpt_lm_model.generate(
247
+ input_ids,
248
+ max_length=max_length,
249
+ num_return_sequences=1,
250
+ no_repeat_ngram_size=2,
251
+ do_sample=True,
252
+ top_k=50,
253
+ top_p=0.95,
254
+ temperature=0.7
255
+ )
256
+ generated_text = gpt_tokenizer.decode(output[0], skip_special_tokens=True)
257
+ else:
258
+ raise ValueError("Invalid model type. Choose 'bloom' or 'gpt'.")
259
+
260
+ return generated_text
261
+
262
+ model_name = "distilbert-base-uncased-finetuned-sst-2-english"
263
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
264
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
265
+ sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
266
+ def get_sentiment(text):
267
+ result = sentiment_pipeline(text)[0]
268
+ return f"Sentiment: {result['label']}, Score: {result['score']:.4f}"
269
+
270
  def process_input(text):
271
  try:
272
  normalized_text = normalize_context(text)
 
278
 
279
  predicted_emotion = emotion_classes[rf_prediction]
280
  sentiment_score = isolation_score
281
+ bloom_generated_text = generate_text(normalized_text, model_type='bloom')
282
+ gpt_generated_text = generate_text(normalized_text, model_type='gpt')
283
 
284
  historical_data = load_historical_data()
285
  historical_data.append({
286
  'context': text,
287
  'predicted_emotion': predicted_emotion,
288
  'sentiment_score': sentiment_score,
289
+ 'bloom_generated_text': bloom_generated_text,
290
+ 'gpt_generated_text': gpt_generated_text
291
  })
292
  save_historical_data(historical_data)
293
 
294
+ return predicted_emotion, sentiment_score, bloom_generated_text, gpt_generated_text
295
 
296
  except Exception as e:
297
  error_message = f"An error occurred: {str(e)}"
298
  print(error_message) # Logging the error
299
+ return error_message, error_message, error_message, error_message
300
 
301
  iface = gr.Interface(
302
  fn=process_input,
 
304
  outputs=[
305
  gr.Textbox(label="Emotional Response"),
306
  gr.Textbox(label="Sentiment Response"),
307
+ gr.Textbox(label="BLOOM Generated Text"),
308
+ gr.Textbox(label="GPT Generated Text")
309
  ],
310
  live=True
311
  )
312
 
313
+ iface.launch(share=True)