p-ferreira commited on
Commit
a34ad94
·
2 Parent(s): 94c3b3d 48f02b6

Merge remote-tracking branch 'origin/stream' into features/mock-validator-integration

Browse files
Files changed (5) hide show
  1. api.py +107 -0
  2. forward.py +4 -4
  3. server.py +34 -126
  4. test.py +16 -0
  5. utils.py +109 -0
api.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import asyncio
4
+
5
+ import traceback
6
+ import bittensor as bt
7
+
8
+ import utils
9
+
10
+ from typing import List
11
+ from neurons.validator import Validator
12
+ from prompting.forward import handle_response
13
+ from prompting.dendrite import DendriteResponseEvent
14
+ from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
15
+ from prompting.utils.uids import get_random_uids
16
+
17
+ from aiohttp import web
18
+
19
+ from aiohttp.web_response import Response
20
+
21
+
22
+ async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response:
23
+
24
+ try:
25
+ # Guess the task name of current request
26
+ task_name = utils.guess_task_name(messages[-1])
27
+
28
+ # Get the list of uids to query for this step.
29
+ uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
30
+ axons = [validator.metagraph.axons[uid] for uid in uids]
31
+
32
+ # Make calls to the network with the prompt.
33
+ bt.logging.info(f'Calling dendrite')
34
+ responses = await validator.dendrite(
35
+ axons=axons,
36
+ synapse=PromptingSynapse(roles=roles, messages=messages),
37
+ timeout=timeout,
38
+ )
39
+
40
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
41
+ # Encapsulate the responses in a response event (dataclass)
42
+ response_event = DendriteResponseEvent(responses, uids)
43
+
44
+ # convert dict to json
45
+ response = response_event.__state_dict__()
46
+
47
+ response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
48
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
49
+
50
+ response['task_name'] = task_name
51
+ response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
52
+
53
+ bt.logging.info(f"Response:\n {response}")
54
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
55
+
56
+ except Exception:
57
+ bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
58
+ return Response(status=500, reason="Internal error")
59
+
60
+
61
+ async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse:
62
+
63
+ try:
64
+ # Guess the task name of current request
65
+ task_name = utils.guess_task_name(messages[-1])
66
+
67
+ # Get the list of uids to query for this step.
68
+ uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
69
+ axons = [validator.metagraph.axons[uid] for uid in uids]
70
+
71
+ # Make calls to the network with the prompt.
72
+ bt.logging.info(f'Calling dendrite')
73
+ streams_responses = await validator.dendrite(
74
+ axons=axons,
75
+ synapse=StreamPromptingSynapse(roles=roles, messages=messages),
76
+ timeout=timeout,
77
+ deserialize=False,
78
+ streaming=True,
79
+ )
80
+
81
+ # Prepare the task for handling stream responses
82
+ handle_stream_responses_task = asyncio.create_task(
83
+ handle_response(responses=dict(zip(uids, streams_responses)))
84
+ )
85
+
86
+ stream_results = await handle_stream_responses_task
87
+
88
+ responses = [stream_result.synapse for stream_result in stream_results]
89
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
90
+ # Encapsulate the responses in a response event (dataclass)
91
+ response_event = DendriteResponseEvent(responses, uids)
92
+
93
+ # convert dict to json
94
+ response = response_event.__state_dict__()
95
+
96
+ response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
97
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
98
+
99
+ response['task_name'] = task_name
100
+ response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
101
+
102
+ bt.logging.info(f"Response:\n {response}")
103
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
104
+
105
+ except Exception:
106
+ bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
107
+ return Response(status=500, reason="Internal error")
forward.py CHANGED
@@ -16,10 +16,10 @@ from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
18
  @async_log
19
- async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
  result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
