Spaces:
Paused
Paused
File size: 6,030 Bytes
026b316 d289937 026b316 a047bdf 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 82235bb 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 026b316 d289937 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import json
from sentence_transformers import SentenceTransformer, util
import nltk
import os
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import time
import logging
import subprocess
import requests
import sys
import json
# Set the GLOO_SOCKET_IFNAME environment variable
# os.environ["GLOO_SOCKET_IFNAME"] = "lo"
# Simplified logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
# Load pre-trained models for evaluation
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
# Download necessary NLTK resources
nltk.download('punkt', quiet=True)
def load_input_data():
"""Load input data from command line arguments."""
try:
# Check if input is provided via command-line argument
if len(sys.argv) > 1:
return json.loads(sys.argv[1])
else:
logging.error("No input data provided")
sys.exit(1)
except json.JSONDecodeError as e:
logging.error(f"Failed to decode JSON input: {e}")
sys.exit(1)
def wait_for_server(max_attempts=30):
"""Wait for the vLLM server to become available."""
url = "http://localhost:8000/health"
for attempt in range(max_attempts):
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
logging.info("vLLM server is ready!")
return True
except requests.exceptions.RequestException:
time.sleep(2)
logging.error("vLLM server failed to start")
return False
def start_vllm_server(model_name):
cmd = [
"vllm",
"serve",
f"PharynxAI/{model_name}",
"--gpu_memory_utilization=0.98",
"--max_model_len=4096",
"--enable-chunked-prefill=False",
"--num_scheduler_steps=2"
]
logging.info(f"Starting vLLM server: {' '.join(cmd)}")
server_process = subprocess.Popen(cmd)
if not wait_for_server():
server_process.terminate()
raise Exception("Server failed to start")
return server_process
def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
"""Evaluate semantic similarity using Sentence-BERT."""
expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
model_embedding = semantic_model.encode(model_response, convert_to_tensor=True)
similarity_score = util.pytorch_cos_sim(expected_embedding, model_embedding)
return similarity_score.item()
def evaluate_bleu(expected_response, model_response):
"""Evaluate BLEU score using NLTK's sentence_bleu."""
expected_tokens = nltk.word_tokenize(expected_response.lower())
model_tokens = nltk.word_tokenize(model_response.lower())
smoothing_function = SmoothingFunction().method1
bleu_score = sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function)
return bleu_score
def query_vllm_server(prompt, model_name):
"""Query the vLLM server."""
url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": f"PharynxAI/{model_name}",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
}
try:
response = requests.post(url, headers=headers, json=data, timeout=300)
response.raise_for_status()
return response.json()
except Exception as e:
logging.error(f"Server query failed: {e}")
raise
def evaluate_model(data, model_name, semantic_model):
"""Evaluate the model using the provided data."""
semantic_scores = []
bleu_scores = []
for entry in data:
prompt = entry['prompt']
expected_response = entry['response']
try:
# Query the vLLM server
response = query_vllm_server(prompt, model_name)
# Extract model's response
if 'choices' not in response or not response['choices']:
logging.error(f"No choices returned for prompt: {prompt}")
continue
model_response = response['choices'][0]['message']['content']
# Evaluate scores
semantic_score = evaluate_semantic_similarity(expected_response, model_response, semantic_model)
semantic_scores.append(semantic_score)
bleu_score = evaluate_bleu(expected_response, model_response)
bleu_scores.append(bleu_score)
except Exception as e:
logging.error(f"Error processing entry: {e}")
continue
# Calculate average scores
avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
# Create results dictionary
evaluation_results = {
'average_semantic_score': avg_semantic_score,
'average_bleu_score': avg_bleu_score
}
# Print JSON directly to stdout for capture
print(json.dumps(evaluation_results))
return evaluation_results
def main():
# Load input data
input_data = load_input_data()
model_name = input_data["model_name"]
server_process = None
try:
# Load dataset
with open('output_json.json', 'r') as f:
data = json.load(f)
# Start vLLM server
server_process = start_vllm_server(model_name)
# Run evaluation
evaluate_model(data, model_name, semantic_model)
except Exception as e:
logging.error(f"Evaluation failed: {e}")
sys.exit(1)
finally:
# Cleanup: terminate the server process
if server_process:
server_process.terminate()
try:
server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
server_process.kill()
if __name__ == "__main__":
main() |