|
import random |
|
import pandas as pd |
|
import gradio as gr |
|
from typing import Dict, Optional |
|
|
|
import unibox as ub |
|
|
|
|
|
CURRENT_DATASET = { |
|
"id": None, |
|
"df": None |
|
} |
|
|
|
rating_map = { |
|
"g": "general", |
|
"s": "sensitive", |
|
"q": "questionable", |
|
"e": "explicit" |
|
} |
|
|
|
def load_dataset_if_needed(dataset_id: str): |
|
""" |
|
Checks if dataset_id is different from what's currently loaded. |
|
If so, loads from HF again and updates CURRENT_DATASET. |
|
""" |
|
if CURRENT_DATASET["id"] != dataset_id: |
|
df = ub.loads(f"hf://{dataset_id}").to_pandas() |
|
CURRENT_DATASET["id"] = dataset_id |
|
CURRENT_DATASET["df"] = df |
|
|
|
|
|
def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str: |
|
""" |
|
1girl long_hair blush -> 1girl, long_hair, blush |
|
""" |
|
tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i] |
|
if shuffle: |
|
random.shuffle(tags_list) |
|
return ", ".join(tags_list) |
|
|
|
|
|
def get_tags_dict(df_row: pd.Series) -> dict: |
|
""" |
|
Returns a dict with rating/artist/character/copyright/general/meta |
|
plus numeric score. |
|
""" |
|
rating = df_row["rating"] |
|
artist = df_row["tag_string_artist"] |
|
character = df_row["tag_string_character"] |
|
copyright_ = df_row["tag_string_copyright"] |
|
general = df_row["tag_string_general"] |
|
meta = df_row["tag_string_meta"] |
|
score = df_row["score"] |
|
|
|
rating_str = rating_map.get(rating, "") |
|
artist_str = artist if artist else "" |
|
character_str = convert_dbr_tag_string(character) if character else "" |
|
copyright_str = f"copyright:{copyright_}" if copyright_ else "" |
|
general_str = convert_dbr_tag_string(general) if general else "" |
|
meta_str = convert_dbr_tag_string(meta) if meta else "" |
|
_score = str(score) if score else "" |
|
|
|
return { |
|
"rating_str": rating_str, |
|
"artist_str": artist_str, |
|
"character_str": character_str, |
|
"copyright_str": copyright_str, |
|
"general_str": general_str, |
|
"meta_str": meta_str, |
|
"score": _score, |
|
} |
|
|
|
|
|
def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str: |
|
""" |
|
Build a final comma-separated string (rating, artist, character, etc.). |
|
""" |
|
context = [] |
|
|
|
if tags_dict["rating_str"]: |
|
context.append(tags_dict["rating_str"]) |
|
|
|
if tags_dict["artist_str"] and add_artist_tags: |
|
context.append(f"artist:{tags_dict['artist_str']}") |
|
|
|
if tags_dict["character_str"]: |
|
context.append(tags_dict["character_str"]) |
|
|
|
if tags_dict["copyright_str"]: |
|
context.append(tags_dict["copyright_str"]) |
|
|
|
if tags_dict["general_str"]: |
|
context.append(tags_dict["general_str"]) |
|
|
|
return ", ".join(context) |
|
|
|
|
|
def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5, |
|
tags_front: str = "", tags_back: str = "", |
|
add_artist_tags: bool = True) -> list: |
|
filtered_df = df.iloc[start_idx:end_idx] |
|
captions = [] |
|
for _, row in filtered_df.iterrows(): |
|
tags = get_tags_dict(row) |
|
caption_base = build_tags_from_tags_dict(tags, add_artist_tags) |
|
|
|
pieces = [part for part in [tags_front, caption_base, tags_back] if part] |
|
final_caption = ", ".join(pieces) |
|
captions.append(final_caption) |
|
return captions |
|
|
|
|
|
def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list: |
|
filtered_df = df.iloc[start_idx:end_idx] |
|
return [row["large_file_url"] for _, row in filtered_df.iterrows()] |
|
|
|
|
|
def gradio_interface( |
|
dataset_id: str, |
|
start_idx: int = 0, |
|
display_count: int = 5, |
|
tags_front: str = "", |
|
tags_back: str = "", |
|
add_artist_tags: bool = True |
|
): |
|
""" |
|
1) Loads dataset if needed |
|
2) Returns (DataFrame, Gallery, InfoMessage) |
|
""" |
|
|
|
load_dataset_if_needed(dataset_id) |
|
dset_df = CURRENT_DATASET["df"] |
|
if dset_df is None: |
|
return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}" |
|
|
|
|
|
total_len = len(dset_df) |
|
if total_len == 0: |
|
return pd.DataFrame(), [], f"Dataset {dataset_id} is empty." |
|
|
|
start_idx = max(start_idx, 0) |
|
if start_idx >= total_len: |
|
start_idx = total_len - 1 |
|
|
|
end_idx = start_idx + display_count |
|
if end_idx > total_len: |
|
end_idx = total_len |
|
|
|
|
|
idxs = range(start_idx, end_idx) |
|
captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags) |
|
previews = get_previews_for_rows(dset_df, start_idx, end_idx) |
|
df_out = pd.DataFrame({"index": idxs, "Captions": captions}) |
|
|
|
|
|
info_msg = ( |
|
f"**Current dataset:** {CURRENT_DATASET['id']} \n" |
|
f"**Dataset length:** {total_len} \n" |
|
f"**start_idx:** {start_idx}, **display_count:** {display_count}, " |
|
f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', " |
|
f"**add_artist_tags:** {add_artist_tags}" |
|
) |
|
|
|
return df_out, previews, info_msg |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Danbooru2025 Dataset Captions and Previews") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
dataset_id_input = gr.Textbox( |
|
value="dataproc5/test-danbooru2025-tag-balanced-2k", |
|
label="Dataset ID" |
|
) |
|
start_idx_input = gr.Number(value=500, label="Start Index") |
|
display_count_input = gr.Slider( |
|
value=5, minimum=1, maximum=50, step=1, |
|
label="Number of Items" |
|
) |
|
tags_front_input = gr.Textbox(value="", label="Tags Front") |
|
tags_back_input = gr.Textbox(value="", label="Tags Back") |
|
add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True) |
|
|
|
run_button = gr.Button("Get Captions & Previews") |
|
|
|
with gr.Column(scale=2): |
|
captions_df_out = gr.DataFrame(label="Captions") |
|
previews_gallery_out = gr.Gallery(label="Previews", type="filepath") |
|
info_textbox_out = gr.Markdown(value="") |
|
|
|
run_button.click( |
|
fn=gradio_interface, |
|
inputs=[ |
|
dataset_id_input, |
|
start_idx_input, |
|
display_count_input, |
|
tags_front_input, |
|
tags_back_input, |
|
add_artist_tags_input |
|
], |
|
outputs=[ |
|
captions_df_out, |
|
previews_gallery_out, |
|
info_textbox_out |
|
] |
|
) |
|
|
|
demo.launch() |