steffenc commited on
Commit
48f02b6
·
1 Parent(s): c60daaf

Add WIP refactor

Browse files
Files changed (5) hide show
  1. api.py +107 -0
  2. forward.py +4 -4
  3. server.py +14 -150
  4. test.py +16 -0
  5. utils.py +109 -0
api.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import asyncio
4
+
5
+ import traceback
6
+ import bittensor as bt
7
+
8
+ import utils
9
+
10
+ from typing import List
11
+ from neurons.validator import Validator
12
+ from prompting.forward import handle_response
13
+ from prompting.dendrite import DendriteResponseEvent
14
+ from prompting.protocol import PromptingSynapse, StreamPromptingSynapse
15
+ from prompting.utils.uids import get_random_uids
16
+
17
+ from aiohttp import web
18
+
19
+ from aiohttp.web_response import Response
20
+
21
+
22
+ async def single_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> Response:
23
+
24
+ try:
25
+ # Guess the task name of current request
26
+ task_name = utils.guess_task_name(messages[-1])
27
+
28
+ # Get the list of uids to query for this step.
29
+ uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
30
+ axons = [validator.metagraph.axons[uid] for uid in uids]
31
+
32
+ # Make calls to the network with the prompt.
33
+ bt.logging.info(f'Calling dendrite')
34
+ responses = await validator.dendrite(
35
+ axons=axons,
36
+ synapse=PromptingSynapse(roles=roles, messages=messages),
37
+ timeout=timeout,
38
+ )
39
+
40
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
41
+ # Encapsulate the responses in a response event (dataclass)
42
+ response_event = DendriteResponseEvent(responses, uids)
43
+
44
+ # convert dict to json
45
+ response = response_event.__state_dict__()
46
+
47
+ response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
48
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
49
+
50
+ response['task_name'] = task_name
51
+ response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
52
+
53
+ bt.logging.info(f"Response:\n {response}")
54
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
55
+
56
+ except Exception:
57
+ bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
58
+ return Response(status=500, reason="Internal error")
59
+
60
+
61
+ async def stream_response(validator: Validator, messages: List[str], roles: List[str], k: int = 5, timeout: float = 3.0, exclude: List[int] = None, prefer: str = 'longest') -> web.StreamResponse:
62
+
63
+ try:
64
+ # Guess the task name of current request
65
+ task_name = utils.guess_task_name(messages[-1])
66
+
67
+ # Get the list of uids to query for this step.
68
+ uids = get_random_uids(validator, k=k, exclude=exclude or []).tolist()
69
+ axons = [validator.metagraph.axons[uid] for uid in uids]
70
+
71
+ # Make calls to the network with the prompt.
72
+ bt.logging.info(f'Calling dendrite')
73
+ streams_responses = await validator.dendrite(
74
+ axons=axons,
75
+ synapse=StreamPromptingSynapse(roles=roles, messages=messages),
76
+ timeout=timeout,
77
+ deserialize=False,
78
+ streaming=True,
79
+ )
80
+
81
+ # Prepare the task for handling stream responses
82
+ handle_stream_responses_task = asyncio.create_task(
83
+ handle_response(responses=dict(zip(uids, streams_responses)))
84
+ )
85
+
86
+ stream_results = await handle_stream_responses_task
87
+
88
+ responses = [stream_result.synapse for stream_result in stream_results]
89
+ bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
90
+ # Encapsulate the responses in a response event (dataclass)
91
+ response_event = DendriteResponseEvent(responses, uids)
92
+
93
+ # convert dict to json
94
+ response = response_event.__state_dict__()
95
+
96
+ response['completion_is_valid'] = valid = list(map(utils.completion_is_valid, response['completions']))
97
+ valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
98
+
99
+ response['task_name'] = task_name
100
+ response['ensemble_result'] = utils.ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
101
+
102
+ bt.logging.info(f"Response:\n {response}")
103
+ return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
104
+
105
+ except Exception:
106
+ bt.logging.error(f'Encountered in {single_response.__name__}:\n{traceback.format_exc()}')
107
+ return Response(status=500, reason="Internal error")
forward.py CHANGED
@@ -16,10 +16,10 @@ from prompting.utils.misc import async_log, serialize_exception_to_string
16
  from dataclasses import dataclass
17
 
18
  @async_log
