Spaces:
Sleeping
Sleeping
import os | |
import json | |
import sys | |
import time | |
import grequests | |
import sqlite3 | |
from tqdm import tqdm | |
import list_files | |
import list_repos | |
SQLITE3_DB = "data/reconstructions.sqlite3" | |
HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") | |
XET_CAS_ENDPOINT = os.getenv("XET_CAS_ENDPOINT", "https://cas-server.xethub.hf.co") | |
RESOLVE_URL_TEMPLATE = HF_ENDPOINT + "/{}/resolve/main" | |
def exception_handler(req, exc): | |
print(exc, file=sys.stderr) | |
def list_reconstructions_from_hub(repo): | |
print( | |
"Listing reconstructions using:\nHF Hub Endpoint: {}\nXet CAS Endpoint: {}".format( | |
HF_ENDPOINT, XET_CAS_ENDPOINT | |
), | |
file=sys.stderr, | |
) | |
ret = [] | |
files = [] | |
resolve_reqs = [] | |
reconstruct_reqs = [] | |
err_count = 0 | |
print("Listing files for repo {}".format(repo), file=sys.stderr) | |
total = 0 | |
for i, file in tqdm(enumerate(list_files.list_lfs_files(repo))): | |
total += 1 | |
files.append(file["name"]) | |
if repo.startswith("models/"): | |
repo = repo.replace("models/", "", 1) | |
url = file["name"].replace(repo, RESOLVE_URL_TEMPLATE.format(repo), 1) | |
headers = {"Authorization": "Bearer {}".format(os.getenv("HF_TOKEN"))} | |
resolve_reqs.append( | |
grequests.head(url, headers=headers, allow_redirects=False) | |
) | |
print("", file=sys.stderr) | |
print("Calling /resolve/ for repo {}".format(repo), file=sys.stderr) | |
for i, resp in tqdm( | |
grequests.imap_enumerated( | |
resolve_reqs, size=4, exception_handler=exception_handler | |
), | |
total=total, | |
): | |
if resp is None: | |
err_count += 1 | |
continue | |
# todo: use refresh_route when access_token is expired | |
refresh_route = resp.headers.get("x-xet-refresh-route") | |
xet_hash = resp.headers.get("x-xet-hash") | |
access_token = resp.headers.get("x-xet-access-token") | |
if xet_hash is not None and xet_hash != "": | |
url = "{}/reconstruction/{}".format(XET_CAS_ENDPOINT, xet_hash) | |
headers = {"Authorization": "Bearer {}".format(access_token)} | |
reconstruct_reqs.append(grequests.get(url, headers=headers)) | |
print("", file=sys.stderr) | |
print( | |
"Calling /reconstruct/ with grequests for repo {}".format(repo), | |
file=sys.stderr, | |
) | |
for i, resp in tqdm( | |
grequests.imap_enumerated( | |
reconstruct_reqs, size=4, exception_handler=exception_handler | |
), | |
total=total, | |
): | |
if resp is None: | |
continue | |
if resp.status_code != 200: | |
continue | |
body = resp.json() | |
for term in body["terms"]: | |
entry = { | |
"start": term["range"]["start"], | |
"end": term["range"]["end"], | |
"file_path": files[i + err_count], | |
"xorb_id": term["hash"], | |
"unpacked_length": term["unpacked_length"] | |
} | |
ret.append(entry) | |
return ret | |
def list_reconstructions(repos, limit=None): | |
ret = [] | |
con = sqlite3.connect(SQLITE3_DB) | |
cur = con.cursor() | |
for repo in repos: | |
if limit is None: | |
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}'".format(repo)) | |
else: | |
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}' LIMIT {}".format(repo, limit)) | |
for row in res.fetchall(): | |
entry = { | |
"xorb_id": row[1], | |
"last_updated_timestamp": row[2], | |
"repo": row[3], | |
"file_path": row[4], | |
"unpacked_length": row[5], | |
"start": row[6], | |
"end": row[7] | |
} | |
ret.append(entry) | |
return ret | |
def write_files_to_db(repo): | |
print("Opening database", SQLITE3_DB, file=sys.stderr) | |
con = sqlite3.connect(SQLITE3_DB) | |
cur = con.cursor() | |
print("Creating reconstructions table if not exists", file=sys.stderr) | |
cur.execute( | |
"CREATE TABLE IF NOT EXISTS reconstructions (id INTEGER PRIMARY KEY AUTOINCREMENT, xorb_id TEXT, last_updated_datetime INTEGER, repo TEXT, file_path TEXT, unpacked_length INTEGER, start INTEGER, end INTEGER)" | |
) | |
con.commit() | |
print("Deleting existing rows for repo {}".format(repo), file=sys.stderr) | |
cur.execute("DELETE FROM reconstructions WHERE repo = '{}'".format(repo)) | |
con.commit() | |
print("Inserting rows from HFFileSystem query", file=sys.stderr) | |
for reconstruction in list_reconstructions_from_hub(repo): | |
query = "INSERT INTO reconstructions VALUES (NULL, '{}', {}, '{}', '{}', {}, {}, {})".format( | |
reconstruction["xorb_id"], | |
int(time.time()), | |
repo, | |
reconstruction["file_path"], | |
reconstruction["unpacked_length"], | |
reconstruction["start"], | |
reconstruction["end"] | |
) | |
cur.execute(query) | |
con.commit() | |
if __name__ == "__main__": | |
for repo in list_repos.list_repos(): | |
write_files_to_db(repo) | |
print("Done writing to DB. Sample of 5 rows:") | |
json.dump( | |
list_reconstructions(list_repos.list_repos(), limit=5), | |
sys.stdout, | |
sort_keys=True, | |
indent=4, | |
) | |