- return result
23
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
@@ -199,8 +199,8 @@ async def run_step(
199
 
200
  log_stream_results(stream_results)
201
 
202
- all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
-
204
  # Encapsulate the responses in a response event (dataclass)
205
  response_event = DendriteResponseEvent(
206
  responses=all_synapses_results, uids=uids, timeout=timeout
 
16
  from dataclasses import dataclass
17
 
18
  @async_log
19
+ async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
  result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ return result
23
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
 
199
 
200
  log_stream_results(stream_results)
201
 
202
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
+
204
  # Encapsulate the responses in a response event (dataclass)
205
  response_event = DendriteResponseEvent(
206
  responses=all_synapses_results, uids=uids, timeout=timeout
server.py CHANGED
@@ -2,11 +2,9 @@
2
 
3
 
4
  import os
5
- import re
6
  import time
7
  import asyncio
8
  import json
9
- import traceback
10
  import bittensor as bt
11
  from collections import Counter
12
  from validator_wrapper import QueryValidatorParams, S1ValidatorWrapper
@@ -38,7 +36,7 @@ EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name a
38
 
39
  basic testing
40
  ```
41
- EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
42
  ```
43
  add --mock to test the echo stream
44
  """
@@ -46,102 +44,7 @@ add --mock to test the echo stream
46
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
47
 
48
  validator = None
49
- reward_models = {
50
- 'date_qa': DateRewardModel(),
51
- 'math': FloatDiffModel(),
52
- }
53
 
54
- def completion_is_valid(completion: str):
55
- """
56
- Get the completion statuses from the completions.
57
- """
58
- patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
59
- if not len(re.findall(r'\w+',completion)) or patt.search(completion):
60
- return False
61
- return True
62
-
63
-
64
- def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
65
- """
66
- Ensemble completions from multiple models.
67
- # TODO: Measure agreement
68
- # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
69
- # TODO: Reward pipeline
70
- """
71
- if not completions:
72
- return None
73
-
74
-
75
- answer = None
76
- if task_name in ('qa', 'summarization'):
77
- # No special handling for QA or summarization
78
- supporting_completions = completions
79
-
80
- elif task_name == 'date_qa':
81
- # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
82
- dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
83
- bt.logging.info(f"Unprocessed dates: {dates}")
84
- valid_date_indices = [i for i, d in enumerate(dates) if d]
85
- valid_completions = [completions[i] for i in valid_date_indices]
86
- valid_dates = [dates[i] for i in valid_date_indices]
87
- dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
88
- if not dates:
89
- return None
90
-
91
- counter = Counter(dates)
92
- most_common, count = counter.most_common()[0]
93
- answer = most_common
94
- if count == 1:
95
- supporting_completions = valid_completions
96
- else:
97
- supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
98
-
99
- elif task_name == 'math':
100
- # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
101
- # TODO: use the median instead of the most common value
102
- vals = list(map(reward_models[task_name].extract_number, completions))
103
- vals = [val for val in vals if val]
104
- if not vals:
105
- return None
106
-
107
- most_common, count = Counter(dates).most_common()[0]
108
- bt.logging.info(f"Most common value: {most_common}, count: {count}")
109
- answer = most_common
110
- if count == 1:
111
- supporting_completions = completions
112
- else:
113
- supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
114
-
115
-
116
- bt.logging.info(f"Supporting completions: {supporting_completions}")
117
- if prefer == 'longest':
118
- preferred_completion = sorted(supporting_completions, key=len)[-1]
119
- elif prefer == 'shortest':
120
- preferred_completion = sorted(supporting_completions, key=len)[0]
121
- elif prefer == 'most_common':
122
- preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
123
- else:
124
- raise ValueError(f"Unknown ensemble preference: {prefer}")
125
-
126
- return {
127
- 'completion': preferred_completion,
128
- 'accepted_answer': answer,
129
- 'support': len(supporting_completions),
130
- 'support_indices': [completions.index(c) for c in supporting_completions],
131
- 'method': f'Selected the {prefer.replace("_", " ")} completion'
132
- }
133
-
134
- def guess_task_name(challenge: str):
135
- categories = {
136
- 'summarization': re.compile('summar|quick rundown|overview'),
137
- 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
138
- 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
139
- }
140
- for task_name, patt in categories.items():
141
- if patt.search(challenge):
142
- return task_name
143
-
144
- return 'qa'
145
 
146
  async def chat(request: web.Request) -> Response:
147
  """
@@ -171,37 +74,43 @@ async def chat(request: web.Request) -> Response:
171
  except ValueError:
172
  bt.logging.error(f'Invalid request data: {request_data}')
173
  return Response(status=400)
174
-
175
- bt.logging.info(f'Request data: {request_data}')
176
-
177
- try:
178
- # Guess the task name of current request
179
- task_name = guess_task_name(request_data['messages'][-1])
180
 
181
- # Get the list of uids to query for this step.
182
- params = QueryValidatorParams.from_dict(request_data)
183
- response_event = await validator.query_validator(params)
184
 
185
- # convert dict to json
186
- response = response_event.__state_dict__()
187
 
188
- response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
189
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
190
 
191
- response['task_name'] = task_name
192
- prefer = request_data.get('prefer', 'longest')
193
- response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
194
 
195
- bt.logging.info(f"Response:\n {response}")
196
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
197
- except Exception:
198
- bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
199
- return Response(status=500, reason="Internal error")
 
 
 
 
 
 
 
 
200
 
201
 
202
 
203
  async def echo_stream(request):
204
-
205
  bt.logging.info(f'echo_stream()')
206
  # Check access key
207
  access_key = request.headers.get("api_key")
@@ -218,7 +127,7 @@ async def echo_stream(request):
218
  bt.logging.info(f'Request data: {request_data}')
219
  k = request_data.get('k', 1)
220
  exclude = request_data.get('exclude', [])
221
- timeout = request_data.get('timeout', 0.2)
222
  message = '\n\n'.join(request_data['messages'])
223
 
224
  # Create a StreamResponse
@@ -231,7 +140,7 @@ async def echo_stream(request):
231
  for word in message.split():
232
  chunk = f'{word} '
233
  await response.write(chunk.encode('utf-8'))
234
- completion += chunk
235
  time.sleep(timeout)
236
  bt.logging.info(f"Echoed: {chunk}")
237
 
@@ -249,7 +158,7 @@ async def echo_stream(request):
249
  "task_name": 'echo',
250
  "ensemble_result": {}
251
  })
252
-
253
  # Send the final JSON as part of the stream
254
  await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
255
 
@@ -257,6 +166,7 @@ async def echo_stream(request):
257
  await response.write_eof()
258
  return response
259
 
 
260
  class ValidatorApplication(web.Application):
261
  def __init__(self, *a, **kw):
262
  super().__init__(*a, **kw)
@@ -275,14 +185,12 @@ bt.logging.info(validator_app)
275
 
276
  def main(run_aio_app=True, test=False) -> None:
277
  loop = asyncio.get_event_loop()
278
-
279
- # port = validator.metagraph.axons[validator.uid].port
280
  port = 10000
281
  if run_aio_app:
282
  try:
283
  web.run_app(validator_app, port=port, loop=loop)
284
  except KeyboardInterrupt:
285
- bt.logging.info("Keyboard interrupt detected. Exiting validator.")
286
  finally:
287
  pass
288
 
 
2
 
3
 
4
  import os
 
5
  import time
6
  import asyncio
7
  import json
 
8
  import bittensor as bt
9
  from collections import Counter
10
  from validator_wrapper import QueryValidatorParams, S1ValidatorWrapper
 
36
 
37
  basic testing
38
  ```
39
+ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
40
  ```
41
  add --mock to test the echo stream
42
  """
 
44
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
45
 
46
  validator = None
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  async def chat(request: web.Request) -> Response:
50
  """
 
74
  except ValueError:
75
  bt.logging.error(f'Invalid request data: {request_data}')
76
  return Response(status=400)
77
+
78
+ # try:
79
+ # # Guess the task name of current request
80
+ # task_name = guess_task_name(request_data['messages'][-1])
 
 
81
 
82
+ # # Get the list of uids to query for this step.
83
+ # params = QueryValidatorParams.from_dict(request_data)
84
+ # response_event = await validator.query_validator(params)
85
 
86
+ # # convert dict to json
87
+ # response = response_event.__state_dict__()
88
 
89
+ # response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
90
+ # valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
91
 
92
+ # response['task_name'] = task_name
93
+ # prefer = request_data.get('prefer', 'longest')
94
+ # response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
95
 
96
+ # bt.logging.info(f"Response:\n {response}")
97
+ # return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
98
+ # except Exception:
99
+ # bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
100
+ # return Response(status=500, reason="Internal error")
101
+ bt.logging.info(f'Request data: {request_data}')
102
+
103
+ stream = request_data.get('stream', False)
104
+ if stream:
105
+ return stream_response(**request_data)
106
+ else:
107
+ return single_response(**request_data)
108
+
109
 
110
 
111
 
112
  async def echo_stream(request):
113
+
114
  bt.logging.info(f'echo_stream()')
115
  # Check access key
116
  access_key = request.headers.get("api_key")
 
127
  bt.logging.info(f'Request data: {request_data}')
128
  k = request_data.get('k', 1)
129
  exclude = request_data.get('exclude', [])
130
+ timeout = request_data.get('timeout', 0.2)
131
  message = '\n\n'.join(request_data['messages'])
132
 
133
  # Create a StreamResponse
 
140
  for word in message.split():
141
  chunk = f'{word} '
142
  await response.write(chunk.encode('utf-8'))
143
+ completion += chunk
144
  time.sleep(timeout)
145
  bt.logging.info(f"Echoed: {chunk}")
146
 
 
158
  "task_name": 'echo',
159
  "ensemble_result": {}
160
  })
