kaikaidai commited on
Commit
f2bc5f6
·
verified ·
1 Parent(s): 3261484

Synced repo using 'sync_with_huggingface' Github Action

Browse files
random_sample/arena_interface.py CHANGED
@@ -6,14 +6,17 @@ from dotenv import load_dotenv
6
  load_dotenv()
7
 
8
  from .gen_api_answer import (
9
- get_atla_response
 
 
10
  )
11
 
12
  from .prompts import (
13
  DEFAULT_EVAL_CRITERIA,
14
  DEFAULT_EVAL_PROMPT,
15
  DEFAULT_EVAL_PROMPT_EDITABLE,
16
- FIXED_EVAL_SUFFIX
 
17
  )
18
 
19
  from .random_sample_generation import (
@@ -67,6 +70,15 @@ def create_arena_interface():
67
  value=DEFAULT_EVAL_PROMPT,
68
  visible=False
69
  )
 
 
 
 
 
 
 
 
 
70
  with gr.Row():
71
  # Left side - Input section
72
  with gr.Column(scale=1):
@@ -234,36 +246,62 @@ def create_arena_interface():
234
  # Add a new state variable to track first game
235
  first_game_state = gr.State(True) # Initialize as True
236
 
237
- # Update the submit function to parse the evaluation criteria
238
  def submit_and_store(
 
239
  use_reference,
240
  eval_criteria_text,
241
  human_input,
242
  ai_response,
243
- ground_truth_input,
244
  ):
245
- # Build prompt data dictionary
246
- prompt_data = {
247
- 'human_input': human_input,
248
- 'ai_response': ai_response,
249
- 'ground_truth_input': ground_truth_input if use_reference else None,
250
- 'eval_criteria': eval_criteria_text,
251
- }
252
-
253
- # Get response from Atla
254
- response = get_atla_response(
255
- model_name="AtlaAI/Selene-1-Mini-Llama-3.1-8B",
256
- prompt=prompt_data,
257
- max_tokens=500,
258
- temperature=0.01
259
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  # Response now contains score and critique directly
262
  if isinstance(response, dict) and 'score' in response and 'critique' in response:
263
  score = str(response['score'])
264
  critique = response['critique']
265
  else:
266
- # Handle error case
267
  score = "Error"
268
  critique = str(response)
269
 
@@ -274,22 +312,11 @@ def create_arena_interface():
274
  gr.update(value="🎲"),
275
  ]
276
 
