lobrien001 commited on
Commit
8323531
·
verified ·
1 Parent(s): 1187ab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -12
app.py CHANGED
@@ -30,7 +30,7 @@ logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asct
30
  chat_queue = Queue() # Define chat_queue globally
31
 
32
  # --- Chat Function with Monitoring ---
33
- def chat_function(message, ground_truth):
34
  logging.debug("Starting chat_function")
35
  with REQUEST_LATENCY.time():
36
  REQUEST_COUNT.inc()
@@ -42,7 +42,8 @@ def chat_function(message, ground_truth):
42
  logging.debug(f"NER results: {ner_results}")
43
 
44
  detailed_response = []
45
- predicted_labels = []
 
46
  for result in ner_results:
47
  token = result['word']
48
  score = result['score']
@@ -50,7 +51,7 @@ def chat_function(message, ground_truth):
50
  start = result['start']
51
  end = result['end']
52
  label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
53
- predicted_labels.append(label_id)
54
  detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
55
 
56
  response = "\n".join(detailed_response)
@@ -61,15 +62,21 @@ def chat_function(message, ground_truth):
61
 
62
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
63
 
64
- # Compute metrics
65
  try:
66
- ground_truth_labels = json.loads(ground_truth) # Assuming ground_truth is input as a JSON string
 
 
 
 
 
 
67
  except json.JSONDecodeError:
68
- return "Invalid JSON format for ground truth labels. Please provide a valid JSON array."
69
 
70
- precision = precision_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
71
- recall = recall_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
72
- f1 = f1_score(ground_truth_labels, predicted_labels, average='weighted', zero_division=0)
73
 
74
  metrics_response = (f"Precision: {precision:.4f}\n"
75
  f"Recall: {recall:.4f}\n"
@@ -115,9 +122,9 @@ body {
115
  with gr.Tab("Chat"):
116
  gr.Markdown("## Chat with the Bot")
117
  message_input = gr.Textbox(label="Enter your sentence:", lines=2)
118
- ground_truth_input = gr.Textbox(label="Enter ground truth labels (JSON format):", lines=2)
119
  output = gr.Textbox(label="Response", lines=10)
120
- chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, ground_truth_input], outputs=output)
121
  chat_interface.render()
122
 
123
  with gr.Tab("Model Parameters"):
@@ -199,4 +206,4 @@ body {
199
  threading.Thread(target=update_queue_length, daemon=True).start()
200
 
201
  # Launch the app
202
- demo.launch()
 
30
  chat_queue = Queue() # Define chat_queue globally
31
 
32
  # --- Chat Function with Monitoring ---
33
+ def chat_function(message, user_ner_tags, ground_truth):
34
  logging.debug("Starting chat_function")
35
  with REQUEST_LATENCY.time():
36
  REQUEST_COUNT.inc()
 
42
  logging.debug(f"NER results: {ner_results}")
43
 
44
  detailed_response = []
45
+ model_predicted_labels = []
46
+ user_predicted_labels = []
47
  for result in ner_results:
48
  token = result['word']
49
  score = result['score']
 
51
  start = result['start']
52
  end = result['end']
53
  label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
54
+ model_predicted_labels.append(label_id)
55
  detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
56
 
57
  response = "\n".join(detailed_response)
 
62
 
63
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
64
 
65
+ # Compare user's input tags to the model's output
66
  try:
67
+ user_ner_results = json.loads(user_ner_tags)
68
+ if not isinstance(user_ner_results, list):
69
+ raise ValueError("Invalid format for user NER tags. Please provide a JSON list of dictionaries.")
70
+ for result in user_ner_results:
71
+ entity = result['entity']
72
+ label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
73
+ user_predicted_labels.append(label_id)
74
  except json.JSONDecodeError:
75
+ return "Invalid JSON format for user NER tags. Please provide a valid JSON array."
76
 
77
+ precision = precision_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
78
+ recall = recall_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
79
+ f1 = f1_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
80
 
81
  metrics_response = (f"Precision: {precision:.4f}\n"
82
  f"Recall: {recall:.4f}\n"
 
122
  with gr.Tab("Chat"):
123
  gr.Markdown("## Chat with the Bot")
124
  message_input = gr.Textbox(label="Enter your sentence:", lines=2)
125
+ user_ner_tags_input = gr.Textbox(label="Enter your NER tags (JSON format):", lines=5)
126
  output = gr.Textbox(label="Response", lines=10)
127
+ chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, user_ner_tags_input, gr.Textbox(lines=5)], outputs=output)
128
  chat_interface.render()
129
 
130
  with gr.Tab("Model Parameters"):
 
206
  threading.Thread(target=update_queue_length, daemon=True).start()
207
 
208
  # Launch the app
209
+ demo.launch(share=True)