Spaces:
Runtime error
Runtime error
"""Module which defines the code for the "Manage models" tab.""" | |
from collections.abc import Sequence | |
from functools import partial | |
import gradio as gr | |
import pandas as pd | |
import requests | |
# Function to search for RVC models on Hugging Face | |
def search_rvc_models(query): | |
url = f"https://huggingface.co/api/models?search={query}&library=rvc" | |
response = requests.get(url) | |
if response.status_code == 200: | |
models = response.json() | |
# Create a DataFrame to store the results | |
df = pd.DataFrame(models) | |
# Filter the DataFrame to only include the desired columns | |
df = df[["id", "likes", "downloads"]] | |
# Add a new column for the download URL | |
df["downloadUrl"] = "https://huggingface.co/" + df["id"] | |
# Sort the DataFrame by downloads in descending order | |
df = df.sort_values(by="downloads", ascending=False) | |
return df | |
else: | |
return pd.DataFrame({"id": ["No models found"]}) | |
from ultimate_rvc.core.manage.models import ( | |
delete_all_models, | |
delete_models, | |
download_model, | |
filter_public_models_table, | |
get_public_model_tags, | |
get_saved_model_names, | |
upload_model, | |
) | |
from ultimate_rvc.web.common import ( | |
PROGRESS_BAR, | |
confirm_box_js, | |
confirmation_harness, | |
exception_harness, | |
render_msg, | |
update_dropdowns, | |
) | |
from ultimate_rvc.web.typing_extra import DropdownValue | |
def _update_models( | |
num_components: int, | |
value: DropdownValue = None, | |
value_indices: Sequence[int] = [], | |
) -> gr.Dropdown | tuple[gr.Dropdown, ...]: | |
""" | |
Update the choices of one or more dropdown components to the set of | |
currently saved voice models. | |
Optionally updates the default value of one or more of these | |
components. | |
Parameters | |
---------- | |
num_components : int | |
Number of dropdown components to update. | |
value : DropdownValue, optional | |
New value for dropdown components. | |
value_indices : Sequence[int], default=[] | |
Indices of dropdown components to update the value for. | |
Returns | |
------- | |
gr.Dropdown | tuple[gr.Dropdown, ...] | |
Updated dropdown component or components. | |
""" | |
return update_dropdowns(get_saved_model_names, num_components, value, value_indices) | |
def _filter_public_models_table(tags: Sequence[str], query: str) -> gr.Dataframe: | |
""" | |
Filter table containing metadata of public voice models by tags and | |
a search query. | |
Parameters | |
---------- | |
tags : Sequence[str] | |
Tags to filter the metadata table by. | |
query : str | |
Search query to filter the metadata table by. | |
Returns | |
------- | |
gr.Dataframe | |
The filtered table rendered in a Gradio dataframe. | |
""" | |
models_table = filter_public_models_table(tags, query) | |
return gr.Dataframe(value=models_table) | |
def _autofill_model_name_and_url( | |
public_models_table: pd.DataFrame, | |
select_event: gr.SelectData, | |
) -> tuple[gr.Textbox, gr.Textbox]: | |
""" | |
Autofill two textboxes with respectively the name and URL that is | |
saved in the currently selected row of the public models table. | |
Parameters | |
---------- | |
public_models_table : pd.DataFrame | |
The public models table saved in a Pandas dataframe. | |
select_event : gr.SelectData | |
Event containing the index of the currently selected row in the | |
public models table. | |
Returns | |
------- | |
name : gr.Textbox | |
The textbox containing the model name. | |
url : gr.Textbox | |
The textbox containing the model URL. | |
Raises | |
------ | |
TypeError | |
If the index in the provided event is not a sequence. | |
""" | |
event_index = select_event.index | |
if not isinstance(event_index, Sequence): | |
err_msg = ( | |
f"Expected a sequence of indices but got {type(event_index)} from the" | |
" provided event." | |
) | |
raise TypeError(err_msg) | |
event_index = event_index[0] | |
url = public_models_table.loc[event_index, "URL"] | |
name = public_models_table.loc[event_index, "Name"] | |
if isinstance(url, str) and isinstance(name, str): | |
return gr.Textbox(value=name), gr.Textbox(value=url) | |
err_msg = ( | |
"Expected model name and URL to be strings but got" | |
f" {type(name)} and {type(url)} respectively." | |
) | |
raise TypeError(err_msg) | |
def render( | |
model_delete: gr.Dropdown, | |
model_1click: gr.Dropdown, | |
model_multi: gr.Dropdown, | |
) -> None: | |
""" | |
Render "Manage models" tab. | |
Parameters | |
---------- | |
model_delete : gr.Dropdown | |
Dropdown for selecting voice models to delete in the | |
"Delete models" tab. | |
model_1click : gr.Dropdown | |
Dropdown for selecting a voice model to use in the | |
"One-click generation" tab. | |
model_multi : gr.Dropdown | |
Dropdown for selecting a voice model to use in the | |
"Multi-step generation" tab. | |
""" | |
# Download tab | |
dummy_checkbox = gr.Checkbox(visible=False) | |
with gr.Tab("Download model"): | |
with gr.Accordion("View public models table", open=False): | |
gr.Markdown("") | |
gr.Markdown("*HOW TO USE*") | |
gr.Markdown( | |
"- Filter voice models by selecting one or more tags and/or providing a" | |
" search query.", | |
) | |
gr.Markdown( | |
"- Select a row in the table to autofill the name and" | |
" URL for the given voice model in the form fields below.", | |
) | |
gr.Markdown("") | |
with gr.Row(): | |
search_query = gr.Textbox(label="Search query") | |
tags = gr.CheckboxGroup( | |
value=[], | |
label="Tags", | |
choices=get_public_model_tags(), | |
) | |
with gr.Row(): | |
public_models_table = gr.Dataframe( | |
value=_filter_public_models_table, | |
inputs=[tags, search_query], | |
headers=["Name", "Description", "Tags", "Credit", "Added", "URL"], | |
label="Public models table", | |
interactive=False, | |
) | |
with gr.Row(): | |
model_url = gr.Textbox( | |
label="Model URL", | |
info=( | |
"Should point to a zip file containing a .pth model file and" | |
" optionally also an .index file." | |
), | |
) | |
model_name = gr.Textbox( | |
label="Model name", | |
info="Enter a unique name for the voice model.", | |
) | |
with gr.Row(equal_height=True): | |
download_btn = gr.Button("Download 🌐", variant="primary", scale=19) | |
download_msg = gr.Textbox( | |
label="Output message", | |
interactive=False, | |
scale=20, | |
) | |
public_models_table.select( | |
_autofill_model_name_and_url, | |
inputs=public_models_table, | |
outputs=[model_name, model_url], | |
show_progress="hidden", | |
) | |
download_btn_click = download_btn.click( | |
partial( | |
exception_harness(download_model), | |
progress_bar=PROGRESS_BAR, | |
), | |
inputs=[model_url, model_name], | |
outputs=download_msg, | |
).success( | |
partial( | |
render_msg, | |
"[+] Succesfully downloaded voice model!", | |
), | |
inputs=model_name, | |
outputs=download_msg, | |
show_progress="hidden", | |
) | |
# Upload tab | |
with gr.Tab("Upload model"): | |
with gr.Accordion("HOW TO USE"): | |
gr.Markdown("") | |
gr.Markdown( | |
"1. Find the .pth file for a locally trained RVC model (e.g. in your" | |
" local weights folder) and optionally also a corresponding .index file" | |
" (e.g. in your logs/[name] folder)", | |
) | |
gr.Markdown( | |
"2. Upload the files directly or save them to a folder, then compress" | |
" that folder and upload the resulting .zip file", | |
) | |
gr.Markdown("3. Enter a unique name for the uploaded model") | |
gr.Markdown("4. Click 'Upload'") | |
with gr.Row(): | |
model_files = gr.File( | |
label="Files", | |
file_count="multiple", | |
file_types=[".zip", ".pth", ".index"], | |
) | |
local_model_name = gr.Textbox(label="Model name") | |
with gr.Row(equal_height=True): | |
upload_btn = gr.Button("Upload", variant="primary", scale=19) | |
upload_msg = gr.Textbox( | |
label="Output message", | |
interactive=False, | |
scale=20, | |
) | |
upload_btn_click = upload_btn.click( | |
partial(exception_harness(upload_model), progress_bar=PROGRESS_BAR), | |
inputs=[model_files, local_model_name], | |
outputs=upload_msg, | |
).success( | |
partial( | |
render_msg, | |
"[+] Successfully uploaded voice model!", | |
), | |
inputs=local_model_name, | |
outputs=upload_msg, | |
show_progress="hidden", | |
) | |
with gr.Tab("Delete models"): | |
with gr.Row(): | |
with gr.Column(): | |
model_delete.render() | |
delete_btn = gr.Button("Delete selected", variant="secondary") | |
delete_all_btn = gr.Button("Delete all", variant="primary") | |
with gr.Column(): | |
delete_msg = gr.Textbox(label="Output message", interactive=False) | |
delete_btn_click = delete_btn.click( | |
partial(confirmation_harness(delete_models), progress_bar=PROGRESS_BAR), | |
inputs=[dummy_checkbox, model_delete], | |
outputs=delete_msg, | |
js=confirm_box_js( | |
"Are you sure you want to delete the selected voice models?", | |
), | |
).success( | |
partial(render_msg, "[-] Successfully deleted selected voice models!"), | |
outputs=delete_msg, | |
show_progress="hidden", | |
) | |
delete_all_btn_click = delete_all_btn.click( | |
partial( | |
confirmation_harness(delete_all_models), | |
progress_bar=PROGRESS_BAR, | |
), | |
inputs=dummy_checkbox, | |
outputs=delete_msg, | |
js=confirm_box_js("Are you sure you want to delete all voice models?"), | |
).success( | |
partial(render_msg, "[-] Successfully deleted all voice models!"), | |
outputs=delete_msg, | |
show_progress="hidden", | |
) | |
with gr.Tab("Search models"): | |
# Textbox for user to enter search query | |
query = gr.Textbox(label="Search for RVC models", placeholder="Enter your search query here") | |
# Button to trigger the search | |
search_button = gr.Button("Search") | |
# Output for displaying the search results as a DataFrame | |
results = gr.Dataframe(label="Search Results") | |
# Event listener for the search button | |
search_button.click(fn=search_rvc_models, inputs=query, outputs=results) | |
for click_event in [ | |
download_btn_click, | |
upload_btn_click, | |
delete_btn_click, | |
delete_all_btn_click, | |
]: | |
click_event.success( | |
partial(_update_models, 3, [], [2]), | |
outputs=[model_1click, model_multi, model_delete], | |
show_progress="hidden", | |
) |