|
import json |
|
import os |
|
from typing import List, Optional |
|
|
|
from huggingface_hub import CommitOperationAdd, Discussion, HfApi, HfFileSystem |
|
|
|
from .tasks import TASKS_PRETTY_REVERSE |
|
|
|
|
|
class AlreadyExists(Exception): |
|
pass |
|
|
|
|
|
class SubmissionUploader: |
|
"""Class for adding new files to a dataset on a Hub and opening a PR. |
|
|
|
Heavily influenced by these amazing spaces: |
|
* https://huggingface.co/spaces/safetensors/convert |
|
* https://huggingface.co/spaces/gaia-benchmark/leaderboard |
|
""" |
|
|
|
def __init__(self, dataset_id: str): |
|
self._api = HfApi(token=os.environ["HF_TOKEN"]) |
|
self._fs = HfFileSystem(token=os.environ["HF_TOKEN"]) |
|
self._dataset_id = dataset_id |
|
|
|
def _get_previous_pr(self, pr_title: str) -> Optional[Discussion]: |
|
"""Searches among discussions of dataset repo for a PR with the given title.""" |
|
try: |
|
discussions = self._api.get_repo_discussions( |
|
repo_id=self._dataset_id, repo_type="dataset" |
|
) |
|
except Exception: |
|
return None |
|
for discussion in discussions: |
|
if ( |
|
discussion.status == "open" |
|
and discussion.is_pull_request |
|
and discussion.title == pr_title |
|
): |
|
return discussion |
|
|
|
def _upload_files( |
|
self, |
|
task_id: str, |
|
model_folder: str, |
|
model_name_pretty: str, |
|
model_availability: str, |
|
urls: str, |
|
context_size: str, |
|
submitted_by: str, |
|
filenames: Optional[List[str]], |
|
) -> List[CommitOperationAdd]: |
|
|
|
commit_operations = [ |
|
CommitOperationAdd( |
|
path_in_repo=f"{task_id}/predictions/{model_folder}/{os.path.basename(filename)}", |
|
path_or_fileobj=filename, |
|
) |
|
for filename in filenames |
|
] |
|
|
|
|
|
metadata_dict = { |
|
"model_name": model_name_pretty, |
|
"model_availability": model_availability, |
|
"urls": urls, |
|
"context_size": context_size, |
|
"submitted_by": submitted_by, |
|
} |
|
with open("metadata.json", "w") as f: |
|
json.dump(metadata_dict, f) |
|
commit_operations.append( |
|
CommitOperationAdd( |
|
path_in_repo=f"{task_id}/predictions/{model_folder}/metadata.json", |
|
path_or_fileobj="metadata.json", |
|
) |
|
) |
|
|
|
return commit_operations |
|
|
|
def upload_files( |
|
self, |
|
task_pretty: str, |
|
model_folder: str, |
|
model_name_pretty: str, |
|
model_availability: str, |
|
urls: str, |
|
context_size: str, |
|
submitted_by: str, |
|
filenames: Optional[List[str]], |
|
force: bool = False, |
|
) -> str: |
|
try: |
|
pr_title = f"π New submission to {task_pretty} task: {model_name_pretty} with {context_size} context size from {submitted_by}" |
|
|
|
task_id = TASKS_PRETTY_REVERSE[task_pretty] |
|
|
|
if not force: |
|
if model_name_pretty in self._fs.ls( |
|
f"datasets/{self._dataset_id}/{task_id}/predictions" |
|
) and all( |
|
filename |
|
in self._fs.ls( |
|
f"datasets/{self._dataset_id}/{task_id}/predictions/{model_name_pretty}" |
|
) |
|
for filename in filenames + ["metadata.json"] |
|
): |
|
return ( |
|
f"{model_name_pretty} is already present in {self._dataset_id}." |
|
) |
|
|
|
prev_pr = self._get_previous_pr(pr_title) |
|
if prev_pr is not None: |
|
url = f"https://huggingface.co/datasets/{self._dataset_id}/discussions/{prev_pr.num}" |
|
return f"{self._dataset_id} already has an open PR for this submission: {url}." |
|
|
|
commit_operations = self._upload_files( |
|
task_id=task_id, |
|
model_folder=model_folder, |
|
model_name_pretty=model_name_pretty, |
|
model_availability=model_availability, |
|
urls=urls, |
|
context_size=context_size, |
|
submitted_by=submitted_by, |
|
filenames=filenames, |
|
) |
|
|
|
new_pr = self._api.create_commit( |
|
repo_id=self._dataset_id, |
|
operations=commit_operations, |
|
commit_message=pr_title, |
|
commit_description=f"""New submission to {task_pretty} task in ποΈ Long Code Arena benchmark! |
|
|
|
* Model name: {model_name_pretty} |
|
* Model availability: {model_availability} |
|
* Context Size: {context_size} |
|
* Relevant URLs: {urls} |
|
* Submitted By: {submitted_by} |
|
""", |
|
create_pr=True, |
|
repo_type="dataset", |
|
) |
|
return f"π PR created at {new_pr.pr_url}." |
|
|
|
except Exception: |
|
return "An exception occured." |
|
|