import random import pandas as pd import gradio as gr from typing import Dict, Optional import unibox as ub # Store current dataset in a global dict so it persists across Gradio calls. CURRENT_DATASET = { "id": None, "df": None } rating_map = { "g": "general", "s": "sensitive", "q": "questionable", "e": "explicit" } def load_dataset_if_needed(dataset_id: str): """ Checks if dataset_id is different from what's currently loaded. If so, loads from HF again and updates CURRENT_DATASET. """ if CURRENT_DATASET["id"] != dataset_id: df = ub.loads(f"hf://{dataset_id}").to_pandas() CURRENT_DATASET["id"] = dataset_id CURRENT_DATASET["df"] = df def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str: """ 1girl long_hair blush -> 1girl, long_hair, blush """ tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i] if shuffle: random.shuffle(tags_list) return ", ".join(tags_list) def get_tags_dict(df_row: pd.Series) -> dict: """ Returns a dict with rating/artist/character/copyright/general/meta plus numeric score. """ rating = df_row["rating"] artist = df_row["tag_string_artist"] character = df_row["tag_string_character"] copyright_ = df_row["tag_string_copyright"] general = df_row["tag_string_general"] meta = df_row["tag_string_meta"] score = df_row["score"] rating_str = rating_map.get(rating, "") artist_str = artist if artist else "" character_str = convert_dbr_tag_string(character) if character else "" copyright_str = f"copyright:{copyright_}" if copyright_ else "" general_str = convert_dbr_tag_string(general) if general else "" meta_str = convert_dbr_tag_string(meta) if meta else "" _score = str(score) if score else "" return { "rating_str": rating_str, "artist_str": artist_str, "character_str": character_str, "copyright_str": copyright_str, "general_str": general_str, "meta_str": meta_str, "score": _score, } def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str: """ Build a final comma-separated string (rating, artist, character, etc.). """ context = [] if tags_dict["rating_str"]: context.append(tags_dict["rating_str"]) if tags_dict["artist_str"] and add_artist_tags: context.append(f"artist:{tags_dict['artist_str']}") if tags_dict["character_str"]: context.append(tags_dict["character_str"]) if tags_dict["copyright_str"]: context.append(tags_dict["copyright_str"]) if tags_dict["general_str"]: context.append(tags_dict["general_str"]) return ", ".join(context) def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5, tags_front: str = "", tags_back: str = "", add_artist_tags: bool = True) -> list: filtered_df = df.iloc[start_idx:end_idx] captions = [] for _, row in filtered_df.iterrows(): tags = get_tags_dict(row) caption_base = build_tags_from_tags_dict(tags, add_artist_tags) # Combine front, base, back pieces = [part for part in [tags_front, caption_base, tags_back] if part] final_caption = ", ".join(pieces) captions.append(final_caption) return captions def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list: filtered_df = df.iloc[start_idx:end_idx] return [row["large_file_url"] for _, row in filtered_df.iterrows()] def gradio_interface( dataset_id: str, start_idx: int = 0, display_count: int = 5, tags_front: str = "", tags_back: str = "", add_artist_tags: bool = True ): """ 1) Loads dataset if needed 2) Returns (DataFrame, Gallery, InfoMessage) """ # 1) Possibly reload load_dataset_if_needed(dataset_id) dset_df = CURRENT_DATASET["df"] if dset_df is None: return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}" # 2) Figure out total length, clamp inputs total_len = len(dset_df) if total_len == 0: return pd.DataFrame(), [], f"Dataset {dataset_id} is empty." start_idx = max(start_idx, 0) if start_idx >= total_len: start_idx = total_len - 1 end_idx = start_idx + display_count if end_idx > total_len: end_idx = total_len # 3) Build results idxs = range(start_idx, end_idx) captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags) previews = get_previews_for_rows(dset_df, start_idx, end_idx) df_out = pd.DataFrame({"index": idxs, "Captions": captions}) # 4) Build info string info_msg = ( f"**Current dataset:** {CURRENT_DATASET['id']} \n" f"**Dataset length:** {total_len} \n" f"**start_idx:** {start_idx}, **display_count:** {display_count}, " f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', " f"**add_artist_tags:** {add_artist_tags}" ) return df_out, previews, info_msg with gr.Blocks() as demo: gr.Markdown("## Danbooru2025 Dataset Captions and Previews") with gr.Row(): with gr.Column(scale=1): dataset_id_input = gr.Textbox( value="dataproc5/test-danbooru2025-tag-balanced-2k", label="Dataset ID" ) start_idx_input = gr.Number(value=500, label="Start Index") display_count_input = gr.Slider( value=5, minimum=1, maximum=50, step=1, label="Number of Items" ) tags_front_input = gr.Textbox(value="", label="Tags Front") tags_back_input = gr.Textbox(value="", label="Tags Back") add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True) run_button = gr.Button("Get Captions & Previews") with gr.Column(scale=2): captions_df_out = gr.DataFrame(label="Captions") previews_gallery_out = gr.Gallery(label="Previews", type="filepath") info_textbox_out = gr.Markdown(value="") run_button.click( fn=gradio_interface, inputs=[ dataset_id_input, start_idx_input, display_count_input, tags_front_input, tags_back_input, add_artist_tags_input ], outputs=[ captions_df_out, previews_gallery_out, info_textbox_out ] ) demo.launch()