File size: 3,701 Bytes
eea405a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# pylint: disable=no-member
import gradio as gr
import requests
from huggingface_hub import HfApi
import pandas as pd
import plotly.express as px

HF_API = HfApi()


def format_repo_size(r_size: int) -> str:
    units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"}
    order = 0
    while r_size >= 1024 and order < len(units) - 1:
        r_size /= 1024
        order += 1
    return f"{r_size:.2f} {units[order]}"


def repo_files(r_type: str, r_id: str) -> dict:
    r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True)
    files = {}
    for sibling in r_info.siblings:
        ext = sibling.rfilename.split(".")[-1]
        if ext in files:
            files[ext]["size"] += sibling.size
            files[ext]["count"] += 1
        else:
            files[ext] = {}
            files[ext]["size"] = sibling.size
            files[ext]["count"] = 1
    return files


def repo_size(r_type, r_id):
    r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type)
    repo_sizes = {}
    for branch in r_refs.branches:
        try:
            response = requests.get(
                f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}",
                timeout=1000,
            )
            response = response.json()
        except Exception:
            response = {}

        size = response.get("size")
        if size is not None:
            repo_sizes[branch.name] = size
    return repo_sizes


def get_repo_info(r_type, r_id):
    repo_sizes = repo_size(r_type, r_id)
    repo_files_info = repo_files(r_type, r_id)
    rf_sizes_df = (
        pd.DataFrame(repo_files_info)
        .T.reset_index(names="ext")
        .sort_values(by="size", ascending=False)
    )
    r_sizes_df = pd.DataFrame(repo_sizes, index=["size"]).T.reset_index(names="branch")
    r_sizes_df["formatted_size"] = r_sizes_df["size"].apply(format_repo_size)
    rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size)
    r_sizes_df.columns = ["Branch", "bytes", "Size"]
    rf_sizes_df.columns = ["Extension", "bytes", "Count", "Size"]
    rf_sizes_plot = px.pie(
        rf_sizes_df,
        values="bytes",
        names="Extension",
        hover_data=["Size"],
        title=f"File Distribution in {r_id}",
        hole=0.3,
    )
    return (
        gr.Row(visible=True),
        gr.Dataframe(
            value=rf_sizes_df[["Extension", "Count", "Size"]],
            visible=True,
        ),
        gr.Plot(rf_sizes_plot, visible=True),
        gr.Dataframe(value=r_sizes_df[["Branch", "Size"]], visible=True),
    )


with gr.Blocks(theme="citrus") as demo:
    gr.Markdown("# Repository Information")
    gr.Markdown(
        "Enter a repository ID and repository type and get back information about the repository's files and branches."
    )
    with gr.Blocks():
        repo_id = gr.Textbox(label="Repository ID", placeholder="123456")
        repo_type = gr.Radio(
            choices=["model", "dataset", "space"],
            label="Repository Type",
            value="model",
        )
        search_button = gr.Button(value="Search")
    with gr.Blocks():
        with gr.Row(visible=False) as results:
            with gr.Column():
                gr.Markdown("## File Information")
                with gr.Row():
                    file_info = gr.Dataframe(visible=False)
                    file_info_plot = gr.Plot(visible=False)
                gr.Markdown("## Branch Sizes")
                branch_sizes = gr.Dataframe(visible=False)

    search_button.click(
        get_repo_info,
        inputs=[repo_type, repo_id],
        outputs=[results, file_info, file_info_plot, branch_sizes],
    )

demo.launch()