trojblue commited on
Commit
a454c92
·
verified ·
1 Parent(s): 6fac094

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from typing import Dict, Optional
5
+
6
+ import unibox as ub
7
+
8
+ # Store current dataset in a global dict so it persists across Gradio calls.
9
+ CURRENT_DATASET = {
10
+ "id": None,
11
+ "df": None
12
+ }
13
+
14
+ rating_map = {
15
+ "g": "general",
16
+ "s": "sensitive",
17
+ "q": "questionable",
18
+ "e": "explicit"
19
+ }
20
+
21
+ def load_dataset_if_needed(dataset_id: str):
22
+ """
23
+ Checks if dataset_id is different from what's currently loaded.
24
+ If so, loads from HF again and updates CURRENT_DATASET.
25
+ """
26
+ if CURRENT_DATASET["id"] != dataset_id:
27
+ df = ub.loads(f"hf://{dataset_id}").to_pandas()
28
+ CURRENT_DATASET["id"] = dataset_id
29
+ CURRENT_DATASET["df"] = df
30
+
31
+
32
+ def convert_dbr_tag_string(tag_string: str, shuffle: bool = True) -> str:
33
+ """
34
+ 1girl long_hair blush -> 1girl, long_hair, blush
35
+ """
36
+ tags_list = [i.replace("_", " ") for i in tag_string.split(" ") if i]
37
+ if shuffle:
38
+ random.shuffle(tags_list)
39
+ return ", ".join(tags_list)
40
+
41
+
42
+ def get_tags_dict(df_row: pd.Series) -> dict:
43
+ """
44
+ Returns a dict with rating/artist/character/copyright/general/meta
45
+ plus numeric score.
46
+ """
47
+ rating = df_row["rating"]
48
+ artist = df_row["tag_string_artist"]
49
+ character = df_row["tag_string_character"]
50
+ copyright_ = df_row["tag_string_copyright"]
51
+ general = df_row["tag_string_general"]
52
+ meta = df_row["tag_string_meta"]
53
+ score = df_row["score"]
54
+
55
+ rating_str = rating_map.get(rating, "")
56
+ artist_str = artist if artist else ""
57
+ character_str = convert_dbr_tag_string(character) if character else ""
58
+ copyright_str = f"copyright:{copyright_}" if copyright_ else ""
59
+ general_str = convert_dbr_tag_string(general) if general else ""
60
+ meta_str = convert_dbr_tag_string(meta) if meta else ""
61
+ _score = str(score) if score else ""
62
+
63
+ return {
64
+ "rating_str": rating_str,
65
+ "artist_str": artist_str,
66
+ "character_str": character_str,
67
+ "copyright_str": copyright_str,
68
+ "general_str": general_str,
69
+ "meta_str": meta_str,
70
+ "score": _score,
71
+ }
72
+
73
+
74
+ def build_tags_from_tags_dict(tags_dict: dict, add_artist_tags: bool = True) -> str:
75
+ """
76
+ Build a final comma-separated string (rating, artist, character, etc.).
77
+ """
78
+ context = []
79
+
80
+ if tags_dict["rating_str"]:
81
+ context.append(tags_dict["rating_str"])
82
+
83
+ if tags_dict["artist_str"] and add_artist_tags:
84
+ context.append(f"artist:{tags_dict['artist_str']}")
85
+
86
+ if tags_dict["character_str"]:
87
+ context.append(tags_dict["character_str"])
88
+
89
+ if tags_dict["copyright_str"]:
90
+ context.append(tags_dict["copyright_str"])
91
+
92
+ if tags_dict["general_str"]:
93
+ context.append(tags_dict["general_str"])
94
+
95
+ return ", ".join(context)
96
+
97
+
98
+ def get_captions_for_rows(df, start_idx: int = 0, end_idx: int = 5,
99
+ tags_front: str = "", tags_back: str = "",
100
+ add_artist_tags: bool = True) -> list:
101
+ filtered_df = df.iloc[start_idx:end_idx]
102
+ captions = []
103
+ for _, row in filtered_df.iterrows():
104
+ tags = get_tags_dict(row)
105
+ caption_base = build_tags_from_tags_dict(tags, add_artist_tags)
106
+ # Combine front, base, back
107
+ pieces = [part for part in [tags_front, caption_base, tags_back] if part]
108
+ final_caption = ", ".join(pieces)
109
+ captions.append(final_caption)
110
+ return captions
111
+
112
+
113
+ def get_previews_for_rows(df: pd.DataFrame, start_idx: int = 0, end_idx: int = 5) -> list:
114
+ filtered_df = df.iloc[start_idx:end_idx]
115
+ return [row["large_file_url"] for _, row in filtered_df.iterrows()]
116
+
117
+
118
+ def gradio_interface(
119
+ dataset_id: str,
120
+ start_idx: int = 0,
121
+ display_count: int = 5,
122
+ tags_front: str = "",
123
+ tags_back: str = "",
124
+ add_artist_tags: bool = True
125
+ ):
126
+ """
127
+ 1) Loads dataset if needed
128
+ 2) Returns (DataFrame, Gallery, InfoMessage)
129
+ """
130
+ # 1) Possibly reload
131
+ load_dataset_if_needed(dataset_id)
132
+ dset_df = CURRENT_DATASET["df"]
133
+ if dset_df is None:
134
+ return pd.DataFrame(), [], f"ERROR: Could not load dataset {dataset_id}"
135
+
136
+ # 2) Figure out total length, clamp inputs
137
+ total_len = len(dset_df)
138
+ if total_len == 0:
139
+ return pd.DataFrame(), [], f"Dataset {dataset_id} is empty."
140
+
141
+ start_idx = max(start_idx, 0)
142
+ if start_idx >= total_len:
143
+ start_idx = total_len - 1
144
+
145
+ end_idx = start_idx + display_count
146
+ if end_idx > total_len:
147
+ end_idx = total_len
148
+
149
+ # 3) Build results
150
+ idxs = range(start_idx, end_idx)
151
+ captions = get_captions_for_rows(dset_df, start_idx, end_idx, tags_front, tags_back, add_artist_tags)
152
+ previews = get_previews_for_rows(dset_df, start_idx, end_idx)
153
+ df_out = pd.DataFrame({"index": idxs, "Captions": captions})
154
+
155
+ # 4) Build info string
156
+ info_msg = (
157
+ f"**Current dataset:** {CURRENT_DATASET['id']} \n"
158
+ f"**Dataset length:** {total_len} \n"
159
+ f"**start_idx:** {start_idx}, **display_count:** {display_count}, "
160
+ f"**tags_front:** '{tags_front}', **tags_back:** '{tags_back}', "
161
+ f"**add_artist_tags:** {add_artist_tags}"
162
+ )
163
+
164
+ return df_out, previews, info_msg
165
+
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.Markdown("## Danbooru2025 Dataset Captions and Previews")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ dataset_id_input = gr.Textbox(
173
+ value="dataproc5/test-danbooru2025-tag-balanced-2k",
174
+ label="Dataset ID"
175
+ )
176
+ start_idx_input = gr.Number(value=0, label="Start Index")
177
+ display_count_input = gr.Slider(
178
+ value=5, minimum=1, maximum=50, step=1,
179
+ label="Number of Items"
180
+ )
181
+ tags_front_input = gr.Textbox(value="", label="Tags Front")
182
+ tags_back_input = gr.Textbox(value="", label="Tags Back")
183
+ add_artist_tags_input = gr.Checkbox(label="Add artist tags", value=True)
184
+
185
+ run_button = gr.Button("Get Captions & Previews")
186
+
187
+ with gr.Column(scale=2):
188
+ captions_df_out = gr.DataFrame(label="Captions")
189
+ previews_gallery_out = gr.Gallery(label="Previews", type="filepath")
190
+ info_textbox_out = gr.Markdown(value="")
191
+
192
+ run_button.click(
193
+ fn=gradio_interface,
194
+ inputs=[
195
+ dataset_id_input,
196
+ start_idx_input,
197
+ display_count_input,
198
+ tags_front_input,
199
+ tags_back_input,
200
+ add_artist_tags_input
201
+ ],
202
+ outputs=[
203
+ captions_df_out,
204
+ previews_gallery_out,
205
+ info_textbox_out
206
+ ]
207
+ )
208
+
209
+ demo.launch()