19
- async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
  result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
- return result
23
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
@@ -199,8 +199,8 @@ async def run_step(
199
 
200
  log_stream_results(stream_results)
201
 
202
- all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
-
204
  # Encapsulate the responses in a response event (dataclass)
205
  response_event = DendriteResponseEvent(
206
  responses=all_synapses_results, uids=uids, timeout=timeout
 
16
  from dataclasses import dataclass
17
 
18
  @async_log
19
+ async def generate_reference(agent):
20
  loop = asyncio.get_running_loop()
21
  result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
22
+ return result
23
 
24
  @async_log
25
  async def execute_dendrite_call(dendrite_call):
 
199
 
200
  log_stream_results(stream_results)
201
 
202
+ all_synapses_results = [stream_result.synapse for stream_result in stream_results]
203
+
204
  # Encapsulate the responses in a response event (dataclass)
205
  response_event = DendriteResponseEvent(
206
  responses=all_synapses_results, uids=uids, timeout=timeout
server.py CHANGED
@@ -2,20 +2,12 @@
2
 
3
 
4
  import os
5
- import re
6
  import time
7
  import asyncio
8
  import json
9
- import traceback
10
  import bittensor as bt
11
 
12
- from collections import Counter
13
-
14
  from neurons.validator import Validator
15
- from prompting.dendrite import DendriteResponseEvent
16
- from prompting.protocol import PromptingSynapse
17
- from prompting.utils.uids import get_random_uids
18
- from prompting.rewards import DateRewardModel, FloatDiffModel
19
  from aiohttp import web
20
  from aiohttp.web_response import Response
21
 
@@ -43,7 +35,7 @@ EXPECTED_ACCESS_KEY="hey-michal" pm2 start app.py --interpreter python3 --name a
43
 
44
  basic testing
45
  ```
46
- EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
47
  ```
48
  add --mock to test the echo stream
49
  """
@@ -51,102 +43,7 @@ add --mock to test the echo stream
51
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
52
 
53
  validator = None
54
- reward_models = {
55
- 'date_qa': DateRewardModel(),
56
- 'math': FloatDiffModel(),
57
- }
58
 
59
- def completion_is_valid(completion: str):
60
- """
61
- Get the completion statuses from the completions.
62
- """
63
- patt = re.compile(r'I\'m sorry|unable to|I cannot|I can\'t|I am unable|I am sorry|I can not|don\'t know|not sure|don\'t understand')
64
- if not len(re.findall(r'\w+',completion)) or patt.search(completion):
65
- return False
66
- return True
67
-
68
-
69
- def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
70
- """
71
- Ensemble completions from multiple models.
72
- # TODO: Measure agreement
73
- # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
74
- # TODO: Reward pipeline
75
- """
76
- if not completions:
77
- return None
78
-
79
-
80
- answer = None
81
- if task_name in ('qa', 'summarization'):
82
- # No special handling for QA or summarization
83
- supporting_completions = completions
84
-
85
- elif task_name == 'date_qa':
86
- # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
87
- dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
88
- bt.logging.info(f"Unprocessed dates: {dates}")
89
- valid_date_indices = [i for i, d in enumerate(dates) if d]
90
- valid_completions = [completions[i] for i in valid_date_indices]
91
- valid_dates = [dates[i] for i in valid_date_indices]
92
- dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
93
- if not dates:
94
- return None
95
-
96
- counter = Counter(dates)
97
- most_common, count = counter.most_common()[0]
98
- answer = most_common
99
- if count == 1:
100
- supporting_completions = valid_completions
101
- else:
102
- supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
103
-
104
- elif task_name == 'math':
105
- # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
106
- # TODO: use the median instead of the most common value
107
- vals = list(map(reward_models[task_name].extract_number, completions))
108
- vals = [val for val in vals if val]
109
- if not vals:
110
- return None
111
-
112
- most_common, count = Counter(dates).most_common()[0]
113
- bt.logging.info(f"Most common value: {most_common}, count: {count}")
114
- answer = most_common
115
- if count == 1:
116
- supporting_completions = completions
117
- else:
118
- supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
119
-
120
-
121
- bt.logging.info(f"Supporting completions: {supporting_completions}")
122
- if prefer == 'longest':
123
- preferred_completion = sorted(supporting_completions, key=len)[-1]
124
- elif prefer == 'shortest':
125
- preferred_completion = sorted(supporting_completions, key=len)[0]
126
- elif prefer == 'most_common':
127
- preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
128
- else:
129
- raise ValueError(f"Unknown ensemble preference: {prefer}")
130
-
131
- return {
132
- 'completion': preferred_completion,
133
- 'accepted_answer': answer,
134
- 'support': len(supporting_completions),
135
- 'support_indices': [completions.index(c) for c in supporting_completions],
136
- 'method': f'Selected the {prefer.replace("_", " ")} completion'
137
- }
138
-
139
- def guess_task_name(challenge: str):
140
- categories = {
141
- 'summarization': re.compile('summar|quick rundown|overview'),
142
- 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
143
- 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
144
- }
145
- for task_name, patt in categories.items():
146
- if patt.search(challenge):
147
- return task_name
148
-
149
- return 'qa'
150
 
151
  async def chat(request: web.Request) -> Response:
152
  """
@@ -178,50 +75,18 @@ async def chat(request: web.Request) -> Response:
178
  return Response(status=400)
179
 
180
  bt.logging.info(f'Request data: {request_data}')
181
- k = request_data.get('k', 10)
182
- exclude = request_data.get('exclude', [])
183
- timeout = request_data.get('timeout', 10)
184
- prefer = request_data.get('prefer', 'longest')
185
- try:
186
- # Guess the task name of current request
187
- task_name = guess_task_name(request_data['messages'][-1])
188
-
189
- # Get the list of uids to query for this step.
190
- uids = get_random_uids(validator, k=k, exclude=exclude or []).to(validator.device)
191
- axons = [validator.metagraph.axons[uid] for uid in uids]
192
-
193
- # Make calls to the network with the prompt.
194
- bt.logging.info(f'Calling dendrite')
195
- responses = await validator.dendrite(
196
- axons=axons,
197
- synapse=PromptingSynapse(roles=request_data['roles'], messages=request_data['messages']),
198
- timeout=timeout,
199
- )
200
-
201
- bt.logging.info(f"Creating DendriteResponseEvent:\n {responses}")
202
- # Encapsulate the responses in a response event (dataclass)
203
- response_event = DendriteResponseEvent(responses, uids)
204
-
205
- # convert dict to json
206
- response = response_event.__state_dict__()
207
-
208
- response['completion_is_valid'] = valid = list(map(completion_is_valid, response['completions']))
209
- valid_completions = [response['completions'][i] for i, v in enumerate(valid) if v]
210
-
211
- response['task_name'] = task_name
212
- response['ensemble_result'] = ensemble_result(valid_completions, task_name=task_name, prefer=prefer)
213
-
214
- bt.logging.info(f"Response:\n {response}")
215
- return Response(status=200, reason="I can't believe it's not butter!", text=json.dumps(response))
216
-
217
- except Exception:
218
- bt.logging.error(f'Encountered in {chat.__name__}:\n{traceback.format_exc()}')
219
- return Response(status=500, reason="Internal error")
220
 
221
 
222
 
223
  async def echo_stream(request):
224
-
225
  bt.logging.info(f'echo_stream()')
226
  # Check access key
227
  access_key = request.headers.get("api_key")
@@ -238,7 +103,7 @@ async def echo_stream(request):
238
  bt.logging.info(f'Request data: {request_data}')
239
  k = request_data.get('k', 1)
240
  exclude = request_data.get('exclude', [])
241
- timeout = request_data.get('timeout', 0.2)
242
  message = '\n\n'.join(request_data['messages'])
243
 
244
  # Create a StreamResponse
@@ -251,7 +116,7 @@ async def echo_stream(request):
251
  for word in message.split():
252
  chunk = f'{word} '
253
  await response.write(chunk.encode('utf-8'))
254
- completion += chunk
255
  time.sleep(timeout)
256
  bt.logging.info(f"Echoed: {chunk}")
257
 
@@ -269,7 +134,7 @@ async def echo_stream(request):
269
  "task_name": 'echo',
270
  "ensemble_result": {}
271
  })
272
-
273
  # Send the final JSON as part of the stream
274
  await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
275
 
@@ -277,6 +142,7 @@ async def echo_stream(request):
277
  await response.write_eof()
278
  return response
279
 
 
280
  class ValidatorApplication(web.Application):
281
  def __init__(self, *a, **kw):
282
  super().__init__(*a, **kw)
@@ -296,14 +162,12 @@ bt.logging.info(validator_app)
296
  def main(run_aio_app=True, test=False) -> None:
297
 
298
  loop = asyncio.get_event_loop()
299
-
300
- # port = validator.metagraph.axons[validator.uid].port
301
  port = 10000
302
  if run_aio_app:
303
  try:
304
  web.run_app(validator_app, port=port, loop=loop)
305
  except KeyboardInterrupt:
306
- bt.logging.info("Keyboard interrupt detected. Exiting validator.")
307
  finally:
308
  pass
309
 
 
2
 
3
 
4
  import os
 
5
  import time
6
  import asyncio
7
  import json
 
8
  import bittensor as bt
9
 
 
 
10
  from neurons.validator import Validator
 
 
 
 
11
  from aiohttp import web
12
  from aiohttp.web_response import Response
13
 
 
35
 
36
  basic testing
37
  ```
38
+ EXPECTED_ACCESS_KEY="hey-michal" python app.py --neuron.model_id mock --wallet.name sn1 --wallet.hotkey v1 --netuid 1 --neuron.tasks math --neuron.task_p 1 --neuron.device cpu
39
  ```
40
  add --mock to test the echo stream
41
  """
 
43
  EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY')
44
 
45
  validator = None
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  async def chat(request: web.Request) -> Response:
49
  """
 
75
  return Response(status=400)
76
 
77
  bt.logging.info(f'Request data: {request_data}')
78
+
79
+ stream = request_data.get('stream', False)
80
+ if stream:
81
+ return stream_response(**request_data)
82
+ else:
83
+ return single_response(**request_data)
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
 
88
  async def echo_stream(request):
89
+
90
  bt.logging.info(f'echo_stream()')
91
  # Check access key
92
  access_key = request.headers.get("api_key")
 
103
  bt.logging.info(f'Request data: {request_data}')
104
  k = request_data.get('k', 1)
105
  exclude = request_data.get('exclude', [])
106
+ timeout = request_data.get('timeout', 0.2)
107
  message = '\n\n'.join(request_data['messages'])
108
 
109
  # Create a StreamResponse
 
116
  for word in message.split():
117
  chunk = f'{word} '
118
  await response.write(chunk.encode('utf-8'))
119
+ completion += chunk
120
  time.sleep(timeout)
121
  bt.logging.info(f"Echoed: {chunk}")
122
 
 
134
  "task_name": 'echo',
135
  "ensemble_result": {}
136
  })
