File size: 11,731 Bytes
eea405a
 
 
 
2b02896
eea405a
 
cc6d57f
adbb8fc
 
eea405a
 
 
 
34ae673
adbb8fc
 
 
 
9d6f412
34ae673
9d6f412
 
adbb8fc
 
 
34ae673
adbb8fc
 
 
 
 
 
 
 
34ae673
adbb8fc
34ae673
adbb8fc
 
 
 
 
 
 
 
 
34ae673
 
adbb8fc
 
 
34ae673
adbb8fc
34ae673
adbb8fc
 
 
 
 
34ae673
 
 
 
 
adbb8fc
 
 
 
34ae673
adbb8fc
 
 
 
34ae673
 
 
 
adbb8fc
 
 
34ae673
 
 
 
adbb8fc
 
 
34ae673
adbb8fc
 
 
 
34ae673
adbb8fc
34ae673
adbb8fc
 
 
 
34ae673
adbb8fc
 
9d6f412
adbb8fc
 
 
 
 
 
 
34ae673
9d6f412
adbb8fc
34ae673
adbb8fc
 
 
 
34ae673
adbb8fc
 
 
34ae673
 
 
 
 
 
 
 
 
 
 
adbb8fc
 
 
 
 
 
34ae673
 
9d6f412
adbb8fc
34ae673
 
adbb8fc
 
34ae673
 
adbb8fc
 
 
2676db7
34ae673
adbb8fc
34ae673
 
adbb8fc
 
34ae673
adbb8fc
 
34ae673
 
 
adbb8fc
 
 
 
 
 
 
eea405a
34ae673
 
 
 
 
 
 
 
 
eea405a
 
 
 
 
 
 
 
 
 
9d6f412
eea405a
 
 
 
 
9d6f412
eea405a
 
 
 
9d6f412
eea405a
adbb8fc
eea405a
 
 
ed3daa4
 
 
 
 
eea405a
 
 
 
 
 
 
 
 
 
ed3daa4
 
 
add58d5
 
eea405a
 
9d6f412
 
 
 
 
eea405a
 
 
 
2b02896
 
adbb8fc
2b02896
 
 
 
 
 
 
 
b4cf22f
2b02896
 
adbb8fc
add58d5
 
 
b4cf22f
add58d5
9d6f412
 
 
add58d5
9d6f412
add58d5
9d6f412
add58d5
b4cf22f
add58d5
9d6f412
 
 
 
eea405a
9d6f412
 
eea405a
 
 
9d6f412
eea405a
 
adbb8fc
 
b4cf22f
add58d5
eea405a
 
 
b4cf22f
9d6f412
eea405a
9d6f412
eea405a
 
cc6d57f
b4cf22f
535e3c5
 
b4cf22f
 
eea405a
7472693
eea405a
 
 
 
 
b4cf22f
eea405a
9d6f412
2676db7
 
 
adbb8fc
b4cf22f
 
9d6f412
 
 
 
b4cf22f
9d6f412
 
 
 
 
 
 
 
eea405a
 
 
 
b4cf22f
eea405a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# pylint: disable=no-member
import gradio as gr
import requests
from huggingface_hub import HfApi
from huggingface_hub.errors import RepositoryNotFoundError
import pandas as pd
import plotly.express as px
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from collections import defaultdict
import numpy as np

HF_API = HfApi()


def apply_power_scaling(sizes: list, exponent=0.2) -> list:
    """Apply custom power scaling to the sizes."""
    return [size**exponent if size is not None else 0 for size in sizes]


def count_chunks(sizes: list | int) -> list:
    """Count the number of chunks, which are 64KB each in size; always roundup"""
    if isinstance(sizes, int):
        return int(np.ceil(sizes / 64_000))
    return [int(np.ceil(size / 64_000)) if size is not None else 0 for size in sizes]


def build_hierarchy(siblings: list) -> dict:
    """Builds a hierarchical structure from the list of RepoSibling objects."""
    hierarchy = defaultdict(dict)

    for sibling in siblings:
        path_parts = sibling.rfilename.split("/")
        size = sibling.lfs.size if sibling.lfs else sibling.size

        current_level = hierarchy
        for part in path_parts[:-1]:
            current_level = current_level.setdefault(part, {})
        current_level[path_parts[-1]] = size

    return hierarchy


