Spaces:
Sleeping
Sleeping
File size: 4,026 Bytes
17ff0d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import dataclasses
import json
from typing import Dict, List, Optional, Union
from datasets import load_dataset
import sdlm.data.instruction_evals.ifeval_instruction_registry as instructions_registry
@dataclasses.dataclass
class InputExample:
key: int
instruction_id_list: List[str]
prompt: str
kwargs: List[Dict[str, Optional[Union[str, int]]]]
@dataclasses.dataclass
class OutputExample:
instruction_id_list: List[str]
prompt: str
response: str
follow_all_instructions: bool
follow_instruction_list: List[bool]
def load_ifeval_prompts():
"""Read inputs from jsonl."""
inputs = []
with open("sdlm/data/instruction_evals/ifeval_input_data.jsonl", "r") as f:
for l in f:
example = json.loads(l)
inputs.append(
InputExample(key=example["key"],
instruction_id_list=example["instruction_id_list"],
prompt=example["prompt"],
kwargs=example["kwargs"]))
return inputs
def test_instruction_following_strict(
inp,
prompt_to_response,
):
"""Tests response to see if instrutions are followed."""
response = prompt_to_response[inp.prompt]
instruction_list = inp.instruction_id_list
is_following_list = []
for index, instruction_id in enumerate(instruction_list):
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
instruction.build_description(**inp.kwargs[index])
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=inp.prompt)
if response.strip() and instruction.check_following(response):
is_following_list.append(True)
else:
is_following_list.append(False)
return OutputExample(
instruction_id_list=inp.instruction_id_list,
prompt=inp.prompt,
response=response,
follow_all_instructions=all(is_following_list),
follow_instruction_list=is_following_list,
)
def test_instruction_following_loose(
inp,
prompt_to_response,
):
"""Tests response for an upper bound for following instructions."""
response = prompt_to_response[inp.prompt]
r = response.split("\n")
response_remove_first = "\n".join(r[1:]).strip()
response_remove_last = "\n".join(r[:-1]).strip()
response_remove_both = "\n".join(r[1:-1]).strip()
revised_response = response.replace("*", "")
revised_response_remove_first = response_remove_first.replace("*", "")
revised_response_remove_last = response_remove_last.replace("*", "")
revised_response_remove_both = response_remove_both.replace("*", "")
all_responses = [
response,
revised_response,
response_remove_first,
response_remove_last,
response_remove_both,
revised_response_remove_first,
revised_response_remove_last,
revised_response_remove_both,
]
instruction_list = inp.instruction_id_list
is_following_list = []
for index, instruction_id in enumerate(instruction_list):
instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
instruction = instruction_cls(instruction_id)
instruction.build_description(**inp.kwargs[index])
args = instruction.get_instruction_args()
if args and "prompt" in args:
instruction.build_description(prompt=inp.prompt)
is_following = False
for r in all_responses:
if r.strip() and instruction.check_following(r):
is_following = True
break
is_following_list.append(is_following)
return OutputExample(
instruction_id_list=inp.instruction_id_list,
prompt=inp.prompt,
response=response,
follow_all_instructions=all(is_following_list),
follow_instruction_list=is_following_list,
)
|