import os import json import sys import time import grequests import sqlite3 from tqdm import tqdm import list_files import list_repos SQLITE3_DB = "data/reconstructions.sqlite3" HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") XET_CAS_ENDPOINT = os.getenv("XET_CAS_ENDPOINT", "https://cas-server.xethub.hf.co") RESOLVE_URL_TEMPLATE = HF_ENDPOINT + "/{}/resolve/main" def exception_handler(req, exc): print(exc, file=sys.stderr) def list_reconstructions_from_hub(repo): print( "Listing reconstructions using:\nHF Hub Endpoint: {}\nXet CAS Endpoint: {}".format( HF_ENDPOINT, XET_CAS_ENDPOINT ), file=sys.stderr, ) ret = [] files = [] resolve_reqs = [] reconstruct_reqs = [] err_count = 0 print("Listing files for repo {}".format(repo), file=sys.stderr) total = 0 for i, file in tqdm(enumerate(list_files.list_lfs_files(repo))): total += 1 files.append(file["name"]) if repo.startswith("models/"): repo = repo.replace("models/", "", 1) url = file["name"].replace(repo, RESOLVE_URL_TEMPLATE.format(repo), 1) headers = {"Authorization": "Bearer {}".format(os.getenv("HF_TOKEN"))} resolve_reqs.append( grequests.head(url, headers=headers, allow_redirects=False) ) print("", file=sys.stderr) print("Calling /resolve/ for repo {}".format(repo), file=sys.stderr) for i, resp in tqdm( grequests.imap_enumerated( resolve_reqs, size=4, exception_handler=exception_handler ), total=total, ): if resp is None: err_count += 1 continue # todo: use refresh_route when access_token is expired refresh_route = resp.headers.get("x-xet-refresh-route") xet_hash = resp.headers.get("x-xet-hash") access_token = resp.headers.get("x-xet-access-token") if xet_hash is not None and xet_hash != "": url = "{}/reconstruction/{}".format(XET_CAS_ENDPOINT, xet_hash) headers = {"Authorization": "Bearer {}".format(access_token)} reconstruct_reqs.append(grequests.get(url, headers=headers)) print("", file=sys.stderr) print( "Calling /reconstruct/ with grequests for repo {}".format(repo), file=sys.stderr, ) for i, resp in tqdm( grequests.imap_enumerated( reconstruct_reqs, size=4, exception_handler=exception_handler ), total=total, ): if resp is None: continue if resp.status_code != 200: continue body = resp.json() for term in body["terms"]: entry = { "start": term["range"]["start"], "end": term["range"]["end"], "file_path": files[i + err_count], "xorb_id": term["hash"], "unpacked_length": term["unpacked_length"] } ret.append(entry) return ret def list_reconstructions(repos, limit=None): ret = [] con = sqlite3.connect(SQLITE3_DB) cur = con.cursor() for repo in repos: if limit is None: res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}'".format(repo)) else: res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}' LIMIT {}".format(repo, limit)) for row in res.fetchall(): entry = { "xorb_id": row[1], "last_updated_timestamp": row[2], "repo": row[3], "file_path": row[4], "unpacked_length": row[5], "start": row[6], "end": row[7] } ret.append(entry) return ret def write_files_to_db(repo): print("Opening database", SQLITE3_DB, file=sys.stderr) con = sqlite3.connect(SQLITE3_DB) cur = con.cursor() print("Creating reconstructions table if not exists", file=sys.stderr) cur.execute( "CREATE TABLE IF NOT EXISTS reconstructions (id INTEGER PRIMARY KEY AUTOINCREMENT, xorb_id TEXT, last_updated_datetime INTEGER, repo TEXT, file_path TEXT, unpacked_length INTEGER, start INTEGER, end INTEGER)" ) con.commit() print("Deleting existing rows for repo {}".format(repo), file=sys.stderr) cur.execute("DELETE FROM reconstructions WHERE repo = '{}'".format(repo)) con.commit() print("Inserting rows from HFFileSystem query", file=sys.stderr) for reconstruction in list_reconstructions_from_hub(repo): query = "INSERT INTO reconstructions VALUES (NULL, '{}', {}, '{}', '{}', {}, {}, {})".format( reconstruction["xorb_id"], int(time.time()), repo, reconstruction["file_path"], reconstruction["unpacked_length"], reconstruction["start"], reconstruction["end"] ) cur.execute(query) con.commit() if __name__ == "__main__": for repo in list_repos.list_repos(): write_files_to_db(repo) print("Done writing to DB. Sample of 5 rows:") json.dump( list_reconstructions(list_repos.list_repos(), limit=5), sys.stdout, sort_keys=True, indent=4, )