File size: 6,030 Bytes
026b316
 
 
 
 
 
 
 
 
 
d289937
026b316
 
a047bdf
026b316
d289937
 
 
 
 
026b316
d289937
 
026b316
 
 
 
d289937
 
 
 
 
 
026b316
 
 
 
d289937
026b316
 
 
 
d289937
026b316
 
 
d289937
 
 
 
026b316
 
 
 
 
 
82235bb
026b316
 
 
 
 
d289937
 
026b316
 
 
d289937
026b316
 
 
 
 
 
 
 
 
 
 
 
 
 
d289937
 
026b316
 
d289937
 
026b316
 
 
 
 
 
 
 
 
 
d289937
 
 
 
 
 
 
 
 
026b316
 
 
 
 
 
 
 
 
 
d289937
026b316
d289937
026b316
d289937
026b316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d289937
026b316
 
 
 
 
d289937
026b316
 
 
 
d289937
026b316
 
 
 
 
 
d289937
 
 
 
 
026b316
 
d289937
 
026b316
 
d289937
026b316
 
 
d289937
026b316
 
 
 
 
 
 
 
d289937
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import json
from sentence_transformers import SentenceTransformer, util
import nltk
import os
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import time
import logging
import subprocess
import requests
import sys
import json

# Set the GLOO_SOCKET_IFNAME environment variable
# os.environ["GLOO_SOCKET_IFNAME"] = "lo"

# Simplified logging
logging.basicConfig(level=logging.INFO, format='%(message)s')

# Load pre-trained models for evaluation
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')

# Download necessary NLTK resources
nltk.download('punkt', quiet=True)

def load_input_data():
    """Load input data from command line arguments."""
    try:
        # Check if input is provided via command-line argument
        if len(sys.argv) > 1:
            return json.loads(sys.argv[1])
        else:
            logging.error("No input data provided")
            sys.exit(1)
    except json.JSONDecodeError as e:
        logging.error(f"Failed to decode JSON input: {e}")
        sys.exit(1)

def wait_for_server(max_attempts=30):
    """Wait for the vLLM server to become available."""
    url = "http://localhost:8000/health"
    for attempt in range(max_attempts):
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                logging.info("vLLM server is ready!")
                return True
        except requests.exceptions.RequestException:
            time.sleep(2)
    logging.error("vLLM server failed to start")
    return False

def start_vllm_server(model_name):
    cmd = [
        "vllm",
        "serve",
        f"PharynxAI/{model_name}",
        "--gpu_memory_utilization=0.98",
        "--max_model_len=4096",
        "--enable-chunked-prefill=False",
        "--num_scheduler_steps=2"
    ]

    logging.info(f"Starting vLLM server: {' '.join(cmd)}")
    server_process = subprocess.Popen(cmd)

    if not wait_for_server():
        server_process.terminate()
        raise Exception("Server failed to start")
    
    return server_process

def evaluate_semantic_similarity(expected_response, model_response, semantic_model):
    """Evaluate semantic similarity using Sentence-BERT."""
    expected_embedding = semantic_model.encode(expected_response, convert_to_tensor=True)
    model_embedding = semantic_model.encode(model_response, convert_to_tensor=True)
    similarity_score = util.pytorch_cos_sim(expected_embedding, model_embedding)
    return similarity_score.item()

def evaluate_bleu(expected_response, model_response):
    """Evaluate BLEU score using NLTK's sentence_bleu."""
    expected_tokens = nltk.word_tokenize(expected_response.lower())
    model_tokens = nltk.word_tokenize(model_response.lower())
    smoothing_function = SmoothingFunction().method1
    bleu_score = sentence_bleu([expected_tokens], model_tokens, smoothing_function=smoothing_function)
    return bleu_score

def query_vllm_server(prompt, model_name):
    """Query the vLLM server."""
    url = "http://localhost:8000/v1/chat/completions"
    headers = {"Content-Type": "application/json"}
    data = {
        "model": f"PharynxAI/{model_name}",
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ]
    }

    try:
        response = requests.post(url, headers=headers, json=data, timeout=300)
        response.raise_for_status()
        return response.json()
    except Exception as e:
        logging.error(f"Server query failed: {e}")
        raise

def evaluate_model(data, model_name, semantic_model):
    """Evaluate the model using the provided data."""
    semantic_scores = []
    bleu_scores = []

    for entry in data:
        prompt = entry['prompt']
        expected_response = entry['response']

        try:
            # Query the vLLM server
            response = query_vllm_server(prompt, model_name)
            
            # Extract model's response
            if 'choices' not in response or not response['choices']:
                logging.error(f"No choices returned for prompt: {prompt}")
                continue
                
            model_response = response['choices'][0]['message']['content']

            # Evaluate scores
            semantic_score = evaluate_semantic_similarity(expected_response, model_response, semantic_model)
            semantic_scores.append(semantic_score)

            bleu_score = evaluate_bleu(expected_response, model_response)
            bleu_scores.append(bleu_score)

        except Exception as e:
            logging.error(f"Error processing entry: {e}")
            continue

    # Calculate average scores
    avg_semantic_score = sum(semantic_scores) / len(semantic_scores) if semantic_scores else 0
    avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0

    # Create results dictionary
    evaluation_results = {
        'average_semantic_score': avg_semantic_score,
        'average_bleu_score': avg_bleu_score
    }

    # Print JSON directly to stdout for capture
    print(json.dumps(evaluation_results))
    
    return evaluation_results

def main():
    # Load input data
    input_data = load_input_data()
    model_name = input_data["model_name"]
    server_process = None
    
    try:
        # Load dataset
        with open('output_json.json', 'r') as f:
            data = json.load(f)
        
        # Start vLLM server
        server_process = start_vllm_server(model_name)
        
        # Run evaluation
        evaluate_model(data, model_name, semantic_model)
        
    except Exception as e:
        logging.error(f"Evaluation failed: {e}")
        sys.exit(1)
        
    finally:
        # Cleanup: terminate the server process
        if server_process:
            server_process.terminate()
            try:
                server_process.wait(timeout=5)
            except subprocess.TimeoutExpired:
                server_process.kill()

if __name__ == "__main__":
    main()