kaikaidai commited on
Commit
782dcd8
·
verified ·
1 Parent(s): 941d431

Delete gen_api_answer.py

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +0 -484
gen_api_answer.py DELETED
@@ -1,484 +0,0 @@
1
- from openai import OpenAI
2
- import anthropic
3
- from together import Together
4
- import cohere
5
- import json
6
- import re
7
- import os
8
- import requests
9
- from prompts import (
10
- JUDGE_SYSTEM_PROMPT,
11
- PROMETHEUS_PROMPT,
12
- PROMETHEUS_PROMPT_WITH_REFERENCE,
13
- ATLA_PROMPT,
14
- ATLA_PROMPT_WITH_REFERENCE,
15
- FLOW_JUDGE_PROMPT
16
- )
17
- from transformers import AutoTokenizer
18
-
19
- # Initialize clients
20
- anthropic_client = anthropic.Anthropic()
21
- openai_client = OpenAI()
22
- together_client = Together()
23
- hf_api_key = os.getenv("HF_API_KEY")
24
- flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
25
- cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
26
- salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
27
- def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
28
- """Get response from OpenAI API"""
29
- try:
30
- response = openai_client.chat.completions.create(
31
- model=model_name,
32
- messages=[
33
- {"role": "system", "content": system_prompt},
34
- {"role": "user", "content": prompt},
35
- ],
36
- max_completion_tokens=max_tokens,
37
- temperature=temperature,
38
- )
39
- return response.choices[0].message.content
40
- except Exception as e:
41
- return f"Error with OpenAI model {model_name}: {str(e)}"
42
-
43
- def get_anthropic_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
44
- """Get response from Anthropic API"""
45
- try:
46
- response = anthropic_client.messages.create(
47
- model=model_name,
48
- max_tokens=max_tokens,
49
- temperature=temperature,
50
- system=system_prompt,
51
- messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
52
- )
53
- return response.content[0].text
54
- except Exception as e:
55
- return f"Error with Anthropic model {model_name}: {str(e)}"
56
-
57
- def get_together_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
58
- """Get response from Together API"""
59
- try:
60
- response = together_client.chat.completions.create(
61
- model=model_name,
62
- messages=[
63
- {"role": "system", "content": system_prompt},
64
- {"role": "user", "content": prompt},
65
- ],
66
- max_tokens=max_tokens,
67
- temperature=temperature,
68
- stream=False,
69
- )
70
- return response.choices[0].message.content
71
- except Exception as e:
72
- return f"Error with Together model {model_name}: {str(e)}"
73
-
74
- def get_prometheus_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
75
- """Get response from Hugging Face model"""
76
- try:
77
- headers = {
78
- "Accept": "application/json",
79
- "Authorization": f"Bearer {hf_api_key}",
80
- "Content-Type": "application/json"
81
- }
82
-
83
- # Create messages list for chat template
84
- messages = []
85
- if system_prompt:
86
- messages.append({"role": "system", "content": system_prompt})
87
- messages.append({"role": "user", "content": prompt})
88
-
89
- # Apply chat template
90
- model_id = "prometheus-eval/prometheus-7b-v2.0"
91
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
92
- formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
-
94
- payload = {
95
- "inputs": formatted_prompt,
96
- "parameters": {
97
- "max_new_tokens": max_tokens,
98
- "return_full_text": False,
99
- "temperature": temperature
100
- }
101
- }
102
-
103
- response = requests.post(
104
- "https://otb7jglxy6r37af6.us-east-1.aws.endpoints.huggingface.cloud",
105
- headers=headers,
106
- json=payload
107
- )
108
- return response.json()[0]["generated_text"]
109
- except Exception as e:
110
- return f"Error with Hugging Face model {model_name}: {str(e)}"
111
-
112
- def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
113
- """Get response from HF endpoint for Atla model"""
114
- try:
115
- headers = {
116
- "Accept": "application/json",
117
- "Authorization": f"Bearer {hf_api_key}",
118
- "Content-Type": "application/json"
119
- }
120
-
121
- # Create messages list for chat template
122
- messages = []
123
- if system_prompt:
124
- messages.append({"role": "system", "content": system_prompt})
125
- messages.append({"role": "user", "content": prompt})
126
-
127
- # Apply chat template
128
- model_id = "AtlaAI/Selene-1-Mini-Llama-3.1-8B"
129
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
130
- formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
131
-
132
- payload = {
133
- "inputs": formatted_prompt,
134
- "parameters": {
135
- "max_new_tokens": max_tokens,
136
- "return_full_text": False,
137
- "temperature": temperature,
138
- "seed": 42,
139
- "add_generation_prompt": True
140
- }
141
- }
142
-
143
- response = requests.post(
144
- "https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
145
- headers=headers,
146
- json=payload
147
- )
148
- return response.json()[0]["generated_text"]
149
- except Exception as e:
150
- return f"Error with Atla model {model_name}: {str(e)}"
151
-
152
- def get_flow_judge_response(model_name, prompt, max_tokens=2048, temperature=0.1, top_p=0.95) -> str:
153
- """Get response from Flow Judge"""
154
- try:
155
- response = requests.post(
156
- "https://arena.flow-ai.io/v1/chat/completions",
157
- headers={
158
- "Content-Type": "application/json",
159
- "Authorization": f"Bearer {flow_judge_api_key}"
160
- },
161
- json={
162
- "model": model_name,
163
- "messages": [
164
- {"role": "user", "content": prompt}
165
- ],
166
- "max_tokens": max_tokens,
167
- "temperature": temperature,
168
- "top_p": top_p,
169
- "stop": None
170
- }
171
- )
172
- response.raise_for_status()
173
- return response.json()["choices"][0]['message']['content']
174
- except Exception as e:
175
- return f"Error with Flow Judge completions model {model_name}: {str(e)}"
176
-
177
- def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
178
- """Get response from Cohere API"""
179
- try:
180
- response = cohere_client.chat(
181
- model=model_name,
182
- messages=[
183
- {"role": "system", "content": system_prompt},
184
- {"role": "user", "content": prompt}
185
- ],
186
- max_tokens=max_tokens,
187
- temperature=temperature
188
- )
189
- # Extract the text from the content items
190
- content_items = response.message.content
191
- if isinstance(content_items, list):
192
- # Get the text from the first content item
193
- return content_items[0].text
194
- return str(content_items) # Fallback if it's not a list
195
- except Exception as e:
196
- return f"Error with Cohere model {model_name}: {str(e)}"
197
-
198
- def get_salesforce_response(model_name, prompt, system_prompt=None, max_tokens=2048, temperature=0):
199
- """Get response from Salesforce Research API"""
200
- try:
201
- headers = {
202
- 'accept': 'application/json',
203
- "content-type": "application/json",
204
- "X-Api-Key": salesforce_api_key,
205
- }
206
-
207
- # Create messages list
208
- messages = []
209
- messages.append({"role": "user", "content": prompt})
210
-
211
- json_data = {
212
- "prompts": messages,
213
- "temperature": temperature,
214
- "top_p": 1,
215
- "max_tokens": max_tokens,
216
- }
217
-
218
- response = requests.post(
219
- 'https://gateway.salesforceresearch.ai/sfr-judge/process',
220
- headers=headers,
221
- json=json_data
222
- )
223
- response.raise_for_status()
224
- return response.json()['result'][0]
225
- except Exception as e:
226
- return f"Error with Salesforce model {model_name}: {str(e)}"
227
-
228
- def get_model_response(
229
- model_name,
230
- model_info,
231
- prompt_data,
232
- use_reference=False,
233
- max_tokens=500,
234
- temperature=0
235
- ):
236
- """Get response from appropriate API based on model organization"""
237
- if not model_info:
238
- return "Model not found or unsupported."
239
-
240
- api_model = model_info["api_model"]
241
- organization = model_info["organization"]
242
-
243
- # Determine if model is Prometheus, Atla, Flow Judge, or Salesforce
244
- is_prometheus = (organization == "Prometheus")
245
- is_atla = (organization == "Atla")
246
- is_flow_judge = (organization == "Flow AI")
247
- is_salesforce = (organization == "Salesforce")
248
-
249
- # For non-Prometheus/Atla/Flow Judge/Salesforce models, use the Judge system prompt
250
- system_prompt = None if (is_prometheus or is_atla or is_flow_judge or is_salesforce) else JUDGE_SYSTEM_PROMPT
251
-
252
- # Select the appropriate base prompt
253
- if is_atla or is_salesforce: # Use same prompt for Atla and Salesforce
254
- base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
255
- elif is_flow_judge:
256
- base_prompt = FLOW_JUDGE_PROMPT
257
- else:
258
- base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
259
-
260
- # For non-Prometheus/non-Atla/non-Salesforce models, use Prometheus but replace the output format with JSON
261
- if not (is_prometheus or is_atla or is_flow_judge or is_salesforce):
262
- base_prompt = base_prompt.replace(
263
- '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
264
- '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
265
- )
266
-
267
- try:
268
- if not is_flow_judge:
269
- # Format the prompt with the provided data
270
- final_prompt = base_prompt.format(
271
- human_input=prompt_data['human_input'],
272
- ai_response=prompt_data['ai_response'],
273
- ground_truth_input=prompt_data.get('ground_truth_input', ''),
274
- eval_criteria=prompt_data['eval_criteria']
275
- )
276
- else:
277
- human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
278
- ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
279
- ground_truth = prompt_data.get('ground_truth_input', '')
280
- if ground_truth:
281
- response_reference = f"<response_reference>\n{ground_truth}\n</response_reference>"
282
- else:
283
- response_reference = ""
284
-
285
- # For Flow Judge, parse the scoring rubric from eval_criteria
286
- eval_criteria_lines = prompt_data['eval_criteria'].split('\n')
287
- rubric_lines = [line for line in eval_criteria_lines if line.startswith('Score ')]
288
- rubric = '\n'.join(f"- {line}" for line in rubric_lines)
289
-
290
- if response_reference:
291
- inputs = human_input + "\n" + response_reference
292
- else:
293
- inputs = human_input
294
-
295
- final_prompt = base_prompt.format(
296
- INPUTS=inputs,
297
- OUTPUT=ai_response,
298
- EVALUATION_CRITERIA=prompt_data['eval_criteria'],
299
- RUBRIC=rubric
300
- )
301
-
302
- except KeyError as e:
303
- return f"Error formatting prompt: Missing required field {str(e)}"
304
-
305
- try:
306
- if organization == "OpenAI":
307
- return get_openai_response(
308
- api_model, final_prompt, system_prompt, max_tokens, temperature
309
- )
310
- elif organization == "Anthropic":
311
- return get_anthropic_response(
312
- api_model, final_prompt, system_prompt, max_tokens, temperature
313
- )
314
- elif organization == "Prometheus":
315
- return get_prometheus_response(
316
- api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
317
- )
318
- elif organization == "Atla":
319
- return get_atla_response(
320
- api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
321
- )
322
- elif organization == "Cohere":
323
- return get_cohere_response(
324
- api_model, final_prompt, system_prompt, max_tokens, temperature
325
- )
326
- elif organization == "Flow AI":
327
- return get_flow_judge_response(
328
- api_model, final_prompt
329
- )
330
- elif organization == "Salesforce":
331
- response = get_salesforce_response(
332
- api_model, final_prompt, system_prompt, max_tokens, temperature
333
- )
334
- return response
335
- else:
336
- # All other organizations use Together API
337
- return get_together_response(
338
- api_model, final_prompt, system_prompt, max_tokens, temperature
339
- )
340
- except Exception as e:
341
- return f"Error with {organization} model {model_name}: {str(e)}"
342
-
343
- def parse_model_response(response):
344
- try:
345
- # Debug print
346
- print(f"Raw model response: {response}")
347
-
348
- # If response is already a dictionary, use it directly
349
- if isinstance(response, dict):
350
- return str(response.get("result", "N/A")), response.get("feedback", "N/A")
351
-
352
- # First try to parse the entire response as JSON
353
- try:
354
- data = json.loads(response)
355
- return str(data.get("result", "N/A")), data.get("feedback", "N/A")
356
- except json.JSONDecodeError:
357
- # If that fails, check if this is a Salesforce response (which uses ATLA format)
358
- if "**Reasoning:**" in response or "**Result:**" in response:
359
- # Use ATLA parser for Salesforce responses
360
- return atla_parse_model_response(response)
361
-
362
- # Otherwise try to find JSON within the response
363
- json_match = re.search(r"{.*}", response, re.DOTALL)
364
- if json_match:
365
- data = json.loads(json_match.group(0))
366
- return str(data.get("result", "N/A")), data.get("feedback", "N/A")
367
- else:
368
- return "Error", f"Invalid response format returned - here is the raw model response: {response}"
369
-
370
- except Exception as e:
371
- # Debug print for error case
372
- print(f"Failed to parse response: {str(e)}")
373
-
374
- # If the error message itself contains valid JSON, try to parse that
375
- try:
376
- error_json_match = re.search(r"{.*}", str(e), re.DOTALL)
377
- if error_json_match:
378
- data = json.loads(error_json_match.group(0))
379
- return str(data.get("result", "N/A")), data.get("feedback", "N/A")
380
- except:
381
- pass
382
-
383
- return "Error", f"Failed to parse response: {response}"
384
-
385
- def prometheus_parse_model_response(output):
386
- try:
387
- print(f"Raw model response: {output}")
388
- output = output.strip()
389
-
390
- # Remove "Feedback:" prefix if present (case insensitive)
391
- output = re.sub(r'^feedback:\s*', '', output, flags=re.IGNORECASE)
392
-
393
- # New pattern to match [RESULT] X at the beginning
394
- begin_result_pattern = r'^\[RESULT\]\s*(\d+)\s*\n*(.*?)$'
395
- begin_match = re.search(begin_result_pattern, output, re.DOTALL | re.IGNORECASE)
396
- if begin_match:
397
- score = int(begin_match.group(1))
398
- feedback = begin_match.group(2).strip()
399
- return str(score), feedback
400
-
401
- # Existing patterns for end-of-string results...
402
- pattern = r"(.*?)\s*\[RESULT\]\s*[\(\[]?(\d+)[\)\]]?"
403
- match = re.search(pattern, output, re.DOTALL | re.IGNORECASE)
404
- if match:
405
- feedback = match.group(1).strip()
406
- score = int(match.group(2))
407
- return str(score), feedback
408
-
409
- # If no match, try to match "... Score: X"
410
- pattern = r"(.*?)\s*(?:Score|Result)\s*:\s*[\(\[]?(\d+)[\)\]]?"
411
- match = re.search(pattern, output, re.DOTALL | re.IGNORECASE)
412
- if match:
413
- feedback = match.group(1).strip()
414
- score = int(match.group(2))
415
- return str(score), feedback
416
-
417
- # Pattern to handle [Score X] at the end
418
- pattern = r"(.*?)\s*\[(?:Score|Result)\s*[\(\[]?(\d+)[\)\]]?\]$"
419
- match = re.search(pattern, output, re.DOTALL)
420
- if match:
421
- feedback = match.group(1).strip()
422
- score = int(match.group(2))
423
- return str(score), feedback
424
-
425
- # Final fallback attempt
426
- pattern = r"[\(\[]?(\d+)[\)\]]?\s*\]?$"
427
- match = re.search(pattern, output)
428
- if match:
429
- score = int(match.group(1))
430
- feedback = output[:match.start()].rstrip()
431
- # Remove any trailing brackets from feedback
432
- feedback = re.sub(r'\s*\[[^\]]*$', '', feedback).strip()
433
- return str(score), feedback
434
-
435
- return "Error", f"Failed to parse response: {output}"
436
-
437
- except Exception as e:
438
- print(f"Failed to parse response: {str(e)}")
439
- return "Error", f"Exception during parsing: {str(e)}"
440
-
441
- def atla_parse_model_response(output):
442
- """Parse response from ATLA model"""
443
- try:
444
- print(f"Raw Atla model response: {output}")
445
- output = output.strip()
446
-
447
- # Look for the Reasoning and Result sections
448
- reasoning_match = re.search(r'\*\*Reasoning:\*\*(.*?)(?=\*\*Result:|$)', output, re.DOTALL)
449
- result_match = re.search(r'\*\*Result:\*\*\s*(\d+)', output)
450
-
451
- if reasoning_match and result_match:
452
- feedback = reasoning_match.group(1).strip()
453
- score = result_match.group(1)
454
- return str(score), feedback
455
-
456
- return "Error", f"Failed to parse ATLA response format: {output}"
457
-
458
- except Exception as e:
459
- print(f"Failed to parse ATLA response: {str(e)}")
460
- return "Error", f"Exception during parsing: {str(e)}"
461
-
462
- def flow_judge_parse_model_response(output):
463
- try:
464
- print(f"Raw model response: {output}")
465
- # Convert multiple line breaks to single ones and strip whitespace
466
- output = re.sub(r'\n{2,}', '\n', output.strip())
467
-
468
- # Compile regex patterns
469
- feedback_pattern = re.compile(r"<feedback>\s*(.*?)\s*</feedback>", re.DOTALL)
470
- score_pattern = re.compile(r"<score>\s*(\d+)\s*</score>", re.DOTALL)
471
-
472
- feedback_match = feedback_pattern.search(output)
473
- score_match = score_pattern.search(output)
474
-
475
- if feedback_match or not score_match:
476
- feedback = feedback_match.group(1).strip()
477
- score = int(score_match.group(1).strip())
478
- return str(score), feedback
479
-
480
- return "Error", f"Failed to parse response: {output}"
481
-
482
- except Exception as e:
483
- print(f"Failed to parse response: {str(e)}")
484
- return "Error", f"Exception during parsing: {str(e)}"