lobrien001 commited on
Commit
aa2b41a
·
verified ·
1 Parent(s): 8323531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -35
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import logging
2
  import gradio as gr
3
- from queue import Queue
4
  import time
5
  from prometheus_client import start_http_server, Counter, Histogram, Gauge
6
  import threading
@@ -8,12 +7,15 @@ import psutil
8
  import random
9
  from transformers import pipeline
10
  from sklearn.metrics import precision_score, recall_score, f1_score
11
- import json
12
  import requests
 
13
 
14
  # Load the model
15
  ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
16
 
 
 
 
17
  # --- Prometheus Metrics Setup ---
18
  REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
19
  REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
@@ -30,29 +32,35 @@ logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asct
30
  chat_queue = Queue() # Define chat_queue globally
31
 
32
  # --- Chat Function with Monitoring ---
33
- def chat_function(message, user_ner_tags, ground_truth):
34
  logging.debug("Starting chat_function")
35
  with REQUEST_LATENCY.time():
36
  REQUEST_COUNT.inc()
37
  try:
38
- chat_queue.put(message)
39
- logging.info(f"Received message from user: {message}")
 
 
 
 
 
 
 
 
40
 
41
- ner_results = ner_pipeline(message)
 
42
  logging.debug(f"NER results: {ner_results}")
43
 
44
  detailed_response = []
45
  model_predicted_labels = []
46
- user_predicted_labels = []
47
  for result in ner_results:
48
  token = result['word']
49
  score = result['score']
50
  entity = result['entity']
51
- start = result['start']
52
- end = result['end']
53
  label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
54
  model_predicted_labels.append(label_id)
55
- detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}, Start: {start}, End: {end}")
56
 
57
  response = "\n".join(detailed_response)
58
  logging.info(f"Generated response: {response}")
@@ -62,21 +70,12 @@ def chat_function(message, user_ner_tags, ground_truth):
62
 
63
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
64
 
65
- # Compare user's input tags to the model's output
66
- try:
67
- user_ner_results = json.loads(user_ner_tags)
68
- if not isinstance(user_ner_results, list):
69
- raise ValueError("Invalid format for user NER tags. Please provide a JSON list of dictionaries.")
70
- for result in user_ner_results:
71
- entity = result['entity']
72
- label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
73
- user_predicted_labels.append(label_id)
74
- except json.JSONDecodeError:
75
- return "Invalid JSON format for user NER tags. Please provide a valid JSON array."
76
-
77
- precision = precision_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
78
- recall = recall_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
79
- f1 = f1_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
80
 
