Spaces:
Build error
Build error
import argparse | |
import json | |
import logging | |
import unidiff | |
from evaluation.benchmarks.swe_bench.resource.swt_bench_constants import ( | |
MAP_VERSION_TO_INSTALL, | |
) | |
_LOGGER = logging.getLogger(__name__) | |
def remove_setup_files(model_patch: str, instance: dict, delete_setup_changes: bool): | |
"""Discard all changes that a patch applies to files changes by the pre_install script and that are reproduction scripts (top-level script)""" | |
setup_files = ['setup.py', 'tox.ini', 'pyproject.toml'] | |
pre_install = ( | |
MAP_VERSION_TO_INSTALL.get(instance['repo'], {}) | |
.get(instance['version'], {}) | |
.get('pre_install', []) | |
) | |
relevant_files = ( | |
[ | |
file | |
for file in setup_files | |
if any(file in install and 'sed' in install for install in pre_install) | |
] | |
if delete_setup_changes | |
else [] | |
) | |
for i in range(10): | |
try: | |
# Appearently outputs.jsonl has .strip() applied, so we try to reconstruct the original patch by adding auxiliary whitespace | |
patch = unidiff.PatchSet(model_patch + i * '\n') | |
break | |
except unidiff.UnidiffParseError: | |
pass | |
to_delete = [] | |
for i, file in enumerate(patch): | |
if ( | |
any(f in file.source_file for f in relevant_files) | |
or file.target_file.count('/') == 1 | |
): | |
to_delete.append(i) | |
for i in reversed(to_delete): | |
del patch[i] | |
return str(patch) | |
def main( | |
prediction_file: str, | |
): | |
"""Main function to extract the model patches from the OpenHands prediction file and turn them into the expected SWT-Bench format.""" | |
with open(prediction_file) as f: | |
for line in f: | |
pred = json.loads(line) | |
try: | |
git_diff = pred['test_result']['git_patch'] | |
except KeyError: | |
_LOGGER.warning( | |
'Warning: No git diff found for instance %s', pred['instance_id'] | |
) | |
continue | |
ci_mode = pred['metadata']['details'].get('mode', '') == 'swt-ci' | |
try: | |
git_diff = remove_setup_files(git_diff, pred['instance'], ci_mode) | |
except: # noqa: E722 | |
_LOGGER.warning( | |
'Warning: Invalid git diff found for instance %s', | |
pred['instance_id'], | |
) | |
print( | |
json.dumps( | |
{ | |
'instance_id': pred['instance_id'], | |
'model_name_or_path': f'{pred["metadata"]["llm_config"]["openrouter_app_name"]}__{pred["metadata"]["agent_class"]}__{pred["metadata"]["llm_config"]["model"]}', | |
'model_patch': git_diff, | |
'full_output': json.dumps(pred), | |
} | |
) | |
) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--prediction_file', | |
type=str, | |
required=True, | |
help='Path to the prediction file (.../outputs.jsonl)', | |
) | |
args = parser.parse_args() | |
main(args.prediction_file) | |