Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import logging
|
2 |
import gradio as gr
|
3 |
-
from queue import Queue
|
4 |
import time
|
5 |
from prometheus_client import start_http_server, Counter, Histogram, Gauge
|
6 |
import threading
|
@@ -8,12 +7,15 @@ import psutil
|
|
8 |
import random
|
9 |
from transformers import pipeline
|
10 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
11 |
-
import json
|
12 |
import requests
|
|
|
13 |
|
14 |
# Load the model
|
15 |
ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
|
16 |
|
|
|
|
|
|
|
17 |
# --- Prometheus Metrics Setup ---
|
18 |
REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
|
19 |
REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
|
@@ -30,29 +32,35 @@ logging.basicConfig(filename="chat_log.txt", level=logging.DEBUG, format='%(asct
|
|
30 |
chat_queue = Queue() # Define chat_queue globally
|
31 |
|
32 |
# --- Chat Function with Monitoring ---
|
33 |
-
def chat_function(
|
34 |
logging.debug("Starting chat_function")
|
35 |
with REQUEST_LATENCY.time():
|
36 |
REQUEST_COUNT.inc()
|
37 |
try:
|
38 |
-
chat_queue.put(
|
39 |
-
logging.info(f"Received
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
|
|
42 |
logging.debug(f"NER results: {ner_results}")
|
43 |
|
44 |
detailed_response = []
|
45 |
model_predicted_labels = []
|
46 |
-
user_predicted_labels = []
|
47 |
for result in ner_results:
|
48 |
token = result['word']
|
49 |
score = result['score']
|
50 |
entity = result['entity']
|
51 |
-
start = result['start']
|
52 |
-
end = result['end']
|
53 |
label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
|
54 |
model_predicted_labels.append(label_id)
|
55 |
-
detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}
|
56 |
|
57 |
response = "\n".join(detailed_response)
|
58 |
logging.info(f"Generated response: {response}")
|
@@ -62,21 +70,12 @@ def chat_function(message, user_ner_tags, ground_truth):
|
|
62 |
|
63 |
time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
|
64 |
|
65 |
-
#
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
entity = result['entity']
|
72 |
-
label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
|
73 |
-
user_predicted_labels.append(label_id)
|
74 |
-
except json.JSONDecodeError:
|
75 |
-
return "Invalid JSON format for user NER tags. Please provide a valid JSON array."
|
76 |
-
|
77 |
-
precision = precision_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
|
78 |
-
recall = recall_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
|
79 |
-
f1 = f1_score(user_predicted_labels, model_predicted_labels, average='weighted', zero_division=0)
|
80 |
|
81 |
metrics_response = (f"Precision: {precision:.4f}\n"
|
82 |
f"Recall: {recall:.4f}\n"
|
@@ -93,10 +92,10 @@ def chat_function(message, user_ner_tags, ground_truth):
|
|
93 |
return f"An error occurred. Please try again. Error: {e}"
|
94 |
|
95 |
# Function to simulate stress test
|
96 |
-
def stress_test(num_requests,
|
97 |
def send_chat_message():
|
98 |
response = requests.post("http://127.0.0.1:7860/api/predict/", json={
|
99 |
-
"data": [
|
100 |
"fn_index": 0 # This might need to be updated based on your Gradio app's function index
|
101 |
})
|
102 |
logging.debug(response.json())
|
@@ -121,10 +120,9 @@ body {
|
|
121 |
""", title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
|
122 |
with gr.Tab("Chat"):
|
123 |
gr.Markdown("## Chat with the Bot")
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
chat_interface = gr.Interface(fn=chat_function, inputs=[message_input, user_ner_tags_input, gr.Textbox(lines=5)], outputs=output)
|
128 |
chat_interface.render()
|
129 |
|
130 |
with gr.Tab("Model Parameters"):
|
@@ -143,20 +141,20 @@ body {
|
|
143 |
|
144 |
with gr.Tab("Stress Testing"):
|
145 |
num_requests_input = gr.Number(label="Number of Requests", value=10)
|
146 |
-
|
147 |
delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
|
148 |
stress_test_button = gr.Button("Start Stress Test")
|
149 |
stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
|
150 |
|
151 |
-
def run_stress_test(num_requests,
|
152 |
stress_test_status.value = "Stress test started..."
|
153 |
try:
|
154 |
-
stress_test(num_requests,
|
155 |
stress_test_status.value = "Stress test completed."
|
156 |
except Exception as e:
|
157 |
stress_test_status.value = f"Stress test failed: {e}"
|
158 |
|
159 |
-
stress_test_button.click(run_stress_test, [num_requests_input,
|
160 |
|
161 |
# --- Update Functions ---
|
162 |
def update_metrics(request_count_display, avg_latency_display):
|
|
|
1 |
import logging
|
2 |
import gradio as gr
|
|
|
3 |
import time
|
4 |
from prometheus_client import start_http_server, Counter, Histogram, Gauge
|
5 |
import threading
|
|
|
7 |
import random
|
8 |
from transformers import pipeline
|
9 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
|
10 |
import requests
|
11 |
+
from datasets import load_dataset
|
12 |
|
13 |
# Load the model
|
14 |
ner_pipeline = pipeline("ner", model="Sevixdd/roberta-base-finetuned-ner")
|
15 |
|
16 |
+
# Load the dataset
|
17 |
+
dataset = load_dataset("surrey-nlp/PLOD-filtered")
|
18 |
+
|
19 |
# --- Prometheus Metrics Setup ---
|
20 |
REQUEST_COUNT = Counter('gradio_request_count', 'Total number of requests')
|
21 |
REQUEST_LATENCY = Histogram('gradio_request_latency_seconds', 'Request latency in seconds')
|
|
|
32 |
chat_queue = Queue() # Define chat_queue globally
|
33 |
|
34 |
# --- Chat Function with Monitoring ---
|
35 |
+
def chat_function(index):
|
36 |
logging.debug("Starting chat_function")
|
37 |
with REQUEST_LATENCY.time():
|
38 |
REQUEST_COUNT.inc()
|
39 |
try:
|
40 |
+
chat_queue.put(index)
|
41 |
+
logging.info(f"Received index from user: {index}")
|
42 |
+
|
43 |
+
# Get the example from the dataset
|
44 |
+
example = dataset['train'][int(index)]
|
45 |
+
tokens = example['tokens']
|
46 |
+
ground_truth_labels = example['ner_tags']
|
47 |
+
|
48 |
+
logging.info(f"Tokens: {tokens}")
|
49 |
+
logging.info(f"Ground Truth Labels: {ground_truth_labels}")
|
50 |
|
51 |
+
# Predict using the model
|
52 |
+
ner_results = ner_pipeline(" ".join(tokens))
|
53 |
logging.debug(f"NER results: {ner_results}")
|
54 |
|
55 |
detailed_response = []
|
56 |
model_predicted_labels = []
|
|
|
57 |
for result in ner_results:
|
58 |
token = result['word']
|
59 |
score = result['score']
|
60 |
entity = result['entity']
|
|
|
|
|
61 |
label_id = int(entity.split('_')[-1]) # Extract numeric label from entity
|
62 |
model_predicted_labels.append(label_id)
|
63 |
+
detailed_response.append(f"Token: {token}, Entity: {entity}, Score: {score:.4f}")
|
64 |
|
65 |
response = "\n".join(detailed_response)
|
66 |
logging.info(f"Generated response: {response}")
|
|
|
70 |
|
71 |
time.sleep(random.uniform(0.5, 2.5)) # Simulate processing time
|
72 |
|
73 |
+
# Ensure the model and ground truth labels are the same length for comparison
|
74 |
+
model_predicted_labels = model_predicted_labels[:len(ground_truth_labels)]
|
75 |
+
|
76 |
+
precision = precision_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
|
77 |
+
recall = recall_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
|
78 |
+
f1 = f1_score(ground_truth_labels, model_predicted_labels, average='weighted', zero_division=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
metrics_response = (f"Precision: {precision:.4f}\n"
|
81 |
f"Recall: {recall:.4f}\n"
|
|
|
92 |
return f"An error occurred. Please try again. Error: {e}"
|
93 |
|
94 |
# Function to simulate stress test
|
95 |
+
def stress_test(num_requests, index, delay):
|
96 |
def send_chat_message():
|
97 |
response = requests.post("http://127.0.0.1:7860/api/predict/", json={
|
98 |
+
"data": [index],
|
99 |
"fn_index": 0 # This might need to be updated based on your Gradio app's function index
|
100 |
})
|
101 |
logging.debug(response.json())
|
|
|
120 |
""", title="PLOD Filtered with Monitoring") as demo: # Load CSS for background image
|
121 |
with gr.Tab("Chat"):
|
122 |
gr.Markdown("## Chat with the Bot")
|
123 |
+
index_input = gr.Textbox(label="Enter dataset index:", lines=1)
|
124 |
+
output = gr.Textbox(label="Response", lines=20)
|
125 |
+
chat_interface = gr.Interface(fn=chat_function, inputs=[index_input], outputs=output)
|
|
|
126 |
chat_interface.render()
|
127 |
|
128 |
with gr.Tab("Model Parameters"):
|
|
|
141 |
|
142 |
with gr.Tab("Stress Testing"):
|
143 |
num_requests_input = gr.Number(label="Number of Requests", value=10)
|
144 |
+
index_input_stress = gr.Textbox(label="Dataset Index", value="2")
|
145 |
delay_input = gr.Number(label="Delay Between Requests (seconds)", value=0.1)
|
146 |
stress_test_button = gr.Button("Start Stress Test")
|
147 |
stress_test_status = gr.Textbox(label="Stress Test Status", lines=5, interactive=False)
|
148 |
|
149 |
+
def run_stress_test(num_requests, index, delay):
|
150 |
stress_test_status.value = "Stress test started..."
|
151 |
try:
|
152 |
+
stress_test(num_requests, index, delay)
|
153 |
stress_test_status.value = "Stress test completed."
|
154 |
except Exception as e:
|
155 |
stress_test_status.value = f"Stress test failed: {e}"
|
156 |
|
157 |
+
stress_test_button.click(run_stress_test, [num_requests_input, index_input_stress, delay_input], stress_test_status)
|
158 |
|
159 |
# --- Update Functions ---
|
160 |
def update_metrics(request_count_display, avg_latency_display):
|