LLM_FinetuneR / VLLM_evaluation.py
diksha
Initial Commit to Gradio app
026b316
raw
history blame
8.4 kB
import json
from sentence_transformers import SentenceTransformer, util
import nltk
import os
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import time
import asyncio
import logging
import subprocess
import requests
import sys
import os
import threading
# Set the GLOO_SOCKET_IFNAME environment variable
os.environ["GLOO_SOCKET_IFNAME"] = "lo"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Configure logging
logging.basicConfig(level=logging.INFO)
def load_input_data():
"""Load input data from command line arguments."""
try:
input_data = json.loads(sys.argv[1])
return input_data
except json.JSONDecodeError as e:
logging.error(f"Failed to decode JSON input: {e}")
sys.exit(1)
# Load pre-trained models for evaluation
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
# Download necessary NLTK resources
nltk.download('punkt')
# Load your dataset
with open('output_json.json', 'r') as f:
data = json.load(f)
def wait_for_server(max_attempts=60):
"""Wait for the vLLM server to become available."""
url = "http://localhost:8000/health"
for attempt in range(max_attempts):
try:
response = requests.get(url)
if response.status_code == 200:
logging.info("vLLM server is ready!")
return True
except requests.exceptions.RequestException as e:
logging.info(f"Server not ready yet: {e}. Retrying in {2**attempt} seconds...")
time.sleep(2**attempt)
def log_output(pipe, log_func):
"""Helper function to log output from a subprocess pipe."""
for line in iter(pipe.readline, ''):
log_func(line.strip())
def start_vllm_server(model_name):
cmd = [
"vllm",
"serve",
f"PharynxAI/{model_name}",
"--gpu_memory_utilization=0.80",
"--max_model_len=4096",
"--enable-chunked-prefill=False",
"--num_scheduler_steps=2"
]
logging.info(f"Starting vLLM server with command: {' '.join(cmd)}")
# Start the server subprocess
server_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
# # Use threads to handle stdout and stderr in real-time
# threading.Thread(target=log_output, args=(server_process.stdout, logging.info), daemon=True).start()
# threading.Thread(target=log_output, args=(server_process.stderr, logging.error), daemon=True).start()
# Wait for the server to become ready
if not wait_for_server():
server_process.terminate()
raise Exception("Server failed to start in time.")
return server_process
def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
"""Evaluate semantic similarity using Sentence-BERT."""
expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
model_embedding = semantic_model.encode(model_response, convert_to_tensor=True)
similarity_score = util.pytorch_cos_sim(expected_embedding, model_embedding)
return similarity_score.item()
def evaluate_bleu(expected_response, model_response):
"""Evaluate BLEU score using NLTK's sentence_bleu."""
expected_tokens = nltk.word_tokenize(expected_response.lower())
model_tokens = nltk.word_tokenize(model_response.lower())
smoothing_function = nltk.translate.bleu_score.SmoothingFunction().method1
bleu_score = nltk.translate.bleu_score.sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function)
return bleu_score
async def query_vllm_server(prompt, model_name, max_retries=3):
"""Query the vLLM server with retries."""
url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": f"PharynxAI/{model_name}",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
}
for attempt in range(max_retries):
try:
response = requests.post(url, headers=headers, json=data, timeout=300)
response.raise_for_status()
return response.json() # returns the complete response object
except Exception as e:
if attempt < max_retries - 1:
logging.error(f"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying...")
await asyncio.sleep(5)
else:
logging.error(f"Failed to query vLLM server after {max_retries} attempts: {e}")
raise
async def evaluate_model(data, model_name, semantic_model):
"""Evaluate the model using the provided data."""
semantic_scores = []
bleu_scores = []
for entry in data:
prompt = entry['prompt']
expected_response = entry['response']
try:
# Query the vLLM server
response = await query_vllm_server(prompt, model_name)
# Extract model's response from the 'choices' field
if 'choices' not in response or not response['choices']:
logging.error(f"No choices returned for prompt: {prompt}. Skipping this entry.")
continue
# Extract the content of the assistant's response
model_response = response['choices'][0]['message']['content']
# Evaluate scores
semantic_score = evaluate_semantic_similarity(expected_response, model_response, semantic_model)
semantic_scores.append(semantic_score)
bleu_score = evaluate_bleu(expected_response, model_response)
bleu_scores.append(bleu_score)
# Print the individual evaluation results
print(f"Prompt: {prompt}")
print(f"Expected Response: {expected_response}")
print(f"Model Response: {model_response}")
print(f"Semantic Similarity: {semantic_score:.4f}")
print(f"BLEU Score: {bleu_score:.4f}")
except Exception as e:
logging.error(f"Error processing entry: {e}")
continue
# Calculate average scores
avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
# Create comprehensive results dictionary
evaluation_results = {
'average_semantic_score': avg_semantic_score,
'average_bleu_score': avg_bleu_score
}
# Print results to stdout for capturing in handler
print(json.dumps(evaluation_results))
logging.info("\nOverall Average Scores:")
logging.info(f"Average Semantic Similarity: {avg_semantic_score:.4f}")
logging.info(f"Average BLEU Score: {avg_bleu_score:.4f}")
return evaluation_results
async def main():
# Load input data
input_data = load_input_data()
model_name = input_data["model_name"]
server_process = None
try:
# Check if the model directory exists
model_path = f"PharynxAI/{model_name}"
if not os.path.exists(model_path):
logging.error(f"Model path does not exist: {model_path}")
logging.info("Please ensure the model is downloaded and the path is correct")
sys.exit(1)
# # Start vLLM server
server_process = start_vllm_server(model_name)
# Run the evaluation asynchronously
await evaluate_model(data, model_name, semantic_model)
except Exception as e:
logging.error(f"An error occurred: {e}")
sys.exit(1)
finally:
# Cleanup: terminate the server process if it exists
if server_process:
logging.info("Shutting down vLLM server...")
server_process.terminate()
try:
server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
logging.warning("Server didn't terminate gracefully, forcing kill...")
server_process.kill()
server_process.wait()
logging.info("Server shutdown complete")
if __name__ == "__main__":
# Start the event loop
asyncio.run(main())