|
import json |
|
import os |
|
import glob |
|
import pprint |
|
import re |
|
from datetime import datetime, timezone |
|
|
|
import click |
|
from colorama import Fore |
|
from huggingface_hub import HfApi, snapshot_download |
|
from huggingface_hub.hf_api import ModelInfo |
|
|
|
API = HfApi() |
|
|
|
|
|
def get_model_size(model_info: ModelInfo, precision: str): |
|
size_pattern = re.compile(r"(\d+\.)?\d+(b|m)") |
|
try: |
|
model_size = round(model_info.safetensors["total"] / 1e9, 3) |
|
except (AttributeError, TypeError ): |
|
try: |
|
size_match = re.search(size_pattern, model_info.modelId.split("/")[-1].lower()) |
|
model_size = size_match.group(0) |
|
model_size = round(float(model_size[:-1]) if model_size[-1] == "b" else float(model_size[:-1]) / 1e3, 3) |
|
except AttributeError: |
|
return 0 |
|
|
|
size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.modelId.split("/")[-1].lower()) else 1 |
|
model_size = size_factor * model_size |
|
return model_size |
|
|
|
|
|
def update_request_files(requests_path): |
|
request_files = os.path.join( |
|
requests_path, "*/*.json" |
|
) |
|
request_files = glob.glob(request_files) |
|
|
|
request_files = sorted(request_files, reverse=True) |
|
for tmp_request_file in request_files: |
|
with open(tmp_request_file, "r") as f: |
|
req_content = json.load(f) |
|
new_req_content = add_model_info(req_content) |
|
|
|
|
|
if new_req_content != req_content: |
|
with open(tmp_request_file, "w") as f: |
|
f.write(json.dumps(new_req_content, indent=4)) |
|
|
|
def add_model_info(entry): |
|
|
|
model = entry["model"] |
|
revision = entry["revision"] |
|
|
|
try: |
|
model_info = API.model_info(repo_id=model, revision=revision) |
|
except Exception: |
|
print(f"Could not get model information for {model} revision {revision}") |
|
return entry |
|
|
|
new_entry = entry.copy() |
|
|
|
model_size = get_model_size(model_info=model_info, precision='float16') |
|
new_entry["params"] = model_size |
|
|
|
new_entry["likes"] = model_info.likes |
|
|
|
|
|
try: |
|
license = model_info.cardData["license"] |
|
new_entry["license"] = license |
|
except Exception: |
|
print(f"No license for {model} revision {revision}") |
|
|
|
print(json.dumps(new_entry, indent=4)) |
|
return new_entry |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
update_request_files("/Volumes/Data-case-sensitive/requests") |
|
|