Spaces:
Sleeping
Sleeping
# found on https://stackoverflow.com/a/52130355 to fix infinite recursion with ssl | |
# at the beginning of the script | |
import gevent.monkey | |
gevent.monkey.patch_all() | |
import json | |
from datetime import date, datetime | |
import sys | |
import time | |
import huggingface_hub | |
import sqlite3 | |
from tqdm import tqdm | |
fs = huggingface_hub.HfFileSystem() | |
import list_repos | |
SQLITE3_DB = "data/files.sqlite3" | |
def json_serial(obj): | |
if isinstance(obj, (datetime, date)): | |
return obj.isoformat() | |
raise TypeError("Type %s not serializable" % type(obj)) | |
def list_files_from_hub(repo, replace_model_in_url=True): | |
# remove models/ from the front of repo, | |
# since the "default" type of repo is a model. | |
# the underlying implementation of fs.ls appends repo as /api/models/<repo>. | |
if replace_model_in_url and repo.startswith("models/"): | |
repo = repo.replace("models/", "", 1) | |
# implement our own recursive list since it will make multiple requests, | |
# one for each ls, which are much more likely to succeed. | |
# passing recursive=True (which is undocumented anyway) does it in one request | |
# which really slams the server and might give a 500 error due to hitting some | |
# backend timeout. | |
items = fs.ls(repo) | |
for item in items: | |
if item["type"] == "directory": | |
yield from list_files_from_hub(item["name"], replace_model_in_url=False) | |
else: | |
yield item | |
def write_files_to_db(repo): | |
print("Opening database", SQLITE3_DB, file=sys.stderr) | |
con = sqlite3.connect(SQLITE3_DB) | |
cur = con.cursor() | |
print("Creating files table if not exists", file=sys.stderr) | |
cur.execute( | |
"CREATE TABLE IF NOT EXISTS files (name TEXT PRIMARY KEY, last_updated_datetime INTEGER, repo TEXT, size INTEGER, type TEXT, blob_id TEXT, is_lfs INTEGER, lfs_size INTEGER, lfs_sha256 TEXT, lfs_pointer_size INTEGER, last_commit_oid TEXT, last_commit_title TEXT, last_commit_date TEXT)" | |
) | |
con.commit() | |
print("Deleting existing rows for repo {}".format(repo), file=sys.stderr) | |
cur.execute("DELETE FROM files WHERE repo = '{}'".format(repo)) | |
con.commit() | |
print("Inserting new rows from HFFileSystem query for repo {}".format(repo), file=sys.stderr) | |
for file in tqdm(list_files_from_hub(repo)): | |
is_lfs = file["lfs"] is not None | |
# Something is wrong below -- occasionally see an error like | |
# sqlite3.OperationalError: near "t": syntax error | |
query = "INSERT INTO files VALUES ('{}', {}, '{}', {}, '{}', '{}', {}, {}, '{}', {}, '{}', '{}', '{}')".format( | |
file["name"], | |
int(time.time()), | |
repo, | |
file["size"], | |
file["type"], | |
file["blob_id"], | |
1 if is_lfs else 0, | |
file["lfs"]["size"] if is_lfs else 'NULL', | |
file["lfs"]["sha256"] if is_lfs else 'NULL', | |
file["lfs"]["pointer_size"] if is_lfs else 'NULL', | |
file["last_commit"]["oid"], | |
file["last_commit"]["title"], | |
file["last_commit"]["date"], | |
) | |
cur.execute(query) | |
con.commit() | |
def is_lfs(file): | |
return file["lfs"] is not None | |
def list_lfs_files(repo): | |
list = list_files(repo) | |
for file in list: | |
if is_lfs(file): | |
yield file | |
def list_files(repo, limit=None): | |
con = sqlite3.connect(SQLITE3_DB) | |
cur = con.cursor() | |
if limit is None: | |
res = cur.execute("SELECT * FROM files WHERE repo == '{}'".format(repo)) | |
else: | |
res = cur.execute("SELECT * FROM files WHERE repo == '{}' LIMIT {}".format(repo, limit)) | |
ret = [ | |
{ | |
"name": row[0], | |
"last_updated_datetime": row[1], | |
"size": row[2], | |
"type": row[3], | |
"blob_id": row[4], | |
"lfs": ( | |
{"size": row[6], "sha256": row[7], "pointer_size": row[8]} | |
if row[5] | |
else None | |
), | |
"last_commit": {"oid": row[9], "title": row[10], "date": row[11]}, | |
} | |
for row in res.fetchall() | |
] | |
return ret | |
if __name__ == "__main__": | |
for repo in list_repos.list_repos(): | |
write_files_to_db(repo) | |
print("Done writing to DB. Sample of 9 rows:") | |
for repo in list_repos.list_repos(limit=3): | |
for file in list_files(repo, limit=3): | |
print(json.dumps(file, default=json_serial)) | |