File size: 3,607 Bytes
a2e759c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import gradio as gr
import jsonlines
from openai import OpenAI
from dotenv import load_dotenv
from evaluation_utils import evaluate_response


def get_split():
    load_dotenv()
    split = os.getenv("SPLIT")
    if split == "train":
        return "evaluation on development set"
    elif split == "test":
        return "evaluation on test set"


# Utility function to chunk a list into batches
def chunk_list(data, chunk_size):
    for i in range(0, len(data), chunk_size):
        yield data[i:i + chunk_size]

# Function to send an individual request to the OpenAI API
def send_request(client, prompt, index):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        temperature=0,
        seed=42,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ],
        max_tokens=1024,
    )
    return index, response.choices[0].message.content

def evaluate_prompt(prompt: str, num_samples: int = None, split: str = None, batch_size: int = 5, progress=gr.Progress()):
    progress(0, desc="Starting...")
    load_dotenv()

    if num_samples is None:
        num_samples = int(os.getenv("NUM_SAMPLES"))

    if split is None:
        split = os.getenv("SPLIT")
    assert split in ["train", "test"]

    # Define the path to the test.jsonl file
    test_file_path = Path(__file__).parent / f"{split}.jsonl"

    # Load the data from the jsonl file
    test_data = []
    with jsonlines.open(test_file_path) as reader:
        for item in reader:
            test_data.append(item)

    test_data = [item for item in test_data if "'" not in item["shuffled_tokenized"] and "’" not in item["shuffled_tokenized"]]

    # Limit to first num_samples items for faster evaluation
    test_data = test_data[:num_samples]

    client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

    responses = [None] * num_samples  # Pre-allocate a list to store responses in order
    instantiated_prompts = []

    # Create and process batches
    for batch_data in chunk_list(test_data, batch_size):
        # Prepare the prompts for this batch
        batch_prompts = [
            prompt.replace("{% shuffled_sentence %}", test_item["shuffled_tokenized"])
            for test_item in batch_data
        ]
        instantiated_prompts.extend(batch_prompts)

        # Send requests in parallel using ThreadPoolExecutor
        with ThreadPoolExecutor() as executor:
            futures = {executor.submit(send_request, client, item_prompt, i): i for i, item_prompt in enumerate(batch_prompts, start=len(instantiated_prompts) - len(batch_prompts))}
            
            for future in as_completed(futures):
                try:
                    index, response = future.result()
                    responses[index] = response  # Store the response at the correct index
                except Exception as e:
                    print(f"Request failed: {e}")
                    responses[index] = "Error: Request failed"

        # Update progress after each batch
        progress(len(instantiated_prompts) / len(test_data), desc="Processing batches...")

    # Evaluate responses
    scores = []
    for test_item, instantiated_prompt, response in zip(test_data, instantiated_prompts, responses):
        score = evaluate_response(test_item["original_tokenized"], response)
        scores.append(score)
        yield (test_item["original_sentence"], instantiated_prompt, response, score)