137
+
138
  # Send the final JSON as part of the stream
139
  await response.write(f"\n\nJSON_RESPONSE_BEGIN:\n{json_chunk}".encode('utf-8'))
140
 
 
142
  await response.write_eof()
143
  return response
144
 
145
+
146
  class ValidatorApplication(web.Application):
147
  def __init__(self, *a, **kw):
148
  super().__init__(*a, **kw)
 
162
  def main(run_aio_app=True, test=False) -> None:
163
 
164
  loop = asyncio.get_event_loop()
 
 
165
  port = 10000
166
  if run_aio_app:
167
  try:
168
  web.run_app(validator_app, port=port, loop=loop)
169
  except KeyboardInterrupt:
170
+ bt.logging.warning("Keyboard interrupt detected. Exiting validator.")
171
  finally:
172
  pass
173
 
test.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+
4
+
5
+ def test_query_network():
6
+ pass
7
+
8
+ def test_filter_completions():
9
+ pass
10
+
11
+
12
+ def test_guess_task_name():
13
+ pass
14
+
15
+ def test_ensemble_completions():
16
+ pass
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import bittensor as bt
3
+
4
+ from collections import Counter
5
+
6
+ from prompting.rewards import DateRewardModel, FloatDiffModel
7
+
8
+ UNSUCCESSFUL_RESPONSE_PATTERNS = ["I'm sorry", "unable to", "I cannot", "I can't", "I am unable", "I am sorry", "I can not", "don't know", "not sure", "don't understand", "not capable"]
9
+
10
+ reward_models = {
11
+ 'date_qa': DateRewardModel(),
12
+ 'math': FloatDiffModel(),
13
+ }
14
+
15
+ def completion_is_valid(completion: str):
16
+ """
17
+ Get the completion statuses from the completions.
18
+ """
19
+ if not completion.strip():
20
+ return False
21
+
22
+ patt = re.compile(r'\b(?:' + '|'.join(UNSUCCESSFUL_RESPONSE_PATTERNS) + r')\b', re.IGNORECASE)
23
+ if not len(re.findall(r'\w+',completion)) or patt.search(completion):
24
+ return False
25
+ return True
26
+
27
+
28
+ def ensemble_result(completions: list, task_name: str, prefer: str = 'longest'):
29
+ """
30
+ Ensemble completions from multiple models.
31
+ # TODO: Measure agreement
32
+ # TODO: Figure out how to mitigate the cabal effect (large groups will appear to be more credible)
33
+ # TODO: Reward pipeline
34
+ """
35
+ if not completions:
36
+ return None
37
+
38
+ answer = None
39
+ if task_name in ('qa', 'summarization'):
40
+ # No special handling for QA or summarization
41
+ supporting_completions = completions
42
+
43
+ elif task_name == 'date_qa':
44
+ # filter the completions to be the ones that contain valid dates and if there are multiple dates, select the most common one (with support > 1)
45
+ dates = list(map(reward_models[task_name].parse_dates_from_text, completions))
46
+ bt.logging.info(f"Unprocessed dates: {dates}")
47
+ valid_date_indices = [i for i, d in enumerate(dates) if d]
48
+ valid_completions = [completions[i] for i in valid_date_indices]
49
+ valid_dates = [dates[i] for i in valid_date_indices]
50
+ dates = [f"{d[0].strftime('%-d %B')} {d[1]}" for d in valid_dates]
51
+ if not dates:
52
+ return None
53
+
54
+ counter = Counter(dates)
55
+ most_common, count = counter.most_common()[0]
56
+ answer = most_common
57
+ if count == 1:
58
+ supporting_completions = valid_completions
59
+ else:
60
+ supporting_completions = [c for i, c in enumerate(valid_completions) if dates[i]==most_common]
61
+
62
+ elif task_name == 'math':
63
+ # filter the completions to be the ones that contain valid numbers and if there are multiple values, select the most common one (with support > 1)
64
+ # TODO: use the median instead of the most common value
65
+ vals = list(map(reward_models[task_name].extract_number, completions))
66
+ vals = [val for val in vals if val]
67
+ if not vals:
68
+ return None
69
+
70
+ most_common, count = Counter(dates).most_common()[0]
71
+ bt.logging.info(f"Most common value: {most_common}, count: {count}")
72
+ answer = most_common
73
+ if count == 1:
74
+ supporting_completions = completions
75
+ else:
76
+ supporting_completions = [c for i, c in enumerate(completions) if vals[i]==most_common]
77
+
78
+
79
+ bt.logging.info(f"Supporting completions: {supporting_completions}")
80
+ if prefer == 'longest':
81
+ preferred_completion = sorted(supporting_completions, key=len)[-1]
82
+ elif prefer == 'shortest':
83
+ preferred_completion = sorted(supporting_completions, key=len)[0]
84
+ elif prefer == 'most_common':
85
+ preferred_completion = max(set(supporting_completions), key=supporting_completions.count)
86
+ else:
87
+ raise ValueError(f"Unknown ensemble preference: {prefer}")
88
+
89
+ return {
90
+ 'completion': preferred_completion,
91
+ 'accepted_answer': answer,
92
+ 'support': len(supporting_completions),
93
+ 'support_indices': [completions.index(c) for c in supporting_completions],
94
+ 'method': f'Selected the {prefer.replace("_", " ")} completion'
95
+ }
96
+
97
+ def guess_task_name(challenge: str):
98
+
99
+ # TODO: use a pre-trained classifier to guess the task name
100
+ categories = {
101
+ 'summarization': re.compile('summar|quick rundown|overview'),
102
+ 'date_qa': re.compile('exact date|tell me when|on what date|on what day|was born?|died?'),
103
+ 'math': re.compile('math|solve|solution| sum |problem|geometric|vector|calculate|degrees|decimal|factorial'),
104
+ }
105
+ for task_name, patt in categories.items():
106
+ if patt.search(challenge):
107
+ return task_name
108
+
109
+ return 'qa'