def calculate_directory_sizes(hierarchy):
    """Recursively calculates the size of each directory as the sum of its contents."""
    total_size = 0

    for key, value in hierarchy.items():
        if isinstance(value, dict):
            dir_size = calculate_directory_sizes(value)
            hierarchy[key] = {
                "__size__": dir_size,
                **value,
            }
            total_size += dir_size
        else:
            total_size += value

    return total_size


def build_full_path(current_parent, key):
    return f"{current_parent}/{key}" if current_parent else key


def flatten_hierarchy(hierarchy, root_name="Repository"):
    """Flatten a nested dictionary into Plotly-compatible treemap data with a defined root node."""
    labels = []
    parents = []
    sizes = []
    ids = []

    # Recursively process the hierarchy
    def process_level(current_hierarchy, current_parent):
        for key, value in current_hierarchy.items():
            full_path = build_full_path(current_parent, key)
            if isinstance(value, dict) and "__size__" in value:
                # Handle directories
                dir_size = value.pop("__size__")
                labels.append(key)
                parents.append(current_parent)
                sizes.append(dir_size)
                ids.append(full_path)
                process_level(value, full_path)
            else:
                # Handle files
                labels.append(key)
                parents.append(current_parent)
                sizes.append(value)
                ids.append(full_path)

    # Add the root node
    total_size = calculate_directory_sizes(hierarchy)
    labels.append(root_name)
    parents.append("")
    sizes.append(total_size)
    ids.append(root_name)

    # Process the hierarchy
    process_level(hierarchy, root_name)

    return labels, parents, sizes, ids


def visualize_repo_treemap(r_info: dict, r_id: str) -> px.treemap:
    """Visualizes the repository as a treemap with directory sizes and human-readable tooltips."""
    siblings = r_info.siblings
    hierarchy = build_hierarchy(siblings)

    # Calculate directory sizes
    calculate_directory_sizes(hierarchy)

    # Flatten the hierarchy for Plotly
    labels, parents, sizes, ids = flatten_hierarchy(hierarchy, r_id)

    # Scale for vix
    scaled_sizes = apply_power_scaling(sizes)

    # Format the original sizes using the helper function
    formatted_sizes = [
        (format_repo_size(size) if size is not None else None) for size in sizes
    ]

    chunks = count_chunks(sizes)
    colors = scaled_sizes[:]
    colors[0] = -1
    max_value = max(scaled_sizes)
    normalized_colors = [value / max_value if value > 0 else 0 for value in colors]

    # Define the colorscale; mimics the plasma scale
    colorscale = [
        [0.0, "#0d0887"],
        [0.5, "#bd3786"],
        [1.0, "#f0f921"],
    ]

    # Create the treemap
    fig = px.treemap(
        names=labels,
        parents=parents,
        values=scaled_sizes,
        color=normalized_colors,
        color_continuous_scale=colorscale,
        title=f"{r_id} by Chunks",
        custom_data=[formatted_sizes, chunks],
        height=1000,
        ids=ids,
    )

    fig.update_traces(marker={"colors": ["lightgrey"] + normalized_colors[1:]})

    # Add subtitle by updating the layout
    fig.update_layout(
        title={
            "text": f"{r_id} file and chunk treemap<br><span style='font-size:14px;'>Color represents size in bytes/chunks.</span>",
            "x": 0.5,
            "xanchor": "center",
        },
        coloraxis_showscale=False,
    )

    # Customize the hover template
    fig.update_traces(
        hovertemplate=(
            "<b>%{label}</b><br>"
            "Size: %{customdata[0]}<br>"
            "# of Chunks: %{customdata[1]}"
        )
    )
    fig.update_traces(root_color="lightgrey")

    return fig


def format_repo_size(r_size: int) -> str:
    """
    Convert a repository size in bytes to a human-readable string with appropriate units.

    Args:
        r_size (int): The size of the repository in bytes.

    Returns:
        str: The formatted size string with appropriate units (B, KB, MB, GB, TB, PB).
    """
    units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"}
    order = 0
    while r_size >= 1024 and order < len(units) - 1:
        r_size /= 1024
        order += 1
    return f"{r_size:.2f} {units[order]}"


def repo_files(r_type: str, r_id: str) -> dict:
    r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True)
    fig = visualize_repo_treemap(r_info, r_id)
    files = {}
    for sibling in r_info.siblings:
        ext = sibling.rfilename.split(".")[-1]
        if ext in files:
            files[ext]["size"] += sibling.size
            files[ext]["chunks"] += count_chunks(sibling.size)
            files[ext]["count"] += 1
        else:
            files[ext] = {}
            files[ext]["size"] = sibling.size
            files[ext]["chunks"] = count_chunks(sibling.size)
            files[ext]["count"] = 1
    return files, fig


