File size: 8,169 Bytes
6dc0c9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import argparse
import json
import pandas as pd
import os
import re
import ast
import time
import concurrent.futures
import tqdm
import random
import threading

LOCK = threading.RLock()

## Configs
SYSTEM_PROMPT = "Your task is to evaluate how well the following input prompts can assess the capabilities of advanced AI assistants.\n\nFor the input prompt, please analyze it based on the following 7 criteria.\n1. Specificity: Does the prompt ask for a specific output, such as code, a mathematical solution, a logical simplification, a problem-solving strategy, or a hardware setup recommendation? This specificity allows the AI to demonstrate its ability to understand and generate precise responses.\n2. Domain Knowledge: Does the prompt cover a specific domain, such as programming, mathematics, logic, problem-solving, or hardware setup? Prompts spanning a range of topics test the AI's breadth of knowledge and its ability to apply that knowledge to different domains.\n3. Complexity: Does the prompt vary in complexity, from straightforward tasks to more complex, multi-step problems? This allows evaluators to assess the AI's capability to handle problems of varying difficulty.\n4. Problem-Solving Skills: Does the prompt directly involves the AI to demonstrate active problem-solving skills, such systemically coming up with a solution for a specific setup instead of regurgitating an existing fact? This tests the AI's ability to apply logical reasoning and provide practical solutions.\n5. Creativity: Does the prompt involve a level of creativity in approaching the problem? This criterion tests the AI's ability to provide tailored solutions that take into account the user's specific needs and limitations.\n6. Technical Accuracy: Does the prompt require technical accuracy in the response? This allows evaluators to assess the AI's precision and correctness in technical fields.\n7. Real-world Application: Does the prompt relate to real-world applications, such as setting up a functional system or writing code for a practical use case? This tests the AI's ability to provide practical and actionable information that could be implemented in real-life scenarios.\n\nYou must list the criteria numbers that the prompt satisfies in the format of a Python array. For example, \"[...]\". Do not explain your choice."

ENDPOINT_INFO = {
    "model_name": "META-LLAMA/LLAMA-3-70B-CHAT-HF",
    "name": "llama-3-70b-instruct",
    "endpoints": [{"api_base": "-", "api_key": "-"}],
    "parallel": 8,
    "temperature": 0.0,
    "max_token": 512,
}  # Modify this

TAGS = {
    1: "specificity",
    2: "domain_knowledge",
    3: "complexity",
    4: "problem_solving",
    5: "creativity",
    6: "technical_accuracy",
    7: "real_world",
}

# API setting constants
API_MAX_RETRY = 3
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"


def get_endpoint(endpoint_list):
    if endpoint_list is None:
        return None
    assert endpoint_list is not None
    # randomly pick one
    api_dict = random.choices(endpoint_list)[0]
    return api_dict


pattern = re.compile(r"(\[\d(?:\,\s\d)*\])")


def get_score(judgment):
    matches = pattern.findall(judgment)
    matches = [m for m in matches if m != ""]
    if len(set(matches)) == 0:
        return []
    elif len(set(matches)) == 1:
        try:
            return ast.literal_eval(matches[0])
        except SyntaxError:
            print(matches[0])
            return []
    else:
        return []


def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None):
    import openai

    if api_dict:
        client = openai.OpenAI(
            base_url=api_dict["api_base"],
            api_key=api_dict["api_key"],
        )
    else:
        client = openai.OpenAI()

    output = API_ERROR_OUTPUT
    for _ in range(API_MAX_RETRY):
        try:
            # print(messages)
            completion = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                # extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None,
            )
            output = completion.choices[0].message.content
            break
        except openai.RateLimitError as e:
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)
        except openai.BadRequestError as e:
            print(messages)
            print(type(e), e)
            break
        except openai.APIConnectionError as e:
            print(messages)
            print(type(e), e)
            time.sleep(API_RETRY_SLEEP)
        except openai.InternalServerError as e:
            print(messages)
            print(type(e), e)
            time.sleep(1)
        except KeyError:
            print(type(e), e)
            break

    return output


def get_answer(
    question: dict,
    max_tokens: int,
    temperature: float,
    answer_file: str,
    api_dict: dict,
):
    conv = []
    conv.append({"role": "system", "content": SYSTEM_PROMPT})

    conv.append({"role": "user", "content": question["prompt"]})
    output = chat_completion_openai(
        model=ENDPOINT_INFO["model_name"],
        messages=conv,
        temperature=temperature,
        max_tokens=max_tokens,
        api_dict=api_dict,
    )

    criteria = get_score(output)

    # Dump answers
    question["criteria_tag"] = {name: bool(i in criteria) for i, name in TAGS.items()}
    question.drop("prompt")

    with LOCK:
        with open(answer_file, "a") as fout:
            fout.write(json.dumps(question.to_dict()) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-file", type=str, required=True)
    parser.add_argument("--cache-file", type=str, default=None)
    parser.add_argument("--output-file", type=str, required=True)
    parser.add_argument("--convert-to-json", action="store_true")
    args = parser.parse_args()

    print("loading input data (might take min)")
    input_data = pd.read_json(args.input_file)
    print(f"{len(input_data)}# of input data just loaded")
    if args.cache_file:
        print("loading cache data")
        cache_data = pd.read_json(args.cache_file)
        print(f"{len(cache_data)}# of cache data just loaded")

        assert "criteria_tag" in cache_data.columns and len(
            cache_data["criteria_tag"].dropna()
        ) == len(cache_data)

        not_labeled = input_data[
            ~input_data["question_id"].isin(cache_data["question_id"])
        ].copy()
    else:
        not_labeled = input_data.copy()

    if os.path.isfile(args.output_file):
        print("loading existing output")
        output_data = pd.read_json(args.output_file, lines=True)
        print(f"{len(output_data)}# of existing output just loaded")

        assert "criteria_tag" in output_data.columns and len(
            output_data["criteria_tag"].dropna()
        ) == len(output_data)

        not_labeled = not_labeled[
            ~not_labeled["question_id"].isin(output_data["question_id"])
        ]

    print(f"{len(not_labeled)} needs to be labeled")

    not_labeled["prompt"] = not_labeled.conversation_a.map(
        lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
    )

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=ENDPOINT_INFO["parallel"]
    ) as executor:
        futures = []
        for index, row in tqdm.tqdm(not_labeled.iterrows()):
            future = executor.submit(
                get_answer,
                row,
                ENDPOINT_INFO["max_token"],
                ENDPOINT_INFO["temperature"],
                args.output_file,
                get_endpoint(ENDPOINT_INFO["endpoints"]),
            )
            futures.append(future)
        for future in tqdm.tqdm(
            concurrent.futures.as_completed(futures), total=len(futures)
        ):
            future.result()

    if args.convert_to_json:
        temp = pd.read_json(args.output_file, lines=True)
        temp.to_json(
            args.output_file[:-1], orient="records", indent=4, force_ascii=False
        )