Spaces:
Running
Running
import gradio as gr | |
import json | |
from difflib import Differ, unified_diff | |
from itertools import groupby | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from huggingface_hub import HfApi, CommitOperationAdd | |
from transformers import PreTrainedTokenizerBase | |
from enum import StrEnum | |
from copy import deepcopy | |
hfapi = HfApi() | |
class ModelFiles(StrEnum): | |
TOKENIZER_CHAT_TEMPLATE = "tokenizer_chat_template.jinja" | |
TOKENIZER_CONFIG = "tokenizer_config.json" | |
TOKENIZER_INVERSE_TEMPLATE = "inverse_template.jinja" | |
example_labels = [ | |
"Single user message", | |
"Single user message with system prompt", | |
"Longer conversation", | |
"Tool call", | |
"Tool call with response", | |
"Tool call with multiple responses", | |
"Tool call with complex tool definition", | |
"RAG call", | |
] | |
example_values = [ | |
[ | |
"{}", | |
"""[ | |
{ | |
"role": "user", | |
"content": "What is the capital of Norway?" | |
} | |
]""", | |
], | |
[ | |
"{}", | |
"""[ | |
{ | |
"role": "system", | |
"content": "You are a somewhat helpful AI." | |
}, | |
{ | |
"role": "user", | |
"content": "What is the capital of Norway?" | |
} | |
]""", | |
], | |
[ | |
"{}", | |
"""[ | |
{ | |
"role": "user", | |
"content": "What is the capital of Norway?" | |
}, | |
{ | |
"role": "assistant", | |
"content": "Oslo is the capital of Norway." | |
}, | |
{ | |
"role": "user", | |
"content": "What is the world famous sculpture park there called?" | |
}, | |
{ | |
"role": "assistant", | |
"content": "The world famous sculpture park in Oslo is called Vigelandsparken." | |
}, | |
{ | |
"role": "user", | |
"content": "What is the most famous sculpture in the park?" | |
} | |
]""", | |
], | |
[ | |
"""{ | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"description": "Get the current weather in a given location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g. San Francisco, CA" | |
}, | |
"unit": { | |
"type": "string", | |
"enum": [ "celsius", "fahrenheit" ] | |
} | |
}, | |
"required": [ "location" ] | |
} | |
} | |
} | |
] | |
}""", | |
"""[ | |
{ | |
"role": "user", | |
"content": "What's the weather like in Oslo?" | |
} | |
]""", | |
], | |
[ | |
"""{ | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"description": "Get the current weather in a given location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g. San Francisco, CA" | |
}, | |
"unit": { | |
"type": "string", | |
"enum": [ "celsius", "fahrenheit" ] | |
} | |
}, | |
"required": [ "location" ] | |
} | |
} | |
} | |
] | |
}""", | |
"""[ | |
{ | |
"role": "user", | |
"content": "What's the weather like in Oslo?" | |
}, | |
{ | |
"role": "assistant", | |
"content": null, | |
"tool_calls": [ | |
{ | |
"id": "toolcall1", | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"arguments": { | |
"location": "Oslo, Norway", | |
"unit": "celsius" | |
} | |
} | |
} | |
] | |
}, | |
{ | |
"role": "tool", | |
"content": "20", | |
"tool_call_id": "toolcall1" | |
} | |
]""", | |
], | |
[ | |
"""{ | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"description": "Get the current weather in a given location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"type": "string", | |
"description": "The city and state, e.g. San Francisco, CA" | |
}, | |
"unit": { | |
"type": "string", | |
"enum": [ "celsius", "fahrenheit" ] | |
} | |
}, | |
"required": [ "location" ] | |
} | |
} | |
} | |
] | |
}""", | |
"""[ | |
{ | |
"role": "user", | |
"content": "What's the weather like in Oslo and Stockholm?" | |
}, | |
{ | |
"role": "assistant", | |
"content": null, | |
"tool_calls": [ | |
{ | |
"id": "toolcall1", | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"arguments": { | |
"location": "Oslo, Norway", | |
"unit": "celsius" | |
} | |
} | |
}, | |
{ | |
"id": "toolcall2", | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"arguments": { | |
"location": "Stockholm, Sweden", | |
"unit": "celsius" | |
} | |
} | |
} | |
] | |
}, | |
{ | |
"role": "tool", | |
"content": "20", | |
"tool_call_id": "toolcall1" | |
}, | |
{ | |
"role": "tool", | |
"content": "22", | |
"tool_call_id": "toolcall2" | |
} | |
]""", | |
], | |
[ | |
"""{ | |
"tools": [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "create_user", | |
"description": "creates a user", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"user": { | |
"title": "User", | |
"type": "object", | |
"properties": { | |
"user_id": { | |
"title": "User Id", | |
"description": "The unique identifier for a user", | |
"default": 0, | |
"type": "integer" | |
}, | |
"name": { | |
"title": "Name", | |
"description": "The name of the user", | |
"type": "string" | |
}, | |
"birthday": { | |
"type": "string", | |
"description": "The birthday of the user, e.g. 2022-01-01", | |
"pattern": "^([1-9] |1[0-9]| 2[0-9]|3[0-1])(.|-)([1-9] |1[0-2])(.|-|)20[0-9][0-9]$" | |
}, | |
"email": { | |
"title": "Email", | |
"description": "The email address of the user", | |
"type": "string" | |
}, | |
"friends": { | |
"title": "Friends", | |
"description": "List of friends of the user", | |
"type": "array", | |
"items": {"type": "string"} | |
} | |
}, | |
"required": ["name", "email"] | |
} | |
}, | |
"required": ["user"], | |
"definitions": {} | |
} | |
} | |
} | |
] | |
}""", | |
"""[ | |
{ | |
"role": "user", | |
"content": "Create a user for Test User ([email protected]), born January 1st 2000, with some random friends." | |
} | |
]""", | |
], | |
[ | |
"""{ | |
"documents": [ | |
{ | |
"title": "Much ado about nothing", | |
"content": "Dolor sit amet..." | |
}, | |
{ | |
"title": "Less ado about something", | |
"content": "Lorem ipsum..." | |
} | |
] | |
}""", | |
"""[ | |
{ | |
"role": "user", | |
"content": "Write a brief summary of the following documents." | |
} | |
]""", | |
], | |
] | |
pr_description_default = "### Changes\n* \n\n**Updated using [Chat Template Editor](https://huggingface.co/spaces/CISCai/chat-template-editor)**" | |
class TokenizerConfig(): | |
def __init__(self, tokenizer_config: dict): | |
self._data = deepcopy(tokenizer_config) | |
self.chat_template = self._data.get("chat_template") | |
def chat_template(self) -> str | list | None: | |
templates = [ | |
{ | |
"name": k, | |
"template": v, | |
} | |
for k, v in self.chat_templates.items() if v | |
] | |
if not templates: | |
return None | |
elif len(templates) == 1 and templates[0]["name"] == "default": | |
return templates[0]["template"] | |
else: | |
return templates | |
def chat_template(self, value: str | list | None): | |
if not value: | |
self.chat_templates.clear() | |
elif isinstance(value, str): | |
self.chat_templates = { | |
"default": value, | |
} | |
else: | |
self.chat_templates = { | |
t["name"]: t["template"] | |
for t in value | |
} | |
# @property | |
# def inverse_template(self) -> str | None: | |
# return self._data.get("inverse_template") | |
# @inverse_template.setter | |
# def inverse_template(self, value: str | None): | |
# if value: | |
# self._data["inverse_template"] = value | |
# elif "inverse_template" in self._data: | |
# del self._data["inverse_template"] | |
def json(self, indent: int | str | None = 4) -> str: | |
self._data["chat_template"] = self.chat_template | |
return json.dumps(self._data, ensure_ascii = False, indent = indent) | |
def get_json_indent( | |
json: str, | |
) -> int | str | None: | |
nonl = json.replace('\r', '').replace('\n', '') | |
start = nonl.find("{") | |
first = nonl.find('"') | |
return "\t" if start >= 0 and nonl[start + 1] == "\t" else None if first == json.find('"') else min(max(first - start - 1, 2), 4) | |
def character_diff( | |
diff_title: str | None, | |
str_original: str, | |
str_updated: str, | |
): | |
d = Differ() | |
title = [] if diff_title is None else [ (f"\n@@ {diff_title} @@\n", "@") ] | |
diffs = [ | |
("".join(map(lambda x: x[2:].replace("\t", "\u21e5").replace("\r", "\u240d\r").replace("\n", "\u240a\n") if x[0] != " " else x[2:], tokens)), group if group != " " else None) # .replace(" ", "\u2423") | |
for group, tokens in groupby(d.compare(str_updated, str_original), lambda x: x[0]) | |
] | |
return title + ([("No changes", "?")] if len(diffs) == 1 and diffs[0][1] is None else diffs) | |
with gr.Blocks( | |
) as blocks: | |
with gr.Row( | |
equal_height = True, | |
): | |
hf_search = HuggingfaceHubSearch( | |
label = "Search Huggingface Hub", | |
placeholder = "Search for models on Huggingface", | |
search_type = "model", | |
sumbit_on_select = True, | |
scale = 2, | |
) | |
hf_branch = gr.Dropdown( | |
None, | |
label = "Branch", | |
scale = 1, | |
) | |
gr.LoginButton( | |
"Sign in for write access or gated/private repos", | |
scale = 1, | |
) | |
gr.Markdown( | |
"""# Chat Template Editor | |
Any model repository with chat template(s) is supported (including GGUFs), however do note that all the model info is extracted using the Hugging Face API. | |
For GGUFs in particular this means that the chat template may deviate from the actual content in any given GGUF file as only the default template from an arbitrary GGUF file is returned. | |
If you sign in and grant this editor write access you will get the option to create a pull request of your changes (provided you have access to the repository). | |
You can freely edit and test GGUF chat template(s) (and are encouraged to do so), but you cannot commit any changes, it is recommended to use the [GGUF Editor](https://huggingface.co/spaces/CISCai/gguf-editor) to save the final result to a GGUF. | |
""", | |
) | |
with gr.Accordion("Commit Changes", open = False, visible = False) as pr_group: | |
with gr.Tabs() as pr_tabs: | |
with gr.Tab("Edit", id = "edit") as pr_edit_tab: | |
pr_title = gr.Textbox( | |
placeholder = "Title", | |
show_label = False, | |
max_lines = 1, | |
interactive = True, | |
) | |
pr_description = gr.Code( | |
label = "Description", | |
language = "markdown", | |
lines = 10, | |
max_lines = 10, | |
interactive = True, | |
) | |
with gr.Tab("Preview (with diffs)", id = "preview") as pr_preview_tab: | |
pr_preview_title = gr.Textbox( | |
show_label = False, | |
max_lines = 1, | |
interactive = False, | |
) | |
pr_preview_description = gr.Markdown( | |
label = "Description", | |
height = "13rem", | |
container = True, | |
) | |
pr_preview_diff = gr.HighlightedText( | |
label = "Diff", | |
combine_adjacent = True, | |
color_map = { "+": "red", "-": "green", "@": "blue", "?": "blue" }, | |
interactive = False, | |
show_legend = False, | |
show_inline_category = False, | |
) | |
pr_submit = gr.Button( | |
"Create Pull Request", | |
variant = "huggingface", | |
interactive = False, | |
) | |
pr_submit.click( | |
lambda: gr.Button( | |
interactive = False, | |
), | |
outputs = [ | |
pr_submit, | |
], | |
show_api = False, | |
) | |
with gr.Tabs() as template_tabs: | |
with gr.Tab("Edit", id = "edit") as edit_tab: | |
with gr.Accordion("Template Input", open = False): | |
chat_settings = gr.Code( | |
label = "Template Settings (kwargs)", | |
language = "json", | |
interactive = True, | |
render = False, | |
) | |
chat_messages = gr.Code( | |
label = "Template Messages", | |
language = "json", | |
interactive = True, | |
render = False, | |
) | |
example_input = gr.Examples( | |
examples = example_values, | |
example_labels = example_labels, | |
inputs = [ | |
chat_settings, | |
chat_messages, | |
], | |
) | |
chat_settings.render() | |
chat_messages.render() | |
chat_template = gr.Code( | |
label = "Chat Template (default)", | |
language = "jinja2", | |
interactive = True, | |
) | |
with gr.Accordion("Additional Templates", open = False): | |
inverse_template = gr.Code( | |
label = "Inverse Template", | |
language = "jinja2", | |
interactive = True, | |
visible = False, | |
) | |
chat_template_tool_use = gr.Code( | |
label = "Chat Template (tool_use)", | |
language = "jinja2", | |
interactive = True, | |
) | |
chat_template_rag = gr.Code( | |
label = "Chat Template (rag)", | |
language = "jinja2", | |
interactive = True, | |
) | |
with gr.Tab("Render", id = "render") as render_tab: | |
rendered_chat_template = gr.Textbox( | |
label = "Chat Prompt (default)", | |
interactive = False, | |
lines = 20, | |
show_copy_button = True, | |
) | |
with gr.Accordion("Additional Output", open = False): | |
rendered_inverse_template = gr.Code( | |
label = "Inverse Chat Messages", | |
language = "json", | |
interactive = False, | |
visible = False, | |
) | |
rendered_chat_template_tool_use = gr.Textbox( | |
label = "Chat Prompt (tool_use)", | |
interactive = False, | |
lines = 20, | |
show_copy_button = True, | |
) | |
rendered_chat_template_rag = gr.Textbox( | |
label = "Chat Prompt (rag)", | |
interactive = False, | |
lines = 20, | |
show_copy_button = True, | |
) | |
model_info = gr.State( | |
value = {}, | |
) | |
def get_branches( | |
repo: str, | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
branches = [] | |
try: | |
refs = hfapi.list_repo_refs( | |
repo, | |
token = oauth_token.token if oauth_token else False, | |
) | |
branches = [b.name for b in refs.branches] | |
open_prs = hfapi.get_repo_discussions( | |
repo, | |
discussion_type = "pull_request", | |
discussion_status = "open", | |
token = oauth_token.token if oauth_token else False, | |
) | |
branches += [pr.git_reference for pr in open_prs] | |
except Exception as e: | |
pass | |
return { | |
hf_branch: gr.Dropdown( | |
branches or None, | |
value = "main" if "main" in branches else None, | |
), | |
} | |
def enable_pr_submit( | |
title: str, | |
): | |
return gr.Button( | |
interactive = bool(title) | |
) | |
def render_pr_preview( | |
info: dict, | |
title: str, | |
description: str, | |
template: str, | |
template_tool_use: str, | |
template_rag: str, | |
template_inverse: str, | |
): | |
changes = [] | |
org_template = "" | |
org_template_inverse = "" | |
org_template_tool_use = "" | |
org_template_rag = "" | |
tokenizer_file = info.get(ModelFiles.TOKENIZER_CONFIG, {}) | |
org_config = tokenizer_file.get("data") | |
if org_config: | |
tokenizer_config = TokenizerConfig(tokenizer_file.get("content")) | |
org_template = tokenizer_config.chat_templates.get("default") or "" | |
org_template_tool_use = tokenizer_config.chat_templates.get("tool_use") or "" | |
org_template_rag = tokenizer_config.chat_templates.get("rag") or "" | |
# org_template_inverse = tokenizer_config.inverse_template or "" | |
tokenizer_config.chat_templates["default"] = template | |
tokenizer_config.chat_templates["tool_use"] = template_tool_use | |
tokenizer_config.chat_templates["rag"] = template_rag | |
# tokenizer_config.inverse_template = template_inverse | |
new_config = tokenizer_config.json(get_json_indent(org_config)) | |
if org_config.endswith("\n"): | |
new_config += "\n" | |
changes += [ | |
(token if token[1] in ("-", "+", "@") else token[1:].replace("\t", "\u21e5").replace("\r\n", "\u240d\u240a\r\n").replace("\r", "\u240d\r").replace("\n", "\u240a\n"), token[0] if token[0] != " " else None) # .replace(" ", "\u2423") | |
for token in unified_diff(new_config.splitlines(keepends = True), org_config.splitlines(keepends = True), fromfile = ModelFiles.TOKENIZER_CONFIG, tofile = ModelFiles.TOKENIZER_CONFIG) | |
] | |
tokenizer_chat_template = info.get(ModelFiles.TOKENIZER_CHAT_TEMPLATE, {}) | |
org_template = tokenizer_chat_template.get("data", org_template) | |
tokenizer_inverse_template = info.get(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, {}) | |
org_template_inverse = tokenizer_inverse_template.get("data", org_template_inverse) | |
if org_template or template: | |
changes += character_diff(f"Default Template{f' ({ModelFiles.TOKENIZER_CHAT_TEMPLATE})' if tokenizer_chat_template else ''}", org_template, template) | |
if org_template_inverse or template_inverse: | |
changes += character_diff(f"Inverse Template{f' ({ModelFiles.TOKENIZER_INVERSE_TEMPLATE})' if tokenizer_inverse_template else ''}", org_template_inverse, template_inverse) | |
if org_template_tool_use or template_tool_use: | |
changes += character_diff("Tool Use Template", org_template_tool_use, template_tool_use) | |
if org_template_rag or template_rag: | |
changes += character_diff("RAG Template", org_template_rag, template_rag) | |
return title, description, changes | |
def submit_pull_request( | |
repo: str, | |
branch: str | None, | |
info: dict, | |
title: str, | |
description: str, | |
template: str, | |
template_tool_use: str, | |
template_rag: str, | |
template_inverse: str, | |
progress = gr.Progress(track_tqdm = True), | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
operations = [] | |
pr_branch = branch if branch.startswith("refs/pr/") else None | |
tokenizer_file = info.get(ModelFiles.TOKENIZER_CONFIG, {}) | |
if org_config := tokenizer_file.get("data"): | |
tokenizer_config = TokenizerConfig(tokenizer_file.get("content")) | |
tokenizer_config.chat_templates["default"] = template | |
tokenizer_config.chat_templates["tool_use"] = template_tool_use | |
tokenizer_config.chat_templates["rag"] = template_rag | |
# tokenizer_config.inverse_template = template_inverse | |
new_config = tokenizer_config.json(get_json_indent(org_config)) | |
if org_config.endswith("\n"): | |
new_config += "\n" | |
if org_config != new_config: | |
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_CONFIG, new_config.encode("utf-8"))) | |
tokenizer_chat_template = info.get(ModelFiles.TOKENIZER_CHAT_TEMPLATE, {}) | |
if template_data := tokenizer_chat_template.get("data"): | |
if template_data != template: | |
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_CHAT_TEMPLATE, template.encode("utf-8"))) | |
tokenizer_inverse_template = info.get(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, {}) | |
if template_data := tokenizer_inverse_template.get("data"): | |
if template_data != template_inverse: | |
operations.append(CommitOperationAdd(ModelFiles.TOKENIZER_INVERSE_TEMPLATE, template_inverse.encode("utf-8"))) | |
if not operations: | |
gr.Info("No changes to commit...") | |
return gr.skip() | |
try: | |
commit = hfapi.create_commit( | |
repo, | |
operations, | |
revision = branch, | |
commit_message = title, | |
commit_description = description, | |
create_pr = False if pr_branch else True, | |
parent_commit = info.get("parent_commit"), | |
token = oauth_token.token if oauth_token else False, | |
) | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Error committing changes", | |
) | |
return gr.skip() | |
info["parent_commit"] = commit.oid | |
if org_config: | |
tokenizer_file["data"] = new_config | |
tokenizer_file["content"] = json.loads(new_config) | |
if tokenizer_chat_template: | |
tokenizer_chat_template["data"] = template | |
if tokenizer_inverse_template: | |
tokenizer_inverse_template["data"] = template_inverse | |
branches = [] | |
try: | |
refs = hfapi.list_repo_refs( | |
repo, | |
token = oauth_token.token if oauth_token else False, | |
) | |
branches = [b.name for b in refs.branches] | |
open_prs = hfapi.get_repo_discussions( | |
repo, | |
discussion_type = "pull_request", | |
discussion_status = "open", | |
token = oauth_token.token if oauth_token else False, | |
) | |
branches += [pr.git_reference for pr in open_prs] | |
except Exception as e: | |
pass | |
pr_created = commit.pr_revision if commit.pr_revision in branches else None | |
gr.Info( | |
message = "Successfully committed changes.", | |
title = f"Pull Request {'Created' if pr_created else 'Updated'}", | |
) | |
return { | |
model_info: info, | |
hf_branch: gr.skip() if pr_branch else gr.Dropdown( | |
branches or None, | |
value = pr_created or branch, | |
), | |
pr_title: gr.skip() if pr_branch else gr.Textbox( | |
value = None, | |
placeholder = "Message" if pr_created else "Title", | |
label = commit.commit_message if pr_created else None, | |
show_label = True if pr_created else False, | |
), | |
pr_preview_title: gr.skip() if pr_branch else gr.Textbox( | |
label = commit.commit_message if pr_created else None, | |
show_label = True if pr_created else False, | |
), | |
pr_description: gr.Code( | |
value = pr_description_default, | |
), | |
pr_submit: gr.skip() if pr_branch else gr.Button( | |
value = f"Commit to PR #{commit.pr_num}" if pr_created else "Create Pull Request", | |
), | |
pr_tabs: gr.Tabs( | |
selected = "edit", | |
), | |
} | |
def switch_to_edit_tabs(): | |
return gr.Tabs( | |
selected = "edit", | |
), gr.Tabs( | |
selected = "edit", | |
) | |
def switch_to_edit_tab(): | |
return gr.Tabs( | |
selected = "edit", | |
) | |
def template_data_from_model_info( | |
repo: str, | |
branch: str | None, | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
try: | |
info = hfapi.model_info( | |
repo, | |
revision = branch, | |
expand = [ | |
"config", | |
"disabled", | |
"gated", | |
"gguf", | |
"private", | |
"widgetData", | |
], | |
token = oauth_token.token if oauth_token else False, | |
) | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
title = "Error loading model info", | |
) | |
return {}, None, None, None, None, None, None | |
templates = info.gguf.get("chat_template") if info.gguf else info.config.get("tokenizer_config", {}).get("chat_template") if info.config else None | |
model_info = { | |
"gguf": bool(info.gguf), | |
"disabled": info.disabled, | |
"gated": info.gated, | |
"private": info.private, | |
} | |
template_messages = example_values[0][1] | |
template_tool_use = None | |
template_rag = None | |
template_inverse = None | |
template_kwargs = { | |
"add_generation_prompt": True, | |
"clean_up_tokenization_spaces": False, | |
"bos_token": "<|startoftext|>", | |
"eos_token": "<|im_end|>", | |
} | |
if info.config: | |
# template_inverse = info.config.get("tokenizer_config", {}).get("inverse_template") | |
for k, v in info.config.get("tokenizer_config", {}).items(): | |
if k != "chat_template": # and k != "inverse_template": | |
template_kwargs[k] = v | |
if info.widget_data: | |
for data in info.widget_data: | |
if "messages" in data: | |
template_messages = json.dumps(data["messages"], ensure_ascii = False, indent = 2) | |
break | |
if isinstance(templates, list): | |
templates = { template["name"]: template["template"] for template in templates } | |
template_tool_use = templates.get("tool_use") | |
template_rag = templates.get("rag") | |
templates = templates.get("default") | |
return model_info, json.dumps(template_kwargs, ensure_ascii = False, indent = 2), template_messages, templates, template_tool_use, template_rag, template_inverse | |
def template_data_from_model_files( | |
repo: str, | |
branch: str | None, | |
info: dict, | |
progress = gr.Progress(track_tqdm = True), | |
oauth_token: gr.OAuthToken | None = None, | |
): | |
write_access = False | |
if info and oauth_token: | |
if info.get("gguf"): | |
gr.Warning("Repository contains GGUFs, use GGUF Editor if you want to commit changes...") | |
elif info.get("disabled"): | |
gr.Warning("Repository is disabled, committing changes is not possible...") | |
elif (gated := info.get("gated")) or (private := info.get("private")): | |
try: | |
hfapi.auth_check( | |
repo, | |
token = oauth_token.token if oauth_token else False, | |
) | |
except Exception as e: | |
if gated: | |
gr.Warning(f"Repository is gated with {gated} approval, you must request access to be able to make changes...") | |
elif private: | |
gr.Warning("Repository is private, you must use proper credentials to be able to make changes...") | |
gr.Warning(str(e)) | |
else: | |
write_access = True | |
else: | |
write_access = True | |
if write_access: | |
if (write_access := hfapi.file_exists( | |
repo, | |
ModelFiles.TOKENIZER_CONFIG, | |
revision = branch, | |
token = oauth_token.token if oauth_token else False, | |
)): | |
try: | |
commits = hfapi.list_repo_commits( | |
repo, | |
revision = branch, | |
token = oauth_token.token if oauth_token else False, | |
) | |
parent_commit = commits[0].commit_id if commits else None | |
tokenizer_config_file = hfapi.hf_hub_download( | |
repo, | |
ModelFiles.TOKENIZER_CONFIG, | |
revision = parent_commit or branch, | |
token = oauth_token.token if oauth_token else False, | |
) | |
tokenizer_chat_template = None | |
if (hfapi.file_exists( | |
repo, | |
ModelFiles.TOKENIZER_CHAT_TEMPLATE, | |
revision = branch, | |
token = oauth_token.token if oauth_token else False, | |
)): | |
tokenizer_chat_template = hfapi.hf_hub_download( | |
repo, | |
ModelFiles.TOKENIZER_CHAT_TEMPLATE, | |
revision = parent_commit or branch, | |
token = oauth_token.token if oauth_token else False, | |
) | |
tokenizer_inverse_template = None | |
if (hfapi.file_exists( | |
repo, | |
ModelFiles.TOKENIZER_INVERSE_TEMPLATE, | |
revision = branch, | |
token = oauth_token.token if oauth_token else False, | |
)): | |
tokenizer_inverse_template = hfapi.hf_hub_download( | |
repo, | |
ModelFiles.TOKENIZER_INVERSE_TEMPLATE, | |
revision = parent_commit or branch, | |
token = oauth_token.token if oauth_token else False, | |
) | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
title = "Error downloading template files", | |
) | |
else: | |
info["parent_commit"] = parent_commit | |
if tokenizer_config_file: | |
with open(tokenizer_config_file, "r", encoding = "utf-8") as fp: | |
config_content = fp.read() | |
info[ModelFiles.TOKENIZER_CONFIG] = { | |
"data": config_content, | |
"content": json.loads(config_content), | |
} | |
if tokenizer_chat_template: | |
with open(tokenizer_chat_template, "r", encoding = "utf-8") as fp: | |
template_data = fp.read() | |
info[ModelFiles.TOKENIZER_CHAT_TEMPLATE] = { | |
"data": template_data, | |
} | |
if tokenizer_inverse_template: | |
with open(tokenizer_inverse_template, "r", encoding = "utf-8") as fp: | |
template_data = fp.read() | |
info[ModelFiles.TOKENIZER_INVERSE_TEMPLATE] = { | |
"data": template_data, | |
} | |
else: | |
gr.Warning(f"No {ModelFiles.TOKENIZER_CONFIG} found in repository...") | |
pr_details = None | |
if branch and branch.startswith("refs/pr/"): | |
pr_num = branch.split("/")[-1] | |
if pr_num and pr_num.isdigit(): | |
pr_details = hfapi.get_discussion_details( | |
repo, | |
int(pr_num), | |
token = oauth_token.token if oauth_token else False, | |
) | |
return { | |
model_info: info, | |
pr_group: gr.Accordion( | |
visible = write_access, | |
), | |
pr_title: gr.Textbox( | |
value = None, | |
placeholder = "Message" if pr_details else "Title", | |
label = pr_details.title if pr_details else None, | |
show_label = True if pr_details else False, | |
), | |
pr_preview_title: gr.Textbox( | |
label = pr_details.title if pr_details else None, | |
show_label = True if pr_details else False, | |
), | |
pr_description: gr.Code( | |
value = pr_description_default, | |
), | |
pr_submit: gr.Button( | |
value = f"Commit to PR #{pr_details.num}" if pr_details else "Create Pull Request", | |
), | |
# chat_template: gr.skip() if ModelFiles.TOKENIZER_CHAT_TEMPLATE not in info else gr.Code( | |
# value = info[ModelFiles.TOKENIZER_CHAT_TEMPLATE]["data"], | |
# ), | |
# inverse_template: gr.skip() if ModelFiles.TOKENIZER_INVERSE_TEMPLATE not in info else gr.Code( | |
# value = info[ModelFiles.TOKENIZER_INVERSE_TEMPLATE]["data"], | |
# ), | |
} | |
def update_examples( | |
settings: str, | |
): | |
settings = json.loads(settings) | |
examples = [] | |
for example in example_values: | |
x = example.copy() | |
x0 = json.loads(x[0]) | |
x0.update(settings) | |
x[0] = json.dumps(x0, ensure_ascii = False, indent = 2) | |
examples.append(x) | |
return gr.Dataset( | |
samples = examples, | |
) | |
gr.on( | |
fn = template_data_from_model_info, | |
triggers = [ | |
hf_search.submit, | |
hf_branch.input, | |
], | |
inputs = [ | |
hf_search, | |
hf_branch, | |
], | |
outputs = [ | |
model_info, | |
chat_settings, | |
chat_messages, | |
chat_template, | |
chat_template_tool_use, | |
chat_template_rag, | |
inverse_template, | |
], | |
).success( | |
fn = update_examples, | |
inputs = [ | |
chat_settings, | |
], | |
outputs = [ | |
example_input.dataset, | |
], | |
show_api = False, | |
).then( | |
fn = template_data_from_model_files, | |
inputs = [ | |
hf_search, | |
hf_branch, | |
model_info, | |
], | |
outputs = [ | |
model_info, | |
pr_group, | |
pr_title, | |
pr_preview_title, | |
pr_description, | |
pr_submit, | |
# chat_template, | |
# inverse_template, | |
], | |
show_api = False, | |
) | |
def render_chat_templates( | |
settings: str, | |
messages: str, | |
template: str, | |
template_tool_use: str | None = None, | |
template_rag: str | None = None, | |
template_inverse: str | None = None, | |
): | |
try: | |
settings = json.loads(settings) if settings else {} | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Template Settings Error", | |
) | |
return gr.skip() | |
try: | |
messages = json.loads(messages) if messages else [] | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Template Messages Error", | |
) | |
return gr.skip() | |
if not isinstance(settings, dict): | |
gr.Warning("Invalid Template Settings!") | |
return gr.skip() | |
if not messages or not isinstance(messages, list) or not isinstance(messages[0], dict) or "role" not in messages[0]: | |
gr.Warning("No Template Messages!") | |
return gr.skip() | |
tools = settings.get("tools") | |
documents = settings.get("documents") | |
add_generation_prompt = settings.get("add_generation_prompt") | |
cleanup_settings = [] | |
for k in settings.keys(): | |
if k.endswith("_side") or k.endswith("_token") or k.endswith("_tokens") or k == "clean_up_tokenization_spaces": | |
continue | |
cleanup_settings.append(k) | |
for cleanup in cleanup_settings: | |
del settings[cleanup] | |
tokenizer = PreTrainedTokenizerBase(**settings) | |
chat_output = None | |
chat_tool_use_output = None | |
chat_rag_output = None | |
inverse_output = None | |
try: | |
chat_output = tokenizer.apply_chat_template(messages, tools = tools, documents = documents, chat_template = template, add_generation_prompt = add_generation_prompt, tokenize = False) | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Chat Template Error", | |
) | |
try: | |
chat_tool_use_output = tokenizer.apply_chat_template(messages, tools = tools or [], chat_template = template_tool_use, add_generation_prompt = add_generation_prompt, tokenize = False) if template_tool_use else None | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Tool Use Template Error", | |
) | |
try: | |
chat_rag_output = tokenizer.apply_chat_template(messages, documents = documents or [], chat_template = template_rag, add_generation_prompt = add_generation_prompt, tokenize = False) if template_rag else None | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "RAG Template Error", | |
) | |
try: | |
inverse_output = tokenizer.apply_inverse_template(messages, inverse_template = template_inverse) if template_inverse else None | |
except Exception as e: | |
gr.Warning( | |
message = str(e), | |
duration = None, | |
title = "Inverse Template Error", | |
) | |
return chat_output, chat_tool_use_output, chat_rag_output, json.dumps(inverse_output, ensure_ascii = False, indent = 2) if inverse_output is not None else None | |
if __name__ == "__main__": | |
blocks.queue( | |
max_size = 10, | |
default_concurrency_limit = 10, | |
) | |
blocks.launch(ssr_mode = False) | |