trojblue's picture
Update app.py
c34122e verified
import random
import pandas as pd
import gradio as gr
from typing import Dict, Optional
import unibox as ub
# Store current dataset in a global dict so it persists across Gradio calls.
CURRENT_DATASET = {
"id": None,
"df": None
}
rating_map = {
"g": "general",
"s": "sensitive",
"q": "questionable",
"e": "explicit"
}
def load_dataset_if_needed(dataset_id: str):
"""
Checks if dataset_id is different from what's currently loaded.
If so, loads from HF again and updates CURRENT_DATASET.
"""
if CURRENT_DATASET["id"] != dataset_id:
df = ub.loads(f"hf://{dataset_id}").to_pandas()
CURRENT_DATASET["id"] = dataset_id
CURRENT_DATASET["df"] = df
def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str:
"""
1girl long_hair blush -> 1girl, long_hair, blush
"""
tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i]
if shuffle:
random.shuffle(tags_list)
return ", ".join(tags_list)
def get_tags_dict(df_row: pd.Series) -> dict:
"""
Returns a dict with rating/artist/character/copyright/general/meta
plus numeric score.
"""
rating = df_row["rating"]
artist = df_row["tag_string_artist"]
character = df_row["tag_string_character"]
copyright_ = df_row["tag_string_copyright"]
general = df_row["tag_string_general"]
meta = df_row["tag_string_meta"]
score = df_row["score"]
rating_str = rating_map.get(rating, "")
artist_str = artist if artist else ""
character_str = convert_dbr_tag_string(character) if character else ""
copyright_str = f"copyright:{copyright_}" if copyright_ else ""
general_str = convert_dbr_tag_string(general) if general else ""
meta_str = convert_dbr_tag_string(meta) if meta else ""
_score = str(score) if score else ""
return {
"rating_str": rating_str,
"artist_str": artist_str,
"character_str": character_str,
"copyright_str": copyright_str,
"general_str": general_str,
"meta_str": meta_str,
"score": _score,
}
def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str:
"""
Build a final comma-separated string (rating, artist, character, etc.).
"""
context = []
if tags_dict["rating_str"]:
context.append(tags_dict["rating_str"])
if tags_dict["artist_str"] and add_artist_tags:
context.append(f"artist:{tags_dict['artist_str']}")
if tags_dict["character_str"]:
context.append(tags_dict["character_str"])
if tags_dict["copyright_str"]:
context.append(tags_dict["copyright_str"])
if tags_dict["general_str"]:
context.append(tags_dict["general_str"])
return ", ".join(context)
def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5,
tags_front: str = "", tags_back: str = "",
add_artist_tags: bool = True) -> list:
filtered_df = df.iloc[start_idx:end_idx]
captions = []
for _, row in filtered_df.iterrows():
tags = get_tags_dict(row)
caption_base = build_tags_from_tags_dict(tags, add_artist_tags)
# Combine front, base, back
pieces = [part for part in [tags_front, caption_base, tags_back] if part]
final_caption = ", ".join(pieces)
captions.append(final_caption)
return captions
def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list:
filtered_df = df.iloc[start_idx:end_idx]
return [row["large_file_url"] for _, row in filtered_df.iterrows()]
def gradio_interface(
dataset_id: str,
start_idx: int = 0,
display_count: int = 5,
tags_front: str = "",
tags_back: str = "",
add_artist_tags: bool = True
):
"""
1) Loads dataset if needed
2) Returns (DataFrame, Gallery, InfoMessage)
"""
# 1) Possibly reload
load_dataset_if_needed(dataset_id)
dset_df = CURRENT_DATASET["df"]
if dset_df is None:
return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}"
# 2) Figure out total length, clamp inputs
total_len = len(dset_df)
if total_len == 0:
return pd.DataFrame(), [], f"Dataset {dataset_id} is empty."
start_idx = max(start_idx, 0)
if start_idx >= total_len:
start_idx = total_len - 1
end_idx = start_idx + display_count
if end_idx > total_len:
end_idx = total_len
# 3) Build results
idxs = range(start_idx, end_idx)
captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags)
previews = get_previews_for_rows(dset_df, start_idx, end_idx)
df_out = pd.DataFrame({"index": idxs, "Captions": captions})
# 4) Build info string
info_msg = (
f"**Current dataset:** {CURRENT_DATASET['id']} \n"
f"**Dataset length:** {total_len} \n"
f"**start_idx:** {start_idx}, **display_count:** {display_count}, "
f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', "
f"**add_artist_tags:** {add_artist_tags}"
)
return df_out, previews, info_msg
with gr.Blocks() as demo:
gr.Markdown("## Danbooru2025 Dataset Captions and Previews")
with gr.Row():
with gr.Column(scale=1):
dataset_id_input = gr.Textbox(
value="dataproc5/test-danbooru2025-tag-balanced-2k",
label="Dataset ID"
)
start_idx_input = gr.Number(value=500, label="Start Index")
display_count_input = gr.Slider(
value=5, minimum=1, maximum=50, step=1,
label="Number of Items"
)
tags_front_input = gr.Textbox(value="", label="Tags Front")
tags_back_input = gr.Textbox(value="", label="Tags Back")
add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True)
run_button = gr.Button("Get Captions & Previews")
with gr.Column(scale=2):
captions_df_out = gr.DataFrame(label="Captions")
previews_gallery_out = gr.Gallery(label="Previews", type="filepath")
info_textbox_out = gr.Markdown(value="")
run_button.click(
fn=gradio_interface,
inputs=[
dataset_id_input,
start_idx_input,
display_count_input,
tags_front_input,
tags_back_input,
add_artist_tags_input
],
outputs=[
captions_df_out,
previews_gallery_out,
info_textbox_out
]
)
demo.launch()