Spaces:
Paused
Paused
Update VLLM_evaluation.py
Browse files- VLLM_evaluation.py +52 -106
VLLM_evaluation.py
CHANGED
@@ -4,63 +4,50 @@ import nltk
|
|
4 |
import os
|
5 |
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
6 |
import time
|
7 |
-
import asyncio
|
8 |
import logging
|
9 |
import subprocess
|
10 |
import requests
|
11 |
import sys
|
12 |
-
import
|
13 |
-
import threading
|
14 |
|
15 |
# Set the GLOO_SOCKET_IFNAME environment variable
|
16 |
os.environ["GLOO_SOCKET_IFNAME"] = "lo"
|
17 |
|
18 |
-
#
|
19 |
-
logging.basicConfig(
|
20 |
-
|
21 |
-
|
22 |
-
)
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
|
27 |
def load_input_data():
|
28 |
"""Load input data from command line arguments."""
|
29 |
try:
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
except json.JSONDecodeError as e:
|
33 |
logging.error(f"Failed to decode JSON input: {e}")
|
34 |
sys.exit(1)
|
35 |
|
36 |
-
|
37 |
-
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
38 |
-
|
39 |
-
# Download necessary NLTK resources
|
40 |
-
nltk.download('punkt')
|
41 |
-
|
42 |
-
# Load your dataset
|
43 |
-
with open('output_json.json', 'r') as f:
|
44 |
-
data = json.load(f)
|
45 |
-
|
46 |
-
def wait_for_server(max_attempts=60):
|
47 |
"""Wait for the vLLM server to become available."""
|
48 |
url = "http://localhost:8000/health"
|
49 |
for attempt in range(max_attempts):
|
50 |
try:
|
51 |
-
response = requests.get(url)
|
52 |
if response.status_code == 200:
|
53 |
logging.info("vLLM server is ready!")
|
54 |
return True
|
55 |
-
except requests.exceptions.RequestException
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
def log_output(pipe, log_func):
|
61 |
-
"""Helper function to log output from a subprocess pipe."""
|
62 |
-
for line in iter(pipe.readline, ''):
|
63 |
-
log_func(line.strip())
|
64 |
|
65 |
def start_vllm_server(model_name):
|
66 |
cmd = [
|
@@ -73,29 +60,15 @@ def start_vllm_server(model_name):
|
|
73 |
"--num_scheduler_steps=2"
|
74 |
]
|
75 |
|
76 |
-
logging.info(f"Starting vLLM server
|
|
|
77 |
|
78 |
-
# Start the server subprocess
|
79 |
-
server_process = subprocess.Popen(
|
80 |
-
cmd,
|
81 |
-
stdout=subprocess.PIPE,
|
82 |
-
stderr=subprocess.PIPE,
|
83 |
-
text=True,
|
84 |
-
bufsize=1
|
85 |
-
)
|
86 |
-
|
87 |
-
# # Use threads to handle stdout and stderr in real-time
|
88 |
-
# threading.Thread(target=log_output, args=(server_process.stdout, logging.info), daemon=True).start()
|
89 |
-
# threading.Thread(target=log_output, args=(server_process.stderr, logging.error), daemon=True).start()
|
90 |
-
|
91 |
-
# Wait for the server to become ready
|
92 |
if not wait_for_server():
|
93 |
server_process.terminate()
|
94 |
-
raise Exception("Server failed to start
|
95 |
|
96 |
return server_process
|
97 |
|
98 |
-
|
99 |
def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
|
100 |
"""Evaluate semantic similarity using Sentence-BERT."""
|
101 |
expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
|
@@ -107,12 +80,12 @@ def evaluate_bleu(expected_response, model_response):
|
|
107 |
"""Evaluate BLEU score using NLTK's sentence_bleu."""
|
108 |
expected_tokens = nltk.word_tokenize(expected_response.lower())
|
109 |
model_tokens = nltk.word_tokenize(model_response.lower())
|
110 |
-
smoothing_function =
|
111 |
-
bleu_score =
|
112 |
return bleu_score
|
113 |
|
114 |
-
|
115 |
-
"""Query the vLLM server
|
116 |
url = "http://localhost:8000/v1/chat/completions"
|
117 |
headers = {"Content-Type": "application/json"}
|
118 |
data = {
|
@@ -123,20 +96,15 @@ async def query_vllm_server(prompt, model_name, max_retries=3):
|
|
123 |
]
|
124 |
}
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
else:
|
136 |
-
logging.error(f"Failed to query vLLM server after {max_retries} attempts: {e}")
|
137 |
-
raise
|
138 |
-
|
139 |
-
async def evaluate_model(data, model_name, semantic_model):
|
140 |
"""Evaluate the model using the provided data."""
|
141 |
semantic_scores = []
|
142 |
bleu_scores = []
|
@@ -147,14 +115,13 @@ async def evaluate_model(data, model_name, semantic_model):
|
|
147 |
|
148 |
try:
|
149 |
# Query the vLLM server
|
150 |
-
response =
|
151 |
|
152 |
-
# Extract model's response
|
153 |
if 'choices' not in response or not response['choices']:
|
154 |
-
logging.error(f"No choices returned for prompt: {prompt}
|
155 |
continue
|
156 |
|
157 |
-
# Extract the content of the assistant's response
|
158 |
model_response = response['choices'][0]['message']['content']
|
159 |
|
160 |
# Evaluate scores
|
@@ -163,13 +130,6 @@ async def evaluate_model(data, model_name, semantic_model):
|
|
163 |
|
164 |
bleu_score = evaluate_bleu(expected_response, model_response)
|
165 |
bleu_scores.append(bleu_score)
|
166 |
-
# Print the individual evaluation results
|
167 |
-
print(f"Prompt: {prompt}")
|
168 |
-
print(f"Expected Response: {expected_response}")
|
169 |
-
print(f"Model Response: {model_response}")
|
170 |
-
print(f"Semantic Similarity: {semantic_score:.4f}")
|
171 |
-
print(f"BLEU Score: {bleu_score:.4f}")
|
172 |
-
|
173 |
|
174 |
except Exception as e:
|
175 |
logging.error(f"Error processing entry: {e}")
|
@@ -179,60 +139,46 @@ async def evaluate_model(data, model_name, semantic_model):
|
|
179 |
avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
|
180 |
avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
|
181 |
|
182 |
-
# Create
|
183 |
evaluation_results = {
|
184 |
'average_semantic_score': avg_semantic_score,
|
185 |
'average_bleu_score': avg_bleu_score
|
186 |
}
|
187 |
|
188 |
-
# Print
|
189 |
print(json.dumps(evaluation_results))
|
190 |
|
191 |
-
logging.info("\nOverall Average Scores:")
|
192 |
-
logging.info(f"Average Semantic Similarity: {avg_semantic_score:.4f}")
|
193 |
-
logging.info(f"Average BLEU Score: {avg_bleu_score:.4f}")
|
194 |
-
|
195 |
return evaluation_results
|
196 |
|
197 |
-
|
198 |
# Load input data
|
199 |
input_data = load_input_data()
|
200 |
model_name = input_data["model_name"]
|
201 |
server_process = None
|
202 |
|
203 |
try:
|
204 |
-
#
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
sys.exit(1)
|
210 |
-
|
211 |
-
# # Start vLLM server
|
212 |
server_process = start_vllm_server(model_name)
|
213 |
|
214 |
-
# Run
|
215 |
-
|
216 |
|
217 |
except Exception as e:
|
218 |
-
logging.error(f"
|
219 |
sys.exit(1)
|
220 |
|
221 |
finally:
|
222 |
-
# Cleanup: terminate the server process
|
223 |
if server_process:
|
224 |
-
logging.info("Shutting down vLLM server...")
|
225 |
server_process.terminate()
|
226 |
try:
|
227 |
server_process.wait(timeout=5)
|
228 |
except subprocess.TimeoutExpired:
|
229 |
-
logging.warning("Server didn't terminate gracefully, forcing kill...")
|
230 |
server_process.kill()
|
231 |
-
server_process.wait()
|
232 |
-
logging.info("Server shutdown complete")
|
233 |
|
234 |
if __name__ == "__main__":
|
235 |
-
|
236 |
-
asyncio.run(main())
|
237 |
-
|
238 |
-
|
|
|
4 |
import os
|
5 |
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
6 |
import time
|
|
|
7 |
import logging
|
8 |
import subprocess
|
9 |
import requests
|
10 |
import sys
|
11 |
+
import json
|
|
|
12 |
|
13 |
# Set the GLOO_SOCKET_IFNAME environment variable
|
14 |
os.environ["GLOO_SOCKET_IFNAME"] = "lo"
|
15 |
|
16 |
+
# Simplified logging
|
17 |
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
18 |
+
|
19 |
+
# Load pre-trained models for evaluation
|
20 |
+
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
21 |
|
22 |
+
# Download necessary NLTK resources
|
23 |
+
nltk.download('punkt', quiet=True)
|
24 |
|
25 |
def load_input_data():
|
26 |
"""Load input data from command line arguments."""
|
27 |
try:
|
28 |
+
# Check if input is provided via command-line argument
|
29 |
+
if len(sys.argv) > 1:
|
30 |
+
return json.loads(sys.argv[1])
|
31 |
+
else:
|
32 |
+
logging.error("No input data provided")
|
33 |
+
sys.exit(1)
|
34 |
except json.JSONDecodeError as e:
|
35 |
logging.error(f"Failed to decode JSON input: {e}")
|
36 |
sys.exit(1)
|
37 |
|
38 |
+
def wait_for_server(max_attempts=30):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
"""Wait for the vLLM server to become available."""
|
40 |
url = "http://localhost:8000/health"
|
41 |
for attempt in range(max_attempts):
|
42 |
try:
|
43 |
+
response = requests.get(url, timeout=5)
|
44 |
if response.status_code == 200:
|
45 |
logging.info("vLLM server is ready!")
|
46 |
return True
|
47 |
+
except requests.exceptions.RequestException:
|
48 |
+
time.sleep(2)
|
49 |
+
logging.error("vLLM server failed to start")
|
50 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def start_vllm_server(model_name):
|
53 |
cmd = [
|
|
|
60 |
"--num_scheduler_steps=2"
|
61 |
]
|
62 |
|
63 |
+
logging.info(f"Starting vLLM server: {' '.join(cmd)}")
|
64 |
+
server_process = subprocess.Popen(cmd)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
if not wait_for_server():
|
67 |
server_process.terminate()
|
68 |
+
raise Exception("Server failed to start")
|
69 |
|
70 |
return server_process
|
71 |
|
|
|
72 |
def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
|
73 |
"""Evaluate semantic similarity using Sentence-BERT."""
|
74 |
expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
|
|
|
80 |
"""Evaluate BLEU score using NLTK's sentence_bleu."""
|
81 |
expected_tokens = nltk.word_tokenize(expected_response.lower())
|
82 |
model_tokens = nltk.word_tokenize(model_response.lower())
|
83 |
+
smoothing_function = SmoothingFunction().method1
|
84 |
+
bleu_score = sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function)
|
85 |
return bleu_score
|
86 |
|
87 |
+
def query_vllm_server(prompt, model_name):
|
88 |
+
"""Query the vLLM server."""
|
89 |
url = "http://localhost:8000/v1/chat/completions"
|
90 |
headers = {"Content-Type": "application/json"}
|
91 |
data = {
|
|
|
96 |
]
|
97 |
}
|
98 |
|
99 |
+
try:
|
100 |
+
response = requests.post(url, headers=headers, json=data, timeout=300)
|
101 |
+
response.raise_for_status()
|
102 |
+
return response.json()
|
103 |
+
except Exception as e:
|
104 |
+
logging.error(f"Server query failed: {e}")
|
105 |
+
raise
|
106 |
+
|
107 |
+
def evaluate_model(data, model_name, semantic_model):
|
|
|
|
|
|
|
|
|
|
|
108 |
"""Evaluate the model using the provided data."""
|
109 |
semantic_scores = []
|
110 |
bleu_scores = []
|
|
|
115 |
|
116 |
try:
|
117 |
# Query the vLLM server
|
118 |
+
response = query_vllm_server(prompt, model_name)
|
119 |
|
120 |
+
# Extract model's response
|
121 |
if 'choices' not in response or not response['choices']:
|
122 |
+
logging.error(f"No choices returned for prompt: {prompt}")
|
123 |
continue
|
124 |
|
|
|
125 |
model_response = response['choices'][0]['message']['content']
|
126 |
|
127 |
# Evaluate scores
|
|
|
130 |
|
131 |
bleu_score = evaluate_bleu(expected_response, model_response)
|
132 |
bleu_scores.append(bleu_score)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
except Exception as e:
|
135 |
logging.error(f"Error processing entry: {e}")
|
|
|
139 |
avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
|
140 |
avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
|
141 |
|
142 |
+
# Create results dictionary
|
143 |
evaluation_results = {
|
144 |
'average_semantic_score': avg_semantic_score,
|
145 |
'average_bleu_score': avg_bleu_score
|
146 |
}
|
147 |
|
148 |
+
# Print JSON directly to stdout for capture
|
149 |
print(json.dumps(evaluation_results))
|
150 |
|
|
|
|
|
|
|
|
|
151 |
return evaluation_results
|
152 |
|
153 |
+
def main():
|
154 |
# Load input data
|
155 |
input_data = load_input_data()
|
156 |
model_name = input_data["model_name"]
|
157 |
server_process = None
|
158 |
|
159 |
try:
|
160 |
+
# Load dataset
|
161 |
+
with open('output_json.json', 'r') as f:
|
162 |
+
data = json.load(f)
|
163 |
+
|
164 |
+
# Start vLLM server
|
|
|
|
|
|
|
165 |
server_process = start_vllm_server(model_name)
|
166 |
|
167 |
+
# Run evaluation
|
168 |
+
evaluate_model(data, model_name, semantic_model)
|
169 |
|
170 |
except Exception as e:
|
171 |
+
logging.error(f"Evaluation failed: {e}")
|
172 |
sys.exit(1)
|
173 |
|
174 |
finally:
|
175 |
+
# Cleanup: terminate the server process
|
176 |
if server_process:
|
|
|
177 |
server_process.terminate()
|
178 |
try:
|
179 |
server_process.wait(timeout=5)
|
180 |
except subprocess.TimeoutExpired:
|
|
|
181 |
server_process.kill()
|
|
|
|
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
+
main()
|
|
|
|
|
|