File size: 8,396 Bytes
026b316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import json
from sentence_transformers import SentenceTransformer, util
import nltk
import os
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import time
import asyncio
import logging
import subprocess
import requests
import sys
import os
import threading

# Set the GLOO_SOCKET_IFNAME environment variable
os.environ["GLOO_SOCKET_IFNAME"] = "lo"

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Configure logging
logging.basicConfig(level=logging.INFO)

def load_input_data():
    """Load input data from command line arguments."""
    try:
        input_data = json.loads(sys.argv[1])
        return input_data
    except json.JSONDecodeError as e:
        logging.error(f"Failed to decode JSON input: {e}")
        sys.exit(1)

# Load pre-trained models for evaluation
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')

# Download necessary NLTK resources
nltk.download('punkt')

# Load your dataset
with open('output_json.json', 'r') as f:
    data = json.load(f)

def wait_for_server(max_attempts=60):
    """Wait for the vLLM server to become available."""
    url = "http://localhost:8000/health"
    for attempt in range(max_attempts):
        try:
            response = requests.get(url)
            if response.status_code == 200:
                logging.info("vLLM server is ready!")
                return True
        except requests.exceptions.RequestException as e:
            logging.info(f"Server not ready yet: {e}. Retrying in {2**attempt} seconds...")
            time.sleep(2**attempt)


def log_output(pipe, log_func):
    """Helper function to log output from a subprocess pipe."""
    for line in iter(pipe.readline, ''):
        log_func(line.strip())

def start_vllm_server(model_name):
    cmd = [
        "vllm",
        "serve",
        f"PharynxAI/{model_name}",
        "--gpu_memory_utilization=0.80",
        "--max_model_len=4096",
        "--enable-chunked-prefill=False",
        "--num_scheduler_steps=2"
    ]

    logging.info(f"Starting vLLM server with command: {' '.join(cmd)}")

    # Start the server subprocess
    server_process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        bufsize=1
    )

    # # Use threads to handle stdout and stderr in real-time
    # threading.Thread(target=log_output, args=(server_process.stdout, logging.info), daemon=True).start()
    # threading.Thread(target=log_output, args=(server_process.stderr, logging.error), daemon=True).start()

    # Wait for the server to become ready
    if not wait_for_server():
        server_process.terminate()
        raise Exception("Server failed to start in time.")
    
    return server_process


def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
    """Evaluate semantic similarity using Sentence-BERT."""
    expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
    model_embedding = semantic_model.encode(model_response, convert_to_tensor=True)
    similarity_score = util.pytorch_cos_sim(expected_embedding, model_embedding)
    return similarity_score.item()

def evaluate_bleu(expected_response, model_response):
    """Evaluate BLEU score using NLTK's sentence_bleu."""
    expected_tokens = nltk.word_tokenize(expected_response.lower())
    model_tokens = nltk.word_tokenize(model_response.lower())
    smoothing_function = nltk.translate.bleu_score.SmoothingFunction().method1
    bleu_score = nltk.translate.bleu_score.sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function)
    return bleu_score

async def query_vllm_server(prompt, model_name, max_retries=3):
    """Query the vLLM server with retries."""
    url = "http://localhost:8000/v1/chat/completions"
    headers = {"Content-Type": "application/json"}
    data = {
        "model": f"PharynxAI/{model_name}",
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
    }

    for attempt in range(max_retries):
        try:
            response = requests.post(url, headers=headers, json=data, timeout=300)
            response.raise_for_status()
            return response.json()  # returns the complete response object
        except Exception as e:
            if attempt < max_retries - 1:
                logging.error(f"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying...")
                await asyncio.sleep(5)
            else:
                logging.error(f"Failed to query vLLM server after {max_retries} attempts: {e}")
                raise

async def evaluate_model(data, model_name, semantic_model):
    """Evaluate the model using the provided data."""
    semantic_scores = []
    bleu_scores = []

    for entry in data:
        prompt = entry['prompt']
        expected_response = entry['response']

        try:
            # Query the vLLM server
            response = await query_vllm_server(prompt, model_name)
            
            # Extract model's response from the 'choices' field
            if 'choices' not in response or not response['choices']:
                logging.error(f"No choices returned for prompt: {prompt}. Skipping this entry.")
                continue
                
            # Extract the content of the assistant's response
            model_response = response['choices'][0]['message']['content']

            # Evaluate scores
            semantic_score = evaluate_semantic_similarity(expected_response, model_response, semantic_model)
            semantic_scores.append(semantic_score)

            bleu_score = evaluate_bleu(expected_response, model_response)
            bleu_scores.append(bleu_score)
            # Print the individual evaluation results
            print(f"Prompt: {prompt}")
            print(f"Expected Response: {expected_response}")
            print(f"Model Response: {model_response}")
            print(f"Semantic Similarity: {semantic_score:.4f}")
            print(f"BLEU Score: {bleu_score:.4f}")


        except Exception as e:
            logging.error(f"Error processing entry: {e}")
            continue

    # Calculate average scores
    avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
    avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0

    # Create comprehensive results dictionary
    evaluation_results = {
        'average_semantic_score': avg_semantic_score,
        'average_bleu_score': avg_bleu_score
    }

    # Print results to stdout for capturing in handler
    print(json.dumps(evaluation_results))
    
    logging.info("\nOverall Average Scores:")
    logging.info(f"Average Semantic Similarity: {avg_semantic_score:.4f}")
    logging.info(f"Average BLEU Score: {avg_bleu_score:.4f}")

    return evaluation_results

async def main():
    # Load input data
    input_data = load_input_data()
    model_name = input_data["model_name"]
    server_process = None
    
    try:
        # Check if the model directory exists
        model_path = f"PharynxAI/{model_name}"
        if not os.path.exists(model_path):
            logging.error(f"Model path does not exist: {model_path}")
            logging.info("Please ensure the model is downloaded and the path is correct")
            sys.exit(1)
            
        # # Start vLLM server
        server_process = start_vllm_server(model_name)
        
        # Run the evaluation asynchronously
        await evaluate_model(data, model_name, semantic_model)
        
    except Exception as e:
        logging.error(f"An error occurred: {e}")
        sys.exit(1)
        
    finally:
        # Cleanup: terminate the server process if it exists
        if server_process:
            logging.info("Shutting down vLLM server...")
            server_process.terminate()
            try:
                server_process.wait(timeout=5)
            except subprocess.TimeoutExpired:
                logging.warning("Server didn't terminate gracefully, forcing kill...")
                server_process.kill()
                server_process.wait()
            logging.info("Server shutdown complete")

if __name__ == "__main__":
    # Start the event loop
    asyncio.run(main())