def repo_size(r_type, r_id):
    try:
        r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type)
    except RepositoryNotFoundError:
        gr.Warning(f"Repository is gated, branch information for {r_id} not available.")
        return {}
    repo_sizes = {}
    for branch in r_refs.branches:
        try:
            response = requests.get(
                f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}",
                timeout=1000,
            )
            response = response.json()
        except Exception:
            response = {}
        if response.get("error") and (
            "restricted" in response.get("error") or "gated" in response.get("error")
        ):
            gr.Warning(f"Branch information for {r_id} not available.")
            return {}
        size = response.get("size")
        if size is not None:
            repo_sizes[branch.name] = {
                "size_in_bytes": size,
                "size_in_chunks": count_chunks(size),
            }

    return repo_sizes


def get_repo_info(r_type, r_id):
    try:
        repo_sizes = repo_size(r_type, r_id)
        repo_files_info, treemap_fig = repo_files(r_type, r_id)
    except RepositoryNotFoundError:
        gr.Warning(
            "Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository."
        )
        return (
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
            gr.Plot(visible=False),
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
        )

    # check if repo_sizes is just {}
    if not repo_sizes:
        r_sizes_component = gr.Dataframe(visible=False)
        b_block = gr.Row(visible=False)
    else:
        r_sizes_df = pd.DataFrame(repo_sizes).T.reset_index(names="branch")
        r_sizes_df["formatted_size"] = r_sizes_df["size_in_bytes"].apply(
            format_repo_size
        )
        r_sizes_df.columns = ["Branch", "size_in_bytes", "Chunks", "Size"]
        r_sizes_component = gr.Dataframe(
            value=r_sizes_df[["Branch", "Size", "Chunks"]], visible=True
        )
        b_block = gr.Row(visible=True)

    rf_sizes_df = (
        pd.DataFrame(repo_files_info)
        .T.reset_index(names="ext")
        .sort_values(by="size", ascending=False)
    )
    rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size)
    rf_sizes_df.columns = ["Extension", "bytes", "Chunks", "Count", "Size"]
    return (
        gr.Row(visible=True),
        gr.Dataframe(
            value=rf_sizes_df[["Extension", "Count", "Size", "Chunks"]],
            visible=True,
        ),
        # gr.Plot(rf_sizes_plot, visible=True),
        gr.Plot(treemap_fig, visible=True),
        b_block,
        r_sizes_component,
    )


with gr.Blocks(theme="ocean") as demo:
    gr.Markdown("# Chunking Repos")
    gr.Markdown(
        "Search for a model or dataset repository using the autocomplete below, select the repository type, and get back information about the repository's contents including the [number of chunks each file might be split into with Xet backed storage](https://huggingface.co/blog/from-files-to-chunks)."
    )
    with gr.Blocks():
        # repo_id = gr.Textbox(label="Repository ID", placeholder="123456")
        repo_id = HuggingfaceHubSearch(
            label="Hub Repository Search (enter user, organization, or repository name to start searching)",
            placeholder="Search for model or dataset repositories on Huggingface",
            search_type=["model", "dataset"],
        )
        repo_type = gr.Radio(
            choices=["model", "dataset"],
            label="Repository Type",
            value="model",
        )
        search_button = gr.Button(value="Search")
    with gr.Blocks():
        with gr.Row(visible=False) as results_block:
            with gr.Column():
                gr.Markdown("## Repo Info")
                gr.Markdown(
                    "Hover over any file or directory to see it's size in bytes and total number of chunks required to store it in Xet storage."
                )
                file_info_plot = gr.Plot(visible=False)
                with gr.Row(visible=False) as branch_block:
                    with gr.Column():
                        gr.Markdown("### Branch Sizes")
                        gr.Markdown(
                            "The size of each branch in the repository and how many chunks it might need (assuming no dedupe)."
                        )
                        branch_sizes = gr.Dataframe(visible=False)
                with gr.Row():
                    with gr.Column():
                        gr.Markdown("### File Sizes")
                        gr.Markdown(
                            "The cumulative size of each filetype in the repository (in the `main` branch) and how many chunks they might need (assuming no dedupe)."
                        )
                        file_info = gr.Dataframe(visible=False)
                    # file_info_plot = gr.Plot(visible=False)

    search_button.click(
        get_repo_info,
        inputs=[repo_type, repo_id],
        outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes],
    )

demo.launch()