Spaces:
Running
Running
import gradio as gr | |
import json | |
import posixpath | |
from fastapi import HTTPException, Path, Query, Request | |
from fastapi.responses import StreamingResponse | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import HfApi, HfFileSystem | |
from typing import Annotated, Any, NamedTuple | |
from urllib.parse import urlencode | |
from _hf_explorer import FileExplorer | |
from _hf_gguf import standard_metadata, TokenType, LlamaFileType, GGUFValueType, HuggingGGUFstream | |
hfapi = HfApi() | |
class MetadataState(NamedTuple): | |
var: dict[str, Any] | |
key: dict[str, tuple[int, Any]] | |
add: dict[str, Any] | |
rem: set | |
def init_state( | |
): | |
return MetadataState( | |
var = {}, | |
key = {}, | |
add = {}, | |
rem = set(), | |
) | |
def human_readable_metadata( | |
meta: MetadataState, | |
key: str, | |
typ: int, | |
val: Any, | |
) -> tuple[str, int | str, Any]: | |
typ = GGUFValueType(typ).name | |
if typ == 'ARRAY': | |
val = '[[...], ...]' | |
elif isinstance(val, list): | |
typ = f'[{typ}][{len(val)}]' | |
if len(val) > 8: | |
val = str(val[:8])[:-1] + ', ...]' | |
else: | |
val = str(val) | |
elif isinstance(val, dict): | |
val = '[' + ', '.join((f'{k}: {v}' for k, v in val.items())) + ']' | |
elif key == 'general.file_type': | |
try: | |
ftype = LlamaFileType(val).name | |
except: | |
ftype = 'UNKNOWN' | |
val = f'{ftype} ({val})' | |
elif key.endswith('_token_id'): | |
tokens = meta.key.get('tokenizer.ggml.tokens', (-1, []))[1] | |
if isinstance(val, int) and val >= 0 and val < len(tokens): | |
val = f'{tokens[val]} ({val})' | |
return key, typ, val | |
with gr.Blocks( | |
) as blocks: | |
with gr.Row(): | |
hf_search = HuggingfaceHubSearch( | |
label = "Search Huggingface Hub", | |
placeholder = "Search for models on Huggingface", | |
search_type = "model", | |
sumbit_on_select = True, | |
scale = 2, | |
) | |
hf_branch = gr.Dropdown( | |
None, | |
label = "Branch", | |
scale = 1, | |
) | |
gr.LoginButton( | |
"Sign in to access gated/private repos", | |
scale = 1, | |
) | |
hf_file = FileExplorer( | |
visible=False, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
meta_keys = gr.Dropdown( | |
None, | |
label = "Modify Metadata", | |
info = "Search by metadata key name", | |
allow_custom_value = True, | |
visible = False, | |
) | |
with gr.Column(): | |
meta_types = gr.Dropdown( | |
[e.name for e in GGUFValueType], | |
label = "Metadata Type", | |
info = "Select data type", | |
type = "index", | |
visible = False, | |
) | |
with gr.Column(): | |
btn_delete = gr.Button( | |
"Remove Key", | |
variant = "stop", | |
visible = False, | |
) | |
meta_boolean = gr.Checkbox( | |
label = "Boolean", | |
info = "Click to update value", | |
visible = False, | |
) | |
with gr.Row(): | |
meta_token_select = gr.Dropdown( | |
label = "Select token", | |
info = "Search by token name", | |
type = "index", | |
allow_custom_value = True, | |
visible = False, | |
) | |
meta_token_type = gr.Dropdown( | |
[e.name for e in TokenType], | |
label = "Token type", | |
info = "Select token type", | |
type = "index", | |
visible = False, | |
) | |
meta_lookup = gr.Dropdown( | |
label = "Lookup token", | |
info = "Search by token name", | |
type = "index", | |
allow_custom_value = True, | |
visible = False, | |
) | |
meta_number = gr.Number( | |
info = "Enter to update value", | |
visible = False, | |
) | |
meta_string = gr.Textbox( | |
info = "Enter to update value (Shift+Enter for new line)", | |
visible = False, | |
) | |
meta_array = gr.Matrix( | |
None, | |
label = "Unsupported", | |
row_count = (1, "fixed"), | |
height = "1rem", | |
interactive = False, | |
visible = False, | |
) | |
meta_changes = gr.HighlightedText( | |
None, | |
label = "Metadata Changes", | |
color_map = {"add": "green", "rem": "red"}, | |
interactive = False, | |
visible = False, | |
) | |
btn_download = gr.Button( | |
"Download GGUF", | |
variant = "primary", | |
visible = False, | |
) | |
file_meta = gr.Matrix( | |
None, | |
col_count = (3, "fixed"), | |
headers = [ | |
"Metadata Name", | |
"Type", | |
"Value", | |
], | |
datatype = ["str", "str", "str"], | |
column_widths = ["35%", "15%", "50%"], | |
wrap = True, | |
interactive = False, | |
visible = False, | |
) | |
meta_state = gr.State() # init_state | |
# BUG: For some reason using gr.State initial value turns tuple to list? | |
meta_state.value = init_state() | |
token_select_indices = gr.State([]) | |
file_change_components = [ | |
meta_changes, | |
file_meta, | |
meta_keys, | |
btn_download, | |
] | |
state_change_components = [ | |
meta_state, | |
] + file_change_components | |
def get_branches( | |
repo: str, | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
branches = [] | |
try: | |
refs = hfapi.list_repo_refs( | |
repo, | |
token = oauth_token.token if oauth_token else False, | |
) | |
branches = [b.name for b in refs.branches] | |
except Exception as e: | |
raise gr.Error(e) | |
return { | |
hf_branch: gr.Dropdown( | |
branches or None, | |
value = "main" if "main" in branches else None, | |
), | |
} | |
def get_files( | |
repo: str, | |
branch: str | None, | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
return { | |
hf_file: FileExplorer( | |
"**/*.gguf", | |
file_count = "single", | |
root_dir = repo, | |
branch = branch, | |
token = oauth_token.token if oauth_token else None, | |
visible = True, | |
), | |
meta_changes: gr.HighlightedText( | |
None, | |
visible = False, | |
), | |
file_meta: gr.Matrix( | |
# None, # FIXME (see Dataframe bug below) | |
visible = False, | |
), | |
meta_keys: gr.Dropdown( | |
None, | |
visible = False, | |
), | |
btn_download: gr.Button( | |
visible = False, | |
), | |
} | |
def load_metadata( | |
repo_file: str | None, | |
branch: str | None, | |
progress: gr.Progress = gr.Progress(), | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
m = [] | |
meta = init_state() | |
yield { | |
meta_state: meta, | |
file_meta: gr.Matrix( | |
[['', '', '']] * 100, # FIXME: Workaround for Dataframe bug when user has selected data | |
visible = True, | |
), | |
meta_changes: gr.HighlightedText( | |
None, | |
visible = False, | |
), | |
meta_keys: gr.Dropdown( | |
None, | |
visible = False, | |
), | |
btn_download: gr.Button( | |
visible = False, | |
), | |
} | |
if not repo_file: | |
return | |
fs = HfFileSystem( | |
token = oauth_token.token if oauth_token else None, | |
) | |
try: | |
progress(0, desc = 'Loading file...') | |
with fs.open( | |
repo_file, | |
"rb", | |
revision = branch, | |
block_size = 8 * 1024 * 1024, | |
cache_type = "readahead", | |
) as fp: | |
progress(0, desc = 'Reading header...') | |
gguf = HuggingGGUFstream(fp) | |
num_metadata = gguf.header['metadata'].value | |
metadata = gguf.read_metadata() | |
meta.var['repo_file'] = repo_file | |
meta.var['branch'] = branch | |
deferred_updates = [] | |
for k, v in progress.tqdm(metadata, desc = 'Reading metadata...', total = num_metadata, unit = f' of {num_metadata} metadata keys...'): | |
human = [*human_readable_metadata(meta, k, v.type, v.value)] | |
if k.endswith('_token_id') and 'tokenizer.ggml.tokens' not in meta.key: | |
deferred_updates.append(((k, v.type, v.value), human)) | |
m.append(human) | |
meta.key[k] = (v.type, v.value) | |
# FIXME | |
# yield { | |
# file_meta: gr.Matrix( | |
# m, | |
# ), | |
# } | |
for data, human in deferred_updates: | |
human[:] = human_readable_metadata(meta, *data) | |
except Exception as e: | |
raise gr.Error(e) | |
yield { | |
meta_state: meta, | |
file_meta: gr.Matrix( | |
m, | |
), | |
meta_keys: gr.Dropdown( | |
sorted(meta.key.keys() | standard_metadata.keys()), | |
value = '', | |
visible = True, | |
), | |
} | |
def update_metakey( | |
meta: MetadataState, | |
key: str | None, | |
): | |
typ = None | |
if (val := meta.key.get(key, standard_metadata.get(key))) is not None: | |
typ = GGUFValueType(val[0]).name | |
elif key: | |
if key.startswith('tokenizer.chat_template.'): | |
typ = GGUFValueType.STRING.name | |
elif key.endswith('_token_id'): | |
typ = GGUFValueType.UINT32.name | |
return { | |
meta_types: gr.Dropdown( | |
value = typ, | |
interactive = False if typ is not None else True, | |
visible = True if key else False, | |
), | |
btn_delete: gr.Button( | |
visible = True if key in meta.key else False, | |
), | |
} | |
def update_metatype( | |
meta: MetadataState, | |
key: str, | |
typ: int, | |
): | |
val = None | |
tokens = meta.key.get('tokenizer.ggml.tokens', (-1, []))[1] | |
if (data := meta.key.get(key, standard_metadata.get(key))) is not None: | |
typ = data[0] | |
val = data[1] | |
elif not key: | |
typ = None | |
do_select_token = False | |
do_lookup_token = False | |
do_token_type = False | |
match key: | |
case 'tokenizer.ggml.scores': | |
do_select_token = True | |
case 'tokenizer.ggml.token_type': | |
do_select_token = True | |
do_token_type = True | |
case s if s.endswith('_token_id'): | |
do_lookup_token = True | |
case _: | |
pass | |
if isinstance(val, list) and not do_select_token: | |
# TODO: Support arrays? | |
typ = GGUFValueType.ARRAY | |
match typ: | |
case GGUFValueType.INT8 | GGUFValueType.INT16 | GGUFValueType.INT32 | GGUFValueType.INT64 | GGUFValueType.UINT8 | GGUFValueType.UINT16 | GGUFValueType.UINT32 | GGUFValueType.UINT64 | GGUFValueType.FLOAT32 | GGUFValueType.FLOAT64: | |
is_number = True | |
case _: | |
is_number = False | |
return { | |
meta_boolean: gr.Checkbox( | |
value = val if typ == GGUFValueType.BOOL and data is not None else False, | |
visible = True if typ == GGUFValueType.BOOL else False, | |
), | |
meta_token_select: gr.Dropdown( | |
None, | |
value = '', | |
visible = True if do_select_token else False, | |
), | |
meta_token_type: gr.Dropdown( | |
interactive = False, | |
visible = True if do_token_type else False, | |
), | |
meta_lookup: gr.Dropdown( | |
None, | |
value = tokens[val] if is_number and data is not None and do_lookup_token and val < len(tokens) else '', | |
visible = True if is_number and do_lookup_token else False, | |
), | |
meta_number: gr.Number( | |
value = val if is_number and data is not None and not do_select_token else 0, | |
precision = 10 if typ == GGUFValueType.FLOAT32 or typ == GGUFValueType.FLOAT64 else 0, | |
interactive = False if do_select_token else True, | |
visible = True if is_number and not do_token_type else False, | |
), | |
meta_string: gr.Textbox( | |
value = val if typ == GGUFValueType.STRING else '', | |
visible = True if typ == GGUFValueType.STRING else False, | |
), | |
meta_array: gr.Matrix( | |
visible = True if typ == GGUFValueType.ARRAY else False, | |
), | |
} | |
# FIXME: Disabled for now due to Dataframe bug when user has selected data | |
# @gr.on( | |
# triggers = [ | |
# file_meta.select, | |
# ], | |
# inputs = [ | |
# ], | |
# outputs = [ | |
# meta_keys, | |
# ], | |
# ) | |
# def select_metakey( | |
# evt: gr.SelectData, | |
# ): | |
# return { | |
# meta_keys: gr.Dropdown( | |
# value = evt.row_value[0] if evt.selected else '', | |
# ), | |
# } | |
def notify_state_change( | |
meta: MetadataState, | |
request: gr.Request, | |
): | |
changes = [(k, 'rem') for k in meta.rem] | |
for k, v in meta.add.items(): | |
key, typ, val = human_readable_metadata(meta, k, *v) | |
changes.append((k, 'add')) | |
changes.append((str(val), None)) | |
m = [] | |
for k, v in meta.key.items(): | |
m.append([*human_readable_metadata(meta, k, v[0], v[1])]) | |
link = str(request.request.url_for('download', repo_file = meta.var['repo_file']).include_query_params(branch = meta.var['branch'])) | |
if link.startswith('http:'): | |
link = 'https' + link[4:] | |
if meta.rem or meta.add: | |
link += '&' + urlencode( | |
{ | |
'rem': meta.rem, | |
'add': [json.dumps([k, *v], ensure_ascii = False, separators = (',', ':')) for k, v in meta.add.items()], | |
}, | |
doseq = True, | |
safe = '[]{}:"\',', | |
) | |
return { | |
meta_state: meta, | |
meta_changes: gr.HighlightedText( | |
changes, | |
visible = True if changes else False, | |
), | |
file_meta: gr.Matrix( | |
m, | |
), | |
meta_keys: gr.Dropdown( | |
sorted(meta.key.keys() | standard_metadata.keys()), | |
value = '', | |
), | |
btn_download: gr.Button( | |
link = link, | |
visible = True if changes else False, | |
), | |
} | |
def rem_metadata( | |
meta: MetadataState, | |
key: str, | |
request: gr.Request, | |
): | |
if key in meta.add: | |
del meta.add[key] | |
if key in meta.key: | |
del meta.key[key] | |
meta.rem.add(key) | |
return notify_state_change( | |
meta, | |
request, | |
) | |
def token_search( | |
meta: MetadataState, | |
name: str, | |
): | |
found = {} | |
name = name.lower() | |
tokens = meta.key.get('tokenizer.ggml.tokens', (-1, []))[1] | |
any(((len(found) > 5, found.setdefault(i, t))[0] for i, t in enumerate(tokens) if name in t.lower())) | |
return found | |
def token_select( | |
meta: MetadataState, | |
keyup: gr.KeyUpData, | |
): | |
found = token_search(meta, keyup.input_value) | |
return { | |
meta_token_select: gr.Dropdown( | |
list(found.values()), | |
), | |
token_select_indices: list(found.keys()), | |
} | |
def token_selected( | |
meta: MetadataState, | |
key: str, | |
choice: int | None, | |
indices: list[int], | |
): | |
if choice is None or choice < 0 or choice >= len(indices) or (token := indices[choice]) < 0: | |
raise gr.Error('Token not found') | |
tokens = meta.key.get('tokenizer.ggml.tokens', (-1, []))[1] | |
if token >= len(tokens): | |
raise gr.Error('Invalid token') | |
data = meta.key.get(key, (-1, []))[1] | |
match key: | |
case 'tokenizer.ggml.scores': | |
return { | |
meta_number: gr.Number( | |
value = data[token] if data and len(data) > token else 0.0, | |
interactive = True, | |
), | |
} | |
case 'tokenizer.ggml.token_type': | |
return { | |
meta_token_type: gr.Dropdown( | |
value = TokenType(data[token]).name if data and len(data) > token else TokenType.NORMAL.name, | |
interactive = True, | |
), | |
} | |
case _: | |
raise gr.Error('Invalid metadata key') | |
def token_lookup( | |
meta: MetadataState, | |
keyup: gr.KeyUpData, | |
): | |
found = token_search(meta, keyup.input_value) | |
return { | |
meta_lookup: gr.Dropdown( | |
list(found.values()), | |
), | |
token_select_indices: list(found.keys()), | |
} | |
def add_metadata( | |
meta: MetadataState, | |
key: str, | |
typ: int | None, | |
val: Any, | |
request: gr.Request, | |
choice: int | None = None, | |
indices: list[int] | None = None, | |
): | |
if not key or typ is None: | |
if key: | |
gr.Warning('Missing required value type') | |
return { | |
meta_changes: gr.HighlightedText( | |
), | |
} | |
if key in meta.rem: | |
meta.rem.remove(key) | |
match key: | |
case 'tokenizer.ggml.scores' | 'tokenizer.ggml.token_type': | |
if choice is None or choice < 0 or choice >= len(indices) or (token := indices[choice]) < 0: | |
raise gr.Error('Token not found') | |
tok = meta.add.setdefault(key, (typ, {}))[1] | |
tok[str(token)] = val + 1 if key == 'tokenizer.ggml.token_type' else val | |
data = meta.key.setdefault(key, (typ, [0.0 if key == 'tokenizer.ggml.scores' else int(TokenType.NORMAL)] * len(meta.key.get('tokenizer.ggml.tokens', (-1, []))[1])))[1] | |
if data: | |
for k, v in tok.items(): | |
data[int(k)] = v | |
case _: | |
meta.key[key] = meta.add[key] = (typ, val) | |
if key.startswith('tokenizer.chat_template.'): | |
template = key[24:] | |
if template not in meta.key.get('tokenizer.chat_templates', []): | |
templates = [x[24:] for x in meta.key.keys() if x.startswith('tokenizer.chat_template.')] | |
meta.key['tokenizer.chat_templates'] = meta.add['tokenizer.chat_templates'] = (GGUFValueType.STRING, templates) | |
return notify_state_change( | |
meta, | |
request, | |
) | |
def token_select_to_id( | |
choice: int, | |
indices: list[int], | |
): | |
if choice < 0 or choice >= len(indices) or (token := indices[choice]) < 0: | |
raise gr.Error('Token not found') | |
return { | |
meta_number: gr.Number( | |
token, | |
), | |
} | |
meta_lookup.input( | |
token_select_to_id, | |
inputs = [ | |
meta_lookup, | |
token_select_indices, | |
], | |
outputs = [ | |
meta_number, | |
], | |
).success( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_number, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
meta_boolean.input( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_boolean, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
meta_token_type.input( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_token_type, | |
meta_token_select, | |
token_select_indices, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
meta_number.submit( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_number, | |
meta_token_select, | |
token_select_indices, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
meta_string.submit( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_string, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
meta_array.input( | |
add_metadata, | |
inputs = [ | |
meta_state, | |
meta_keys, | |
meta_types, | |
meta_array, | |
], | |
outputs = [ | |
] + state_change_components, | |
) | |
def stream_repo_file( | |
repo_file: str, | |
branch: str, | |
add_meta: list[str] | None, | |
rem_meta: list[str] | None, | |
token: str | None = None, | |
): | |
fs = HfFileSystem( | |
token = token, | |
) | |
with fs.open( | |
repo_file, | |
"rb", | |
revision = branch, | |
block_size = 8 * 1024 * 1024, | |
cache_type = "readahead", | |
) as fp: | |
if not rem_meta: | |
rem_meta = [] | |
if not add_meta: | |
add_meta = [] | |
gguf = HuggingGGUFstream(fp) | |
for _ in gguf.read_metadata(): | |
pass | |
for k in rem_meta: | |
gguf.remove_metadata(k) | |
tokens = gguf.metadata.get('tokenizer.ggml.tokens') | |
for k in add_meta: | |
k = json.loads(k) | |
if isinstance(k, list) and len(k) == 3: | |
if isinstance(k[2], dict): | |
if tokens: | |
if (data := gguf.metadata.get(k[0])): | |
data = data.value | |
else: | |
data = [0.0 if k[0] == 'tokenizer.ggml.scores' else int(TokenType.NORMAL)] * len(tokens.value) | |
for i, v in k[2].items(): | |
data[int(i)] = v | |
k[2] = data | |
else: | |
k[2] = [] | |
gguf.add_metadata(*k) | |
yield gguf.filesize | |
yield b''.join((v.data for k, v in gguf.header.items())) | |
for k, v in gguf.metadata.items(): | |
yield v.data | |
while True: | |
if not (data := fp.read(65536)): | |
break | |
yield data | |
if __name__ == "__main__": | |
blocks.queue( | |
max_size = 10, | |
default_concurrency_limit = 10, | |
) | |
app, local_url, share_url = blocks.launch( | |
show_api = False, | |
prevent_thread_lock = True, | |
) | |
async def download( | |
request: Request, | |
repo_file: Annotated[str, Path()], | |
branch: Annotated[str, Query()] = "main", | |
add: Annotated[list[str] | None, Query()] = None, | |
rem: Annotated[list[str] | None, Query()] = None, | |
): | |
token = request.session.get('oauth_info', {}).get('access_token') | |
if posixpath.normpath(repo_file) != repo_file or '\\' in repo_file or repo_file.startswith('../') or repo_file.startswith('/') or repo_file.count('/') < 2: | |
raise HTTPException( | |
status_code = 404, | |
detail = 'Invalid repository', | |
) | |
stream = stream_repo_file( | |
repo_file, | |
branch, | |
add, | |
rem, | |
token = token, | |
) | |
size = next(stream) | |
return StreamingResponse( | |
stream, | |
headers = { | |
'Content-Length': str(size), | |
}, | |
media_type = 'application/octet-stream', | |
) | |
app.add_api_route( | |
"/download/{repo_file:path}", | |
download, | |
methods = ["GET"], | |
) | |
# app.openapi_schema = None | |
# app.setup() | |
blocks.block_thread() | |