161
+
162
  # Send the final JSON as part of the stream
163
  await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
164
 
 
166
  await response.write_eof()
167
  return response
168
 
169
+
170
  class ValidatorApplication(web.Application):
171
  def __init__(self, *a, **kw):
172
  super().__init__(*a, **kw)
 
185
 
186
  def main(run_aio_app=True, test=False) -> None:
187
  loop = asyncio.get_event_loop()
 
 
188
  port = 10000
189
  if run_aio_app:
190
  try:
191
  web.run_app(validator_app, port=port, loop=loop)
192
  except KeyboardInterrupt:
193
+ bt.logging.warning("Keyboard interrupt detected. Exiting validator.")
194
  finally:
195
  pass
196
 
test.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+
4
+
5
+ def test_query_network():
6
+ pass
7
+
8
+ def test_filter_completions():
9
+ pass
10
+
11
+
12
+ def test_guess_task_name():
13
+ pass
14
+
15
+ def test_ensemble_completions():
16
+ pass
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import bittensor as bt
3
+
4
+ from collections import Counter
5
+
6
+ from prompting.rewards import DateRewardModel, FloatDiffModel
7
+
8
+ UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
9
+
10
+ reward_models = {
11
+ 'date_qa': DateRewardModel(),
12
+ 'math': FloatDiffModel(),
13
+ }
14
+
15
+ def completion_is_valid(completion: str):
16
+ """
17
+ Get the completion statuses from the completions.
18
+ """
19
+ if not completion.strip():
20
+ return False
21
+
22
+ patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE)
23
+ if not len(re.findall(r'\w+',completion)) or patt.search(completion):
24
+ return False
25
+ return True
26
+
27
+
28
+ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
29
+ """
30
+ Ensemble completions from multiple models.
31
+ # TODO: Measure agreement
32
+ # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
33
+ # TODO: Reward pipeline
34
+ """
35
+ if not completions:
36
+ return None
37
+
38
+ answer = None
39
+ if task_name in ('qa', 'summarization'):
40
+ # No special handling for QA or summarization
41
+ supporting_completions = completions
42
+
43
+ elif task_name == 'date_qa':
44
+ # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
45
+ dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
46
+ bt.logging.info(f"Unprocessed dates: {dates}")
47
+ valid_date_indices = [i for i, d in enumerate(dates) if d]
48
+ valid_completions = [completions[i] for i in valid_date_indices]
49
+ valid_dates = [dates[i] for i in valid_date_indices]
50
+ dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
51
+ if not dates:
52
+ return None
53
+
54
+ counter = Counter(dates)
55
+ most_common, count = counter.most_common()[0]
56
+ answer = most_common
57
+ if count == 1:
58
+ supporting_completions = valid_completions
59
+ else:
60
+ supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
61
+
62
+ elif task_name == 'math':
63
+ # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
64
+ # TODO: use the median instead of the most common value
65
+ vals = list(map(reward_models[task_name].extract_number, completions))
66
+ vals = [val for val in vals if val]
67
+ if not vals:
68
+ return None
69
+
70
+ most_common, count = Counter(dates).most_common()[0]
71
+ bt.logging.info(f"Most common value: {most_common}, count: {count}")
72
+ answer = most_common
73
+ if count == 1:
74
+ supporting_completions = completions
75
+ else:
76
+ supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
77
+
78
+
79
+ bt.logging.info(f"Supporting completions: {supporting_completions}")
80
+ if prefer == 'longest':
81
+ preferred_completion = sorted(supporting_completions, key=len)[-1]
82
+ elif prefer == 'shortest':
83
+ preferred_completion = sorted(supporting_completions, key=len)[0]
84
+ elif prefer == 'most_common':
85
+ preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
86
+ else:
87
+ raise ValueError(f"Unknown ensemble preference: {prefer}")
88
+
89
+ return {
90
+ 'completion': preferred_completion,
91
+ 'accepted_answer': answer,
92
+ 'support': len(supporting_completions),
93
+ 'support_indices': [completions.index(c) for c in supporting_completions],
94
+ 'method': f'Selected the {prefer.replace("_", " ")} completion'
95
+ }
96
+
97
+ def guess_task_name(challenge: str):
98
+
99
+ # TODO: use a pre-trained classifier to guess the task name
100
+ categories = {
101
+ 'summarization': re.compile('summar|quick rundown|overview'),
102
+ 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
103
+ 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
104
+ }
105
+ for task_name, patt in categories.items():
106
+ if patt.search(challenge):
107
+ return task_name
108
+
109
+ return 'qa'