277
- # Update the click handler to use False for is_first_game after first submission
278
- def create_submit_handler():
279
- first_game = True
280
-
281
- def handler(*args):
282
- nonlocal first_game
283
- result = submit_and_store(*args)
284
- first_game = False # Set to False after first submission
285
- return result
286
-
287
- return handler
288
-
289
- # Update the send_btn click handler
290
  send_btn.click(
291
  fn=submit_and_store,
292
  inputs=[
 
293
  use_reference_toggle,
294
  eval_criteria_text,
295
  human_input,
 
6
  load_dotenv()
7
 
8
  from .gen_api_answer import (
9
+ get_atla_response,
10
+ get_selene_mini_response,
11
+ parse_selene_mini_response
12
  )
13
 
14
  from .prompts import (
15
  DEFAULT_EVAL_CRITERIA,
16
  DEFAULT_EVAL_PROMPT,
17
  DEFAULT_EVAL_PROMPT_EDITABLE,
18
+ ATLA_PROMPT,
19
+ ATLA_PROMPT_WITH_REFERENCE
20
  )
21
 
22
  from .random_sample_generation import (
 
70
  value=DEFAULT_EVAL_PROMPT,
71
  visible=False
72
  )
73
+ with gr.Row():
74
+ # Add model selector dropdown at the top
75
+ model_selector = gr.Dropdown(
76
+ choices=["Selene", "Selene Mini"],
77
+ value="Selene",
78
+ label="Choose your Atla Model",
79
+ interactive=True
80
+ )
81
+
82
  with gr.Row():
83
  # Left side - Input section
84
  with gr.Column(scale=1):
 
246
  # Add a new state variable to track first game
247
  first_game_state = gr.State(True) # Initialize as True
248
 
249
+ # Update the submit function to handle both models
250
  def submit_and_store(
251
+ model_choice,
252
  use_reference,
253
  eval_criteria_text,
254
  human_input,
255
  ai_response,
256
+ ground_truth,
257
  ):
258
+ if model_choice == "Selene Mini":
259
+ # Prepare prompt based on reference mode
260
+ prompt_template = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
261
+ prompt = prompt_template.format(
262
+ human_input=human_input,
263
+ ai_response=ai_response,
264
+ eval_criteria=eval_criteria_text,
265
+ ground_truth=ground_truth if use_reference else ""
266
+ )
267
+
268
+ print("\n=== Debug: Prompt being sent to Selene Mini ===")
269
+ print(prompt)
270
+ print("============================================\n")
271
+
272
+ # Get and parse response
273
+ raw_response = get_selene_mini_response(
274
+ model_name="AtlaAI/Selene-1-Mini-Llama-3.1-8B",
275
+ prompt=prompt,
276
+ max_tokens=500,
277
+ temperature=0.01
278
+ )
279
+ response = parse_selene_mini_response(raw_response)
280
+ else:
281
+ # Selene API logic
282
+ prompt_data = {
283
+ 'human_input': human_input,
284
+ 'ai_response': ai_response,
285
+ 'ground_truth': ground_truth if use_reference else None,
286
+ 'eval_criteria': eval_criteria_text,
287
+ }
288
+
289
+ print("\n=== Debug: Prompt data being sent to Selene API ===")
290
+ print(json.dumps(prompt_data, indent=2))
291
+ print("============================================\n")
292
+
293
+ response = get_atla_response(
294
+ model_name="AtlaAI/Selene-1-Mini-Llama-3.1-8B",
295
+ prompt=prompt_data,
296
+ max_tokens=500,
297
+ temperature=0.01
298
+ )
299
 
300
  # Response now contains score and critique directly
301
  if isinstance(response, dict) and 'score' in response and 'critique' in response:
302
  score = str(response['score'])
303
  critique = response['critique']
304
  else:
 
305
  score = "Error"
306
  critique = str(response)
307
 
 
312
  gr.update(value="🎲"),
313
  ]
314
 
315
+ # Update the send_btn click handler with new input
 
 
 
 
 
 
 
 
 
 
 
 
316
  send_btn.click(
317
  fn=submit_and_store,
318
  inputs=[
319
+ model_selector,
320
  use_reference_toggle,
321
  eval_criteria_text,
322
  human_input,
random_sample/gen_api_answer.py CHANGED
@@ -1,14 +1,16 @@
1
  from openai import OpenAI
2
  import anthropic
3
  from together import Together
4
- import os
5
  from atla import Atla
6
  from dotenv import load_dotenv
7
  from .prompts import (
8
- JUDGE_SYSTEM_PROMPT,
9
- ATLA_PROMPT,
10
- ATLA_PROMPT_WITH_REFERENCE
11
  )
 
 
 
 
12
 
13
  load_dotenv()
14
 
@@ -57,7 +59,7 @@ def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, te
57
  # Extract components from the prompt data
58
  model_input = prompt.get('human_input', '')
59
  model_output = prompt.get('ai_response', '')
60
- expected_output = prompt.get('ground_truth_input')
61
  evaluation_criteria = prompt.get('eval_criteria', '')
62
 
63
  response = atla_client.evaluation.create(
@@ -74,4 +76,73 @@ def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, te
74
  "critique": response.result.evaluation.critique
75
  }
76
  except Exception as e:
77
- return f"Error with Atla model {model_name}: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from openai import OpenAI
2
  import anthropic
3
  from together import Together
4
+ import os
5
  from atla import Atla
6
  from dotenv import load_dotenv
7
  from .prompts import (
8
+ JUDGE_SYSTEM_PROMPT
 
 
9
  )
10
+ from transformers import AutoTokenizer
11
+ import requests
12
+ import json
13
+ import re
14
 
15
  load_dotenv()
16
 
 
59
  # Extract components from the prompt data
60
  model_input = prompt.get('human_input', '')
61
  model_output = prompt.get('ai_response', '')
62
+ expected_output = prompt.get('ground_truth')
63
  evaluation_criteria = prompt.get('eval_criteria', '')
64
 
65
  response = atla_client.evaluation.create(
 
76
  "critique": response.result.evaluation.critique
77
  }
78
  except Exception as e:
79
+ return f"Error with Atla model {model_name}: {str(e)}"
80
+
81
+ def get_selene_mini_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
82
+ """Get response from HF endpoint for Atla model"""
83
+ try:
84
+ headers = {
85
+ "Accept": "application/json",
86
+ "Authorization": f"Bearer {hf_api_key}",
87
+ "Content-Type": "application/json"
88
+ }
89
+
90
+ # Create messages list for chat template
91
+ messages = []
92
+ if system_prompt:
93
+ messages.append({"role": "system", "content": system_prompt})
94
+ messages.append({"role": "user", "content": prompt})
95
+
96
+ # Apply chat template
97
+ model_id = "AtlaAI/Selene-1-Mini-Llama-3.1-8B"
98
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
99
+ formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
100
+
101
+ payload = {
102
+ "inputs": formatted_prompt,
103
+ "parameters": {
104
+ "max_new_tokens": max_tokens,
105
+ "return_full_text": False,
106
+ "temperature": temperature,
107
+ "seed": 42,
108
+ "add_generation_prompt": True
109
+ }
110
+ }
111
+
112
+ response = requests.post(
113
+ "https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
114
+ headers=headers,
115
+ json=payload
116
+ )
117
+ return response.json()[0]["generated_text"]
118
+ except Exception as e:
119
+ return f"Error with Atla model {model_name}: {str(e)}"
120
+
121
+ def parse_selene_mini_response(response_text):
122
+ """Parse the response from Selene Mini to extract score and critique"""
123
+ try:
124
+ # Clean up the response text
125
+ response_text = response_text.strip()
126
+
127
+ # More flexible regex patterns
128
+ reasoning_pattern = r'\*\*Reasoning:?\*\*\s*(.*?)(?=\*\*Result|$)'
129
+ result_pattern = r'\*\*Result:?\*\*\s*(\d+)'
130
+
131
+ reasoning_match = re.search(reasoning_pattern, response_text, re.DOTALL | re.IGNORECASE)
132
+ result_match = re.search(result_pattern, response_text, re.IGNORECASE)
133
+
134
+ if reasoning_match and result_match:
135
+ critique = reasoning_match.group(1).strip()
136
+ score = result_match.group(1)
137
+ return {"score": score, "critique": critique}
138
+ else:
139
+ # If we can't parse it properly, let's return the raw response as critique
140
+ return {
141
+ "score": "Error",
142
+ "critique": f"Failed to parse response. Raw response:\n{response_text}"
143
+ }
144
+ except Exception as e:
145
+ return {
146
+ "score": "Error",
147
+ "critique": f"Error parsing response: {str(e)}\nRaw response:\n{response_text}"
148
+ }
random_sample/prompts.py CHANGED
@@ -88,7 +88,7 @@ ATLA_PROMPT_WITH_REFERENCE = """You are tasked with evaluating a response based
88
  {eval_criteria}
89
 
90
  Reference answer:
91
- {ground_truth_input}"""
92
 
93
  # Judge system prompt for non-Prometheus models
94
  JUDGE_SYSTEM_PROMPT = """Please act as an impartial judge and evaluate based on the user's instruction. Your output format should strictly adhere to JSON as follows: {"feedback": "<write feedback>", "result": <numerical score>}. Ensure the output is valid JSON, without additional formatting or explanations."""
 
88
  {eval_criteria}
89
 
90
  Reference answer:
91
+ {ground_truth}"""
92
 
93
  # Judge system prompt for non-Prometheus models
94
  JUDGE_SYSTEM_PROMPT = """Please act as an impartial judge and evaluate based on the user's instruction. Your output format should strictly adhere to JSON as follows: {"feedback": "<write feedback>", "result": <numerical score>}. Ensure the output is valid JSON, without additional formatting or explanations."""