81
  metrics_response = (f"Precision: {precision:.4f}\n"
82
  f"Recall: {recall:.4f}\n"
@@ -93,10 +92,10 @@ def chat_function(message, user_ner_tags, ground_truth):
93
  return f"An error occurred. Please try again. Error: {e}"
94
 
95
  # Function to simulate stress test
96
- def stress_test(num_requests, message, delay):
97
  def send_chat_message():
98
  response = requests.post("http://127.0.0.1:7860/api/predict/", json={
99
- "data": [message],
100
  "fn_index": 0 # This might need to be updated based on your Gradio app's function index
101
  })
102
  logging.debug(response.json())
@@ -121,10 +120,9 @@ body {
121
  """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
122
  with gr.Tab("Chat"):
123
  gr.Markdown("## Chat with the Bot")
124
- message_input = gr.Textbox(label="Enter your sentence:", lines=2)
125
- user_ner_tags_input = gr.Textbox(label="Enter your NER tags (JSON format):", lines=5)
126
- output = gr.Textbox(label="Response", lines=10)
127
- chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, user_ner_tags_input, gr.Textbox(lines=5)], outputs=output)
128
  chat_interface.render()
129
 
130
  with gr.Tab("Model Parameters"):
@@ -143,20 +141,20 @@ body {
143
 
144
  with gr.Tab("Stress Testing"):
145
  num_requests_input = gr.Number(label="Number of Requests", value=10)
146
- message_input_stress = gr.Textbox(label="Message", value="Hello bot!")
147
  delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
148
  stress_test_button = gr.Button("Start Stress Test")
149
  stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
150
 
151
- def run_stress_test(num_requests, message, delay):
152
  stress_test_status.value = "Stress test started..."
153
  try:
154
- stress_test(num_requests, message, delay)
155
  stress_test_status.value = "Stress test completed."
156
  except Exception as e:
157
  stress_test_status.value = f"Stress test failed: {e}"
158
 
159
- stress_test_button.click(run_stress_test, [num_requests_input, message_input_stress, delay_input], stress_test_status)
160
 
161
  # --- Update Functions ---
162
  def update_metrics(request_count_display, avg_latency_display):
 
1
  import logging
2
  import gradio as gr
 
3
  import time
4
  from prometheus_client import start_http_server, Counter, Histogram, Gauge
5
  import threading
 
7
  import random
8
  from transformers import pipeline
9
  from sklearn.metrics import precision_score, recall_score, f1_score
 
10
  import requests
11
+ from datasets import load_dataset
12
 
13
  # Load the model
14
  ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
15
 
16
+ # Load the dataset
17
+ dataset = load_dataset("surrey-nlp/PLOD-filtered")
18
+
19
  # --- Prometheus Metrics Setup ---
20
  REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
21
  REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
 
32
  chat_queue = Queue() # Define chat_queue globally
33
 
34
  # --- Chat Function with Monitoring ---
35
+ def chat_function(index):
36
  logging.debug("Starting chat_function")
37
  with REQUEST_LATENCY.time():
38
  REQUEST_COUNT.inc()
39
  try:
40
+ chat_queue.put(index)
41
+ logging.info(f"Received index from user: {index}")
42
+
43
+ # Get the example from the dataset
44
+ example = dataset['train'][int(index)]
45
+ tokens = example['tokens']
46
+ ground_truth_labels = example['ner_tags']
47
+
48
+ logging.info(f"Tokens: {tokens}")
49
+ logging.info(f"Ground Truth Labels: {ground_truth_labels}")
50
 
51
+ # Predict using the model
52
+ ner_results = ner_pipeline(" ".join(tokens))
53
  logging.debug(f"NER results: {ner_results}")
54
 
55
  detailed_response = []
56
  model_predicted_labels = []
 
57
  for result in ner_results:
58
  token = result['word']
59
  score = result['score']
60
  entity = result['entity']
 
 
61
  label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
62
  model_predicted_labels.append(label_id)
63
+ detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}")
64
 
65
  response = "\n".join(detailed_response)
66
  logging.info(f"Generated response: {response}")
 
70
 
71
  time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
72
 
73
+ # Ensure the model and ground truth labels are the same length for comparison
74
+ model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)]
75
+
76
+ precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
77
+ recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
78
+ f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
 
 
 
 
 
 
 
 
 
79
 
80
  metrics_response = (f"Precision: {precision:.4f}\n"
81
  f"Recall: {recall:.4f}\n"
 
92
  return f"An error occurred. Please try again. Error: {e}"
93
 
94
  # Function to simulate stress test
95
+ def stress_test(num_requests, index, delay):
96
  def send_chat_message():
97
  response = requests.post("http://127.0.0.1:7860/api/predict/", json={
98
+ "data": [index],
99
  "fn_index": 0 # This might need to be updated based on your Gradio app's function index
100
  })
101
  logging.debug(response.json())
 
120
  """, title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
121
  with gr.Tab("Chat"):
122
  gr.Markdown("## Chat with the Bot")
123
+ index_input = gr.Textbox(label="Enter dataset index:", lines=1)
124
+ output = gr.Textbox(label="Response", lines=20)
125
+ chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output)
 
126
  chat_interface.render()
127
 
128
  with gr.Tab("Model Parameters"):
 
141
 
142
  with gr.Tab("Stress Testing"):
143
  num_requests_input = gr.Number(label="Number of Requests", value=10)
144
+ index_input_stress = gr.Textbox(label="Dataset Index", value="2")
145
  delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
146
  stress_test_button = gr.Button("Start Stress Test")
147
  stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
148
 
149
+ def run_stress_test(num_requests, index, delay):
150
  stress_test_status.value = "Stress test started..."
151
  try:
152
+ stress_test(num_requests, index, delay)
153
  stress_test_status.value = "Stress test completed."
154
  except Exception as e:
155
  stress_test_status.value = f"Stress test failed: {e}"
156
 
157
+ stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status)
158
 
159
  # --- Update Functions ---
160
  def update_metrics(request_count_display, avg_latency_display):