AICoverGen / src /webui.py
SpyC0der77's picture
Update src/webui.py
ea56cb4 verified
"""Module which defines the code for the "Manage models" tab."""
from collections.abc import Sequence
from functools import partial
import gradio as gr
import pandas as pd
import requests
# Function to search for RVC models on Hugging Face
def search_rvc_models(query):
url = f"https://huggingface.co/api/models?search={query}&library=rvc"
response = requests.get(url)
if response.status_code == 200:
models = response.json()
# Create a DataFrame to store the results
df = pd.DataFrame(models)
# Filter the DataFrame to only include the desired columns
df = df[["id", "likes", "downloads"]]
# Add a new column for the download URL
df["downloadUrl"] = "https://huggingface.co/" + df["id"]
# Sort the DataFrame by downloads in descending order
df = df.sort_values(by="downloads", ascending=False)
return df
else:
return pd.DataFrame({"id": ["No models found"]})
from ultimate_rvc.core.manage.models import (
delete_all_models,
delete_models,
download_model,
filter_public_models_table,
get_public_model_tags,
get_saved_model_names,
upload_model,
)
from ultimate_rvc.web.common import (
PROGRESS_BAR,
confirm_box_js,
confirmation_harness,
exception_harness,
render_msg,
update_dropdowns,
)
from ultimate_rvc.web.typing_extra import DropdownValue
def _update_models(
num_components: int,
value: DropdownValue = None,
value_indices: Sequence[int] = [],
) -> gr.Dropdown | tuple[gr.Dropdown, ...]:
"""
Update the choices of one or more dropdown components to the set of
currently saved voice models.
Optionally updates the default value of one or more of these
components.
Parameters
----------
num_components : int
Number of dropdown components to update.
value : DropdownValue, optional
New value for dropdown components.
value_indices : Sequence[int], default=[]
Indices of dropdown components to update the value for.
Returns
-------
gr.Dropdown | tuple[gr.Dropdown, ...]
Updated dropdown component or components.
"""
return update_dropdowns(get_saved_model_names, num_components, value, value_indices)
def _filter_public_models_table(tags: Sequence[str], query: str) -> gr.Dataframe:
"""
Filter table containing metadata of public voice models by tags and
a search query.
Parameters
----------
tags : Sequence[str]
Tags to filter the metadata table by.
query : str
Search query to filter the metadata table by.
Returns
-------
gr.Dataframe
The filtered table rendered in a Gradio dataframe.
"""
models_table = filter_public_models_table(tags, query)
return gr.Dataframe(value=models_table)
def _autofill_model_name_and_url(
public_models_table: pd.DataFrame,
select_event: gr.SelectData,
) -> tuple[gr.Textbox, gr.Textbox]:
"""
Autofill two textboxes with respectively the name and URL that is
saved in the currently selected row of the public models table.
Parameters
----------
public_models_table : pd.DataFrame
The public models table saved in a Pandas dataframe.
select_event : gr.SelectData
Event containing the index of the currently selected row in the
public models table.
Returns
-------
name : gr.Textbox
The textbox containing the model name.
url : gr.Textbox
The textbox containing the model URL.
Raises
------
TypeError
If the index in the provided event is not a sequence.
"""
event_index = select_event.index
if not isinstance(event_index, Sequence):
err_msg = (
f"Expected a sequence of indices but got {type(event_index)} from the"
" provided event."
)
raise TypeError(err_msg)
event_index = event_index[0]
url = public_models_table.loc[event_index, "URL"]
name = public_models_table.loc[event_index, "Name"]
if isinstance(url, str) and isinstance(name, str):
return gr.Textbox(value=name), gr.Textbox(value=url)
err_msg = (
"Expected model name and URL to be strings but got"
f" {type(name)} and {type(url)} respectively."
)
raise TypeError(err_msg)
def render(
model_delete: gr.Dropdown,
model_1click: gr.Dropdown,
model_multi: gr.Dropdown,
) -> None:
"""
Render "Manage models" tab.
Parameters
----------
model_delete : gr.Dropdown
Dropdown for selecting voice models to delete in the
"Delete models" tab.
model_1click : gr.Dropdown
Dropdown for selecting a voice model to use in the
"One-click generation" tab.
model_multi : gr.Dropdown
Dropdown for selecting a voice model to use in the
"Multi-step generation" tab.
"""
# Download tab
dummy_checkbox = gr.Checkbox(visible=False)
with gr.Tab("Download model"):
with gr.Accordion("View public models table", open=False):
gr.Markdown("")
gr.Markdown("*HOW TO USE*")
gr.Markdown(
"- Filter voice models by selecting one or more tags and/or providing a"
" search query.",
)
gr.Markdown(
"- Select a row in the table to autofill the name and"
" URL for the given voice model in the form fields below.",
)
gr.Markdown("")
with gr.Row():
search_query = gr.Textbox(label="Search query")
tags = gr.CheckboxGroup(
value=[],
label="Tags",
choices=get_public_model_tags(),
)
with gr.Row():
public_models_table = gr.Dataframe(
value=_filter_public_models_table,
inputs=[tags, search_query],
headers=["Name", "Description", "Tags", "Credit", "Added", "URL"],
label="Public models table",
interactive=False,
)
with gr.Row():
model_url = gr.Textbox(
label="Model URL",
info=(
"Should point to a zip file containing a .pth model file and"
" optionally also an .index file."
),
)
model_name = gr.Textbox(
label="Model name",
info="Enter a unique name for the voice model.",
)
with gr.Row(equal_height=True):
download_btn = gr.Button("Download 🌐", variant="primary", scale=19)
download_msg = gr.Textbox(
label="Output message",
interactive=False,
scale=20,
)
public_models_table.select(
_autofill_model_name_and_url,
inputs=public_models_table,
outputs=[model_name, model_url],
show_progress="hidden",
)
download_btn_click = download_btn.click(
partial(
exception_harness(download_model),
progress_bar=PROGRESS_BAR,
),
inputs=[model_url, model_name],
outputs=download_msg,
).success(
partial(
render_msg,
"[+] Succesfully downloaded voice model!",
),
inputs=model_name,
outputs=download_msg,
show_progress="hidden",
)
# Upload tab
with gr.Tab("Upload model"):
with gr.Accordion("HOW TO USE"):
gr.Markdown("")
gr.Markdown(
"1. Find the .pth file for a locally trained RVC model (e.g. in your"
" local weights folder) and optionally also a corresponding .index file"
" (e.g. in your logs/[name] folder)",
)
gr.Markdown(
"2. Upload the files directly or save them to a folder, then compress"
" that folder and upload the resulting .zip file",
)
gr.Markdown("3. Enter a unique name for the uploaded model")
gr.Markdown("4. Click 'Upload'")
with gr.Row():
model_files = gr.File(
label="Files",
file_count="multiple",
file_types=[".zip", ".pth", ".index"],
)
local_model_name = gr.Textbox(label="Model name")
with gr.Row(equal_height=True):
upload_btn = gr.Button("Upload", variant="primary", scale=19)
upload_msg = gr.Textbox(
label="Output message",
interactive=False,
scale=20,
)
upload_btn_click = upload_btn.click(
partial(exception_harness(upload_model), progress_bar=PROGRESS_BAR),
inputs=[model_files, local_model_name],
outputs=upload_msg,
).success(
partial(
render_msg,
"[+] Successfully uploaded voice model!",
),
inputs=local_model_name,
outputs=upload_msg,
show_progress="hidden",
)
with gr.Tab("Delete models"):
with gr.Row():
with gr.Column():
model_delete.render()
delete_btn = gr.Button("Delete selected", variant="secondary")
delete_all_btn = gr.Button("Delete all", variant="primary")
with gr.Column():
delete_msg = gr.Textbox(label="Output message", interactive=False)
delete_btn_click = delete_btn.click(
partial(confirmation_harness(delete_models), progress_bar=PROGRESS_BAR),
inputs=[dummy_checkbox, model_delete],
outputs=delete_msg,
js=confirm_box_js(
"Are you sure you want to delete the selected voice models?",
),
).success(
partial(render_msg, "[-] Successfully deleted selected voice models!"),
outputs=delete_msg,
show_progress="hidden",
)
delete_all_btn_click = delete_all_btn.click(
partial(
confirmation_harness(delete_all_models),
progress_bar=PROGRESS_BAR,
),
inputs=dummy_checkbox,
outputs=delete_msg,
js=confirm_box_js("Are you sure you want to delete all voice models?"),
).success(
partial(render_msg, "[-] Successfully deleted all voice models!"),
outputs=delete_msg,
show_progress="hidden",
)
with gr.Tab("Search models"):
# Textbox for user to enter search query
query = gr.Textbox(label="Search for RVC models", placeholder="Enter your search query here")
# Button to trigger the search
search_button = gr.Button("Search")
# Output for displaying the search results as a DataFrame
results = gr.Dataframe(label="Search Results")
# Event listener for the search button
search_button.click(fn=search_rvc_models, inputs=query, outputs=results)
for click_event in [
download_btn_click,
upload_btn_click,
delete_btn_click,
delete_all_btn_click,
]:
click_event.success(
partial(_update_models, 3, [], [2]),
outputs=[model_1click, model_multi, model_delete],
show_progress="hidden",
)