Spaces:
Build error
Build error
File size: 3,186 Bytes
51ff9e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import argparse
import json
import logging
import unidiff
from evaluation.benchmarks.swe_bench.resource.swt_bench_constants import (
MAP_VERSION_TO_INSTALL,
)
_LOGGER = logging.getLogger(__name__)
def remove_setup_files(model_patch: str, instance: dict, delete_setup_changes: bool):
"""Discard all changes that a patch applies to files changes by the pre_install script and that are reproduction scripts (top-level script)"""
setup_files = ['setup.py', 'tox.ini', 'pyproject.toml']
pre_install = (
MAP_VERSION_TO_INSTALL.get(instance['repo'], {})
.get(instance['version'], {})
.get('pre_install', [])
)
relevant_files = (
[
file
for file in setup_files
if any(file in install and 'sed' in install for install in pre_install)
]
if delete_setup_changes
else []
)
for i in range(10):
try:
# Appearently outputs.jsonl has .strip() applied, so we try to reconstruct the original patch by adding auxiliary whitespace
patch = unidiff.PatchSet(model_patch + i * '\n')
break
except unidiff.UnidiffParseError:
pass
to_delete = []
for i, file in enumerate(patch):
if (
any(f in file.source_file for f in relevant_files)
or file.target_file.count('/') == 1
):
to_delete.append(i)
for i in reversed(to_delete):
del patch[i]
return str(patch)
def main(
prediction_file: str,
):
"""Main function to extract the model patches from the OpenHands prediction file and turn them into the expected SWT-Bench format."""
with open(prediction_file) as f:
for line in f:
pred = json.loads(line)
try:
git_diff = pred['test_result']['git_patch']
except KeyError:
_LOGGER.warning(
'Warning: No git diff found for instance %s', pred['instance_id']
)
continue
ci_mode = pred['metadata']['details'].get('mode', '') == 'swt-ci'
try:
git_diff = remove_setup_files(git_diff, pred['instance'], ci_mode)
except: # noqa: E722
_LOGGER.warning(
'Warning: Invalid git diff found for instance %s',
pred['instance_id'],
)
print(
json.dumps(
{
'instance_id': pred['instance_id'],
'model_name_or_path': f'{pred["metadata"]["llm_config"]["openrouter_app_name"]}__{pred["metadata"]["agent_class"]}__{pred["metadata"]["llm_config"]["model"]}',
'model_patch': git_diff,
'full_output': json.dumps(pred),
}
)
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--prediction_file',
type=str,
required=True,
help='Path to the prediction file (.../outputs.jsonl)',
)
args = parser.parse_args()
main(args.prediction_file)
|