Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import pandas as pd
|
3 |
+
import gradio as gr
|
4 |
+
from typing import Dict, Optional
|
5 |
+
|
6 |
+
import unibox as ub
|
7 |
+
|
8 |
+
# Store current dataset in a global dict so it persists across Gradio calls.
|
9 |
+
CURRENT_DATASET = {
|
10 |
+
"id": None,
|
11 |
+
"df": None
|
12 |
+
}
|
13 |
+
|
14 |
+
rating_map = {
|
15 |
+
"g": "general",
|
16 |
+
"s": "sensitive",
|
17 |
+
"q": "questionable",
|
18 |
+
"e": "explicit"
|
19 |
+
}
|
20 |
+
|
21 |
+
def load_dataset_if_needed(dataset_id: str):
|
22 |
+
"""
|
23 |
+
Checks if dataset_id is different from what's currently loaded.
|
24 |
+
If so, loads from HF again and updates CURRENT_DATASET.
|
25 |
+
"""
|
26 |
+
if CURRENT_DATASET["id"] != dataset_id:
|
27 |
+
df = ub.loads(f"hf://{dataset_id}").to_pandas()
|
28 |
+
CURRENT_DATASET["id"] = dataset_id
|
29 |
+
CURRENT_DATASET["df"] = df
|
30 |
+
|
31 |
+
|
32 |
+
def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str:
|
33 |
+
"""
|
34 |
+
1girl long_hair blush -> 1girl, long_hair, blush
|
35 |
+
"""
|
36 |
+
tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i]
|
37 |
+
if shuffle:
|
38 |
+
random.shuffle(tags_list)
|
39 |
+
return ", ".join(tags_list)
|
40 |
+
|
41 |
+
|
42 |
+
def get_tags_dict(df_row: pd.Series) -> dict:
|
43 |
+
"""
|
44 |
+
Returns a dict with rating/artist/character/copyright/general/meta
|
45 |
+
plus numeric score.
|
46 |
+
"""
|
47 |
+
rating = df_row["rating"]
|
48 |
+
artist = df_row["tag_string_artist"]
|
49 |
+
character = df_row["tag_string_character"]
|
50 |
+
copyright_ = df_row["tag_string_copyright"]
|
51 |
+
general = df_row["tag_string_general"]
|
52 |
+
meta = df_row["tag_string_meta"]
|
53 |
+
score = df_row["score"]
|
54 |
+
|
55 |
+
rating_str = rating_map.get(rating, "")
|
56 |
+
artist_str = artist if artist else ""
|
57 |
+
character_str = convert_dbr_tag_string(character) if character else ""
|
58 |
+
copyright_str = f"copyright:{copyright_}" if copyright_ else ""
|
59 |
+
general_str = convert_dbr_tag_string(general) if general else ""
|
60 |
+
meta_str = convert_dbr_tag_string(meta) if meta else ""
|
61 |
+
_score = str(score) if score else ""
|
62 |
+
|
63 |
+
return {
|
64 |
+
"rating_str": rating_str,
|
65 |
+
"artist_str": artist_str,
|
66 |
+
"character_str": character_str,
|
67 |
+
"copyright_str": copyright_str,
|
68 |
+
"general_str": general_str,
|
69 |
+
"meta_str": meta_str,
|
70 |
+
"score": _score,
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str:
|
75 |
+
"""
|
76 |
+
Build a final comma-separated string (rating, artist, character, etc.).
|
77 |
+
"""
|
78 |
+
context = []
|
79 |
+
|
80 |
+
if tags_dict["rating_str"]:
|
81 |
+
context.append(tags_dict["rating_str"])
|
82 |
+
|
83 |
+
if tags_dict["artist_str"] and add_artist_tags:
|
84 |
+
context.append(f"artist:{tags_dict['artist_str']}")
|
85 |
+
|
86 |
+
if tags_dict["character_str"]:
|
87 |
+
context.append(tags_dict["character_str"])
|
88 |
+
|
89 |
+
if tags_dict["copyright_str"]:
|
90 |
+
context.append(tags_dict["copyright_str"])
|
91 |
+
|
92 |
+
if tags_dict["general_str"]:
|
93 |
+
context.append(tags_dict["general_str"])
|
94 |
+
|
95 |
+
return ", ".join(context)
|
96 |
+
|
97 |
+
|
98 |
+
def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5,
|
99 |
+
tags_front: str = "", tags_back: str = "",
|
100 |
+
add_artist_tags: bool = True) -> list:
|
101 |
+
filtered_df = df.iloc[start_idx:end_idx]
|
102 |
+
captions = []
|
103 |
+
for _, row in filtered_df.iterrows():
|
104 |
+
tags = get_tags_dict(row)
|
105 |
+
caption_base = build_tags_from_tags_dict(tags, add_artist_tags)
|
106 |
+
# Combine front, base, back
|
107 |
+
pieces = [part for part in [tags_front, caption_base, tags_back] if part]
|
108 |
+
final_caption = ", ".join(pieces)
|
109 |
+
captions.append(final_caption)
|
110 |
+
return captions
|
111 |
+
|
112 |
+
|
113 |
+
def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list:
|
114 |
+
filtered_df = df.iloc[start_idx:end_idx]
|
115 |
+
return [row["large_file_url"] for _, row in filtered_df.iterrows()]
|
116 |
+
|
117 |
+
|
118 |
+
def gradio_interface(
|
119 |
+
dataset_id: str,
|
120 |
+
start_idx: int = 0,
|
121 |
+
display_count: int = 5,
|
122 |
+
tags_front: str = "",
|
123 |
+
tags_back: str = "",
|
124 |
+
add_artist_tags: bool = True
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
1) Loads dataset if needed
|
128 |
+
2) Returns (DataFrame, Gallery, InfoMessage)
|
129 |
+
"""
|
130 |
+
# 1) Possibly reload
|
131 |
+
load_dataset_if_needed(dataset_id)
|
132 |
+
dset_df = CURRENT_DATASET["df"]
|
133 |
+
if dset_df is None:
|
134 |
+
return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}"
|
135 |
+
|
136 |
+
# 2) Figure out total length, clamp inputs
|
137 |
+
total_len = len(dset_df)
|
138 |
+
if total_len == 0:
|
139 |
+
return pd.DataFrame(), [], f"Dataset {dataset_id} is empty."
|
140 |
+
|
141 |
+
start_idx = max(start_idx, 0)
|
142 |
+
if start_idx >= total_len:
|
143 |
+
start_idx = total_len - 1
|
144 |
+
|
145 |
+
end_idx = start_idx + display_count
|
146 |
+
if end_idx > total_len:
|
147 |
+
end_idx = total_len
|
148 |
+
|
149 |
+
# 3) Build results
|
150 |
+
idxs = range(start_idx, end_idx)
|
151 |
+
captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags)
|
152 |
+
previews = get_previews_for_rows(dset_df, start_idx, end_idx)
|
153 |
+
df_out = pd.DataFrame({"index": idxs, "Captions": captions})
|
154 |
+
|
155 |
+
# 4) Build info string
|
156 |
+
info_msg = (
|
157 |
+
f"**Current dataset:** {CURRENT_DATASET['id']} \n"
|
158 |
+
f"**Dataset length:** {total_len} \n"
|
159 |
+
f"**start_idx:** {start_idx}, **display_count:** {display_count}, "
|
160 |
+
f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', "
|
161 |
+
f"**add_artist_tags:** {add_artist_tags}"
|
162 |
+
)
|
163 |
+
|
164 |
+
return df_out, previews, info_msg
|
165 |
+
|
166 |
+
|
167 |
+
with gr.Blocks() as demo:
|
168 |
+
gr.Markdown("## Danbooru2025 Dataset Captions and Previews")
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column(scale=1):
|
172 |
+
dataset_id_input = gr.Textbox(
|
173 |
+
value="dataproc5/test-danbooru2025-tag-balanced-2k",
|
174 |
+
label="Dataset ID"
|
175 |
+
)
|
176 |
+
start_idx_input = gr.Number(value=0, label="Start Index")
|
177 |
+
display_count_input = gr.Slider(
|
178 |
+
value=5, minimum=1, maximum=50, step=1,
|
179 |
+
label="Number of Items"
|
180 |
+
)
|
181 |
+
tags_front_input = gr.Textbox(value="", label="Tags Front")
|
182 |
+
tags_back_input = gr.Textbox(value="", label="Tags Back")
|
183 |
+
add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True)
|
184 |
+
|
185 |
+
run_button = gr.Button("Get Captions & Previews")
|
186 |
+
|
187 |
+
with gr.Column(scale=2):
|
188 |
+
captions_df_out = gr.DataFrame(label="Captions")
|
189 |
+
previews_gallery_out = gr.Gallery(label="Previews", type="filepath")
|
190 |
+
info_textbox_out = gr.Markdown(value="")
|
191 |
+
|
192 |
+
run_button.click(
|
193 |
+
fn=gradio_interface,
|
194 |
+
inputs=[
|
195 |
+
dataset_id_input,
|
196 |
+
start_idx_input,
|
197 |
+
display_count_input,
|
198 |
+
tags_front_input,
|
199 |
+
tags_back_input,
|
200 |
+
add_artist_tags_input
|
201 |
+
],
|
202 |
+
outputs=[
|
203 |
+
captions_df_out,
|
204 |
+
previews_gallery_out,
|
205 |
+
info_textbox_out
|
206 |
+
]
|
207 |
+
)
|
208 |
+
|
209 |
+
demo.launch()
|