Werli commited on
Commit
415a9c8
·
verified ·
1 Parent(s): 6990f87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1102 -1102
app.py CHANGED
@@ -1,1103 +1,1103 @@
1
- import os
2
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3
- import io
4
- import copy
5
- import requests
6
- import numpy as np
7
- import spaces
8
- import gradio as gr
9
- from transformers import AutoProcessor, AutoModelForCausalLM
10
- from transformers import AutoModelForCausalLM, AutoProcessor
11
- from transformers.dynamic_module_utils import get_imports
12
- from PIL import Image, ImageDraw, ImageFont
13
- import matplotlib.pyplot as plt
14
- import matplotlib.patches as patches
15
- from unittest.mock import patch
16
-
17
- import argparse
18
- import huggingface_hub
19
- import onnxruntime as rt
20
- import pandas as pd
21
- import traceback
22
- import tempfile
23
- import zipfile
24
- import re
25
- import ast
26
- import time
27
- from datetime import datetime, timezone
28
- from collections import defaultdict
29
- from classifyTags import classify_tags
30
- # Add scheduler code here
31
- from apscheduler.schedulers.background import BackgroundScheduler
32
-
33
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
- def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
35
- """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
36
- if not str(filename).endswith("/modeling_florence2.py"):
37
- return get_imports(filename)
38
- imports = get_imports(filename)
39
- if "flash_attn" in imports:
40
- imports.remove("flash_attn")
41
- return imports
42
-
43
- @spaces.GPU
44
- def get_device_type():
45
- import torch
46
- if torch.cuda.is_available():
47
- return "cuda"
48
- else:
49
- if (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
50
- return "mps"
51
- else:
52
- return "cpu"
53
-
54
- model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
55
-
56
- import subprocess
57
- device = get_device_type()
58
- if (device == "cuda"):
59
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
60
- model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
61
- processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
62
- model.to(device)
63
- else:
64
- #https://huggingface.co/microsoft/Florence-2-base-ft/discussions/4
65
- with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
66
- model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
67
- processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
68
- model.to(device)
69
-
70
- TITLE = "Multi-Tagger"
71
- DESCRIPTION = """
72
- Multi-Tagger is a powerful and versatile application that integrates two cutting-edge models: Waifu Diffusion and Florence 2. This app is designed to provide comprehensive image analysis and captioning capabilities, making it a valuable tool for AI artists, researchers, and enthusiasts.
73
-
74
- Features:
75
- - Supports batch processing of multiple images.
76
- - Tags images with multiple categories: general tags, character tags, and ratings.
77
- - Tags are categorized into groups (e.g., general, characters, ratings).
78
- - Displays categorized tags in a structured format.
79
- - Integrates Llama3 models to reorganize the tags into a readable English article.
80
- - Includes a separate tab for image captioning using Florence 2.
81
- - Florence 2 supports CUDA, MPS or CPU if one of them is available.
82
- - Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection).
83
- - Displays output text and images for tasks that generate visual outputs.
84
- - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process.
85
-
86
- Example image by [me.](https://huggingface.co/Werli)
87
- """
88
- colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
89
- 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
90
-
91
- # Dataset v3 series of models:
92
- SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
93
- CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
94
- VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
95
- VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
96
- EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
97
-
98
- # Dataset v2 series of models:
99
- MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
100
- SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
101
- CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
102
- CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
103
- VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
104
-
105
- # IdolSankaku series of models:
106
- EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
107
- SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
108
-
109
- # Files to download from the repos
110
- MODEL_FILENAME = "model.onnx"
111
- LABEL_FILENAME = "selected_tags.csv"
112
-
113
- # LLAMA model
114
- META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
115
- META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
116
-
117
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
118
- kaomojis = [
119
- "0_0",
120
- "(o)_(o)",
121
- "+_+",
122
- "+_-",
123
- "._.",
124
- "<o>_<o>",
125
- "<|>_<|>",
126
- "=_=",
127
- ">_<",
128
- "3_3",
129
- "6_9",
130
- ">_o",
131
- "@_@",
132
- "^_^",
133
- "o_o",
134
- "u_u",
135
- "x_x",
136
- "|_|",
137
- "||_||",
138
- ]
139
- def parse_args() -> argparse.Namespace:
140
- parser = argparse.ArgumentParser()
141
- parser.add_argument("--score-slider-step", type=float, default=0.05)
142
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
143
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
144
- parser.add_argument("--share", action="store_true")
145
- return parser.parse_args()
146
- def load_labels(dataframe) -> list[str]:
147
- name_series = dataframe["name"]
148
- name_series = name_series.map(
149
- lambda x: x.replace("_", " ") if x not in kaomojis else x
150
- )
151
- tag_names = name_series.tolist()
152
-
153
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
154
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
155
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
156
- return tag_names, rating_indexes, general_indexes, character_indexes
157
- def mcut_threshold(probs):
158
- """
159
- Maximum Cut Thresholding (MCut)
160
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
161
- for Multi-label Classification. In 11th International Symposium, IDA 2012
162
- (pp. 172-183).
163
- """
164
- sorted_probs = probs[probs.argsort()[::-1]]
165
- difs = sorted_probs[:-1] - sorted_probs[1:]
166
- t = difs.argmax()
167
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
168
- return thresh
169
- class Timer:
170
- def __init__(self):
171
- self.start_time = time.perf_counter() # Record the start time
172
- self.checkpoints = [("Start", self.start_time)] # Store checkpoints
173
-
174
- def checkpoint(self, label="Checkpoint"):
175
- """Record a checkpoint with a given label."""
176
- now = time.perf_counter()
177
- self.checkpoints.append((label, now))
178
-
179
- def report(self, is_clear_checkpoints = True):
180
- # Determine the max label width for alignment
181
- max_label_length = max(len(label) for label, _ in self.checkpoints)
182
-
183
- prev_time = self.checkpoints[0][1]
184
- for label, curr_time in self.checkpoints[1:]:
185
- elapsed = curr_time - prev_time
186
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
187
- prev_time = curr_time
188
-
189
- if is_clear_checkpoints:
190
- self.checkpoints.clear()
191
- self.checkpoint() # Store checkpoints
192
-
193
- def report_all(self):
194
- """Print all recorded checkpoints and total execution time with aligned formatting."""
195
- print("\n> Execution Time Report:")
196
-
197
- # Determine the max label width for alignment
198
- max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
199
-
200
- prev_time = self.start_time
201
- for label, curr_time in self.checkpoints[1:]:
202
- elapsed = curr_time - prev_time
203
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
204
- prev_time = curr_time
205
-
206
- total_time = self.checkpoints[-1][1] - self.start_time
207
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
208
-
209
- self.checkpoints.clear()
210
-
211
- def restart(self):
212
- self.start_time = time.perf_counter() # Record the start time
213
- self.checkpoints = [("Start", self.start_time)] # Store checkpoints
214
-
215
- class Llama3Reorganize:
216
- def __init__(
217
- self,
218
- repoId: str,
219
- device: str = None,
220
- loadModel: bool = False,
221
- ):
222
- """Initializes the Llama model.
223
-
224
- Args:
225
- repoId: LLAMA model repo.
226
- device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
227
- ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
228
- localFilesOnly: If True, avoid downloading the file and return the path to the
229
- local cached file if it exists.
230
- """
231
- self.modelPath = self.download_model(repoId)
232
-
233
- if device is None:
234
- import torch
235
- self.totalVram = 0
236
- if torch.cuda.is_available():
237
- try:
238
- deviceId = torch.cuda.current_device()
239
- self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
240
- except Exception as e:
241
- print(traceback.format_exc())
242
- print("Error detect vram: " + str(e))
243
- device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
244
- else:
245
- device = "cpu"
246
-
247
- self.device = device
248
- self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
249
-
250
- if loadModel:
251
- self.load_model()
252
-
253
- def download_model(self, repoId):
254
- import warnings
255
- import requests
256
- allowPatterns = [
257
- "config.json",
258
- "generation_config.json",
259
- "model.bin",
260
- "pytorch_model.bin",
261
- "pytorch_model.bin.index.json",
262
- "pytorch_model-*.bin",
263
- "sentencepiece.bpe.model",
264
- "tokenizer.json",
265
- "tokenizer_config.json",
266
- "shared_vocabulary.txt",
267
- "shared_vocabulary.json",
268
- "special_tokens_map.json",
269
- "spiece.model",
270
- "vocab.json",
271
- "model.safetensors",
272
- "model-*.safetensors",
273
- "model.safetensors.index.json",
274
- "quantize_config.json",
275
- "tokenizer.model",
276
- "vocabulary.json",
277
- "preprocessor_config.json",
278
- "added_tokens.json"
279
- ]
280
-
281
- kwargs = {"allow_patterns": allowPatterns,}
282
-
283
- try:
284
- return huggingface_hub.snapshot_download(repoId, **kwargs)
285
- except (
286
- huggingface_hub.utils.HfHubHTTPError,
287
- requests.exceptions.ConnectionError,
288
- ) as exception:
289
- warnings.warn(
290
- "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
291
- repoId,
292
- exception,
293
- )
294
- warnings.warn(
295
- "Trying to load the model directly from the local cache, if it exists."
296
- )
297
-
298
- kwargs["local_files_only"] = True
299
- return huggingface_hub.snapshot_download(repoId, **kwargs)
300
-
301
-
302
- def load_model(self):
303
- import ctranslate2
304
- import transformers
305
- try:
306
- print('\n\nLoading model: %s\n\n' % self.modelPath)
307
- kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
308
- kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
309
- self.roleSystem = {"role": "system", "content": self.system_prompt}
310
- self.Model = ctranslate2.Generator(**kwargsModel)
311
-
312
- self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
313
- self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
314
-
315
- except Exception as e:
316
- self.release_vram()
317
- raise e
318
-
319
-
320
- def release_vram(self):
321
- try:
322
- import torch
323
- if torch.cuda.is_available():
324
- if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
325
- self.Model.unload_model()
326
-
327
- if getattr(self, "Tokenizer", None) is not None:
328
- del self.Tokenizer
329
- if getattr(self, "Model", None) is not None:
330
- del self.Model
331
- import gc
332
- gc.collect()
333
- try:
334
- torch.cuda.empty_cache()
335
- except Exception as e:
336
- print(traceback.format_exc())
337
- print("\tcuda empty cache, error: " + str(e))
338
- print("release vram end.")
339
- except Exception as e:
340
- print(traceback.format_exc())
341
- print("Error release vram: " + str(e))
342
-
343
- def reorganize(self, text: str, max_length: int = 400):
344
- output = None
345
- result = None
346
- try:
347
- input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
348
- source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
349
- output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
350
- target = output[0]
351
- result = self.Tokenizer.decode(target.sequences_ids[0])
352
-
353
- if len(result) > 2:
354
- if result[0] == "\"" and result[len(result) - 1] == "\"":
355
- result = result[1:-1]
356
- elif result[0] == "'" and result[len(result) - 1] == "'":
357
- result = result[1:-1]
358
- elif result[0] == "「" and result[len(result) - 1] == "」":
359
- result = result[1:-1]
360
- elif result[0] == "『" and result[len(result) - 1] == "』":
361
- result = result[1:-1]
362
- except Exception as e:
363
- print(traceback.format_exc())
364
- print("Error reorganize text: " + str(e))
365
-
366
- return result
367
-
368
-
369
- class Predictor:
370
- def __init__(self):
371
- self.model_target_size = None
372
- self.last_loaded_repo = None
373
- def download_model(self, model_repo):
374
- csv_path = huggingface_hub.hf_hub_download(
375
- model_repo,
376
- LABEL_FILENAME,
377
- )
378
- model_path = huggingface_hub.hf_hub_download(
379
- model_repo,
380
- MODEL_FILENAME,
381
- )
382
- return csv_path, model_path
383
- def load_model(self, model_repo):
384
- if model_repo == self.last_loaded_repo:
385
- return
386
-
387
- csv_path, model_path = self.download_model(model_repo)
388
-
389
- tags_df = pd.read_csv(csv_path)
390
- sep_tags = load_labels(tags_df)
391
-
392
- self.tag_names = sep_tags[0]
393
- self.rating_indexes = sep_tags[1]
394
- self.general_indexes = sep_tags[2]
395
- self.character_indexes = sep_tags[3]
396
-
397
- model = rt.InferenceSession(model_path)
398
- _, height, width, _ = model.get_inputs()[0].shape
399
- self.model_target_size = height
400
-
401
- self.last_loaded_repo = model_repo
402
- self.model = model
403
- def prepare_image(self, path):
404
- image = Image.open(path)
405
- image = image.convert("RGBA")
406
- target_size = self.model_target_size
407
-
408
- canvas = Image.new("RGBA", image.size, (255, 255, 255))
409
- canvas.alpha_composite(image)
410
- image = canvas.convert("RGB")
411
-
412
- # Pad image to square
413
- image_shape = image.size
414
- max_dim = max(image_shape)
415
- pad_left = (max_dim - image_shape[0]) // 2
416
- pad_top = (max_dim - image_shape[1]) // 2
417
-
418
- padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
419
- padded_image.paste(image, (pad_left, pad_top))
420
-
421
- # Resize
422
- if max_dim != target_size:
423
- padded_image = padded_image.resize(
424
- (target_size, target_size),
425
- Image.BICUBIC,
426
- )
427
- # Convert to numpy array
428
- image_array = np.asarray(padded_image, dtype=np.float32)
429
-
430
- # Convert PIL-native RGB to BGR
431
- image_array = image_array[:, :, ::-1]
432
-
433
- return np.expand_dims(image_array, axis=0)
434
-
435
- def create_file(self, text: str, directory: str, fileName: str) -> str:
436
- # Write the text to a file
437
- with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
438
- file.write(text)
439
-
440
- return file.name
441
-
442
- def predict(
443
- self,
444
- gallery,
445
- model_repo,
446
- general_thresh,
447
- general_mcut_enabled,
448
- character_thresh,
449
- character_mcut_enabled,
450
- characters_merge_enabled,
451
- llama3_reorganize_model_repo,
452
- additional_tags_prepend,
453
- additional_tags_append,
454
- tag_results,
455
- progress=gr.Progress()
456
- ):
457
- gallery_len = len(gallery)
458
- print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
459
-
460
- timer = Timer() # Create a timer
461
- progressRatio = 0.5 if llama3_reorganize_model_repo else 1
462
- progressTotal = gallery_len + 1
463
- current_progress = 0
464
-
465
- self.load_model(model_repo)
466
- current_progress += progressRatio/progressTotal;
467
- progress(current_progress, desc="Initialize wd model finished")
468
- timer.checkpoint(f"Initialize wd model")
469
-
470
- # Result
471
- txt_infos = []
472
- output_dir = tempfile.mkdtemp()
473
- if not os.path.exists(output_dir):
474
- os.makedirs(output_dir)
475
-
476
- sorted_general_strings = ""
477
- rating = None
478
- character_res = None
479
- general_res = None
480
-
481
- if llama3_reorganize_model_repo:
482
- print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
483
- llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
484
- current_progress += progressRatio/progressTotal;
485
- progress(current_progress, desc="Initialize llama3 model finished")
486
- timer.checkpoint(f"Initialize llama3 model")
487
-
488
- timer.report()
489
-
490
- prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
491
- append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
492
- if prepend_list and append_list:
493
- append_list = [item for item in append_list if item not in prepend_list]
494
-
495
- # Dictionary to track counters for each filename
496
- name_counters = defaultdict(int)
497
- # New code to create categorized output string
498
- categorized_output_strings = []
499
- for idx, value in enumerate(gallery):
500
- try:
501
- image_path = value[0]
502
- image_name = os.path.splitext(os.path.basename(image_path))[0]
503
-
504
- # Increment the counter for the current name
505
- name_counters[image_name] += 1
506
-
507
- if name_counters[image_name] > 1:
508
- image_name = f"{image_name}_{name_counters[image_name]:02d}"
509
-
510
- image = self.prepare_image(image_path)
511
-
512
- input_name = self.model.get_inputs()[0].name
513
- label_name = self.model.get_outputs()[0].name
514
- print(f"Gallery {idx:02d}: Starting run wd model...")
515
- preds = self.model.run([label_name], {input_name: image})[0]
516
-
517
- labels = list(zip(self.tag_names, preds[0].astype(float)))
518
-
519
- # First 4 labels are actually ratings: pick one with argmax
520
- ratings_names = [labels[i] for i in self.rating_indexes]
521
- rating = dict(ratings_names)
522
-
523
- # Then we have general tags: pick any where prediction confidence > threshold
524
- general_names = [labels[i] for i in self.general_indexes]
525
-
526
- if general_mcut_enabled:
527
- general_probs = np.array([x[1] for x in general_names])
528
- general_thresh = mcut_threshold(general_probs)
529
-
530
- general_res = [x for x in general_names if x[1] > general_thresh]
531
- general_res = dict(general_res)
532
-
533
- # Everything else is characters: pick any where prediction confidence > threshold
534
- character_names = [labels[i] for i in self.character_indexes]
535
-
536
- if character_mcut_enabled:
537
- character_probs = np.array([x[1] for x in character_names])
538
- character_thresh = mcut_threshold(character_probs)
539
- character_thresh = max(0.15, character_thresh)
540
-
541
- character_res = [x for x in character_names if x[1] > character_thresh]
542
- character_res = dict(character_res)
543
- character_list = list(character_res.keys())
544
-
545
- sorted_general_list = sorted(
546
- general_res.items(),
547
- key=lambda x: x[1],
548
- reverse=True,
549
- )
550
- sorted_general_list = [x[0] for x in sorted_general_list]
551
- #Remove values from character_list that already exist in sorted_general_list
552
- character_list = [item for item in character_list if item not in sorted_general_list]
553
- #Remove values from sorted_general_list that already exist in prepend_list or append_list
554
- if prepend_list:
555
- sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
556
- if append_list:
557
- sorted_general_list = [item for item in sorted_general_list if item not in append_list]
558
-
559
- sorted_general_list = prepend_list + sorted_general_list + append_list
560
-
561
- sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
562
-
563
- classified_tags, unclassified_tags = classify_tags(sorted_general_list)
564
-
565
- # Create a single string of all categorized tags
566
- categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
567
- categorized_output_strings.append(categorized_output_string)
568
-
569
- current_progress += progressRatio/progressTotal;
570
- progress(current_progress, desc=f"image{idx:02d}, predict finished")
571
- timer.checkpoint(f"image{idx:02d}, predict finished")
572
-
573
- if llama3_reorganize_model_repo:
574
- print(f"Starting reorganize with llama3...")
575
- reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
576
- reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
577
- reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
578
- reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
579
- sorted_general_strings += "," + reorganize_strings
580
-
581
- current_progress += progressRatio/progressTotal;
582
- progress(current_progress, desc=f"image{idx:02d}, llama3 reorganize finished")
583
- timer.checkpoint(f"image{idx:02d}, llama3 reorganize finished")
584
-
585
- txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
586
- txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
587
-
588
- tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
589
- timer.report()
590
- except Exception as e:
591
- print(traceback.format_exc())
592
- print("Error predict: " + str(e))
593
- # Result
594
- download = []
595
- if txt_infos is not None and len(txt_infos) > 0:
596
- downloadZipPath = os.path.join(output_dir, "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
597
- with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
598
- for info in txt_infos:
599
- # Get file name from lookup
600
- taggers_zip.write(info["path"], arcname=info["name"])
601
- download.append(downloadZipPath)
602
-
603
- if llama3_reorganize_model_repo:
604
- llama3_reorganize.release_vram()
605
- del llama3_reorganize
606
-
607
- progress(1, desc=f"Predict completed")
608
- timer.report_all() # Print all recorded times
609
- print("Predict is complete.")
610
-
611
- # Collect all categorized output strings into a single string
612
- final_categorized_output = ', '.join(categorized_output_strings)
613
-
614
- return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results, final_categorized_output
615
- # END
616
-
617
- def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
618
- if not selected_state:
619
- return selected_state
620
-
621
- tag_result = { "strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}" }
622
- if selected_state.value["image"]["path"] in tag_results:
623
- tag_result = tag_results[selected_state.value["image"]["path"]]
624
-
625
- return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
626
-
627
- def append_gallery(gallery: list, image: str):
628
- if gallery is None:
629
- gallery = []
630
- if not image:
631
- return gallery, None
632
-
633
- gallery.append(image)
634
-
635
- return gallery, None
636
-
637
-
638
- def extend_gallery(gallery: list, images):
639
- if gallery is None:
640
- gallery = []
641
- if not images:
642
- return gallery
643
-
644
- # Combine the new images with the existing gallery images
645
- gallery.extend(images)
646
-
647
- return gallery
648
-
649
- def remove_image_from_gallery(gallery: list, selected_image: str):
650
- if not gallery or not selected_image:
651
- return gallery
652
-
653
- selected_image = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
654
- # Remove the selected image from the gallery
655
- if selected_image in gallery:
656
- gallery.remove(selected_image)
657
- return gallery
658
-
659
- # END
660
-
661
- def fig_to_pil(fig):
662
- buf = io.BytesIO()
663
- fig.savefig(buf, format='png')
664
- buf.seek(0)
665
- return Image.open(buf)
666
-
667
- @spaces.GPU
668
- def run_example(task_prompt, image, text_input=None):
669
- if text_input is None:
670
- prompt = task_prompt
671
- else:
672
- prompt = task_prompt + text_input
673
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
674
- generated_ids = model.generate(
675
- input_ids=inputs["input_ids"],
676
- pixel_values=inputs["pixel_values"],
677
- max_new_tokens=1024,
678
- early_stopping=False,
679
- do_sample=False,
680
- num_beams=3,
681
- )
682
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
683
- parsed_answer = processor.post_process_generation(
684
- generated_text,
685
- task=task_prompt,
686
- image_size=(image.width, image.height)
687
- )
688
- return parsed_answer
689
-
690
- def plot_bbox(image, data):
691
- fig, ax = plt.subplots()
692
- ax.imshow(image)
693
- for bbox, label in zip(data['bboxes'], data['labels']):
694
- x1, y1, x2, y2 = bbox
695
- rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
696
- ax.add_patch(rect)
697
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
698
- ax.axis('off')
699
- return fig
700
-
701
- def draw_polygons(image, prediction, fill_mask=False):
702
- draw = ImageDraw.Draw(image)
703
- scale = 1
704
- for polygons, label in zip(prediction['polygons'], prediction['labels']):
705
- color = random.choice(colormap)
706
- fill_color = random.choice(colormap) if fill_mask else None
707
- for _polygon in polygons:
708
- _polygon = np.array(_polygon).reshape(-1, 2)
709
- if len(_polygon) < 3:
710
- print('Invalid polygon:', _polygon)
711
- continue
712
- _polygon = (_polygon * scale).reshape(-1).tolist()
713
- if fill_mask:
714
- draw.polygon(_polygon, outline=color, fill=fill_color)
715
- else:
716
- draw.polygon(_polygon, outline=color)
717
- draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
718
- return image
719
-
720
- def convert_to_od_format(data):
721
- bboxes = data.get('bboxes', [])
722
- labels = data.get('bboxes_labels', [])
723
- od_results = {
724
- 'bboxes': bboxes,
725
- 'labels': labels
726
- }
727
- return od_results
728
-
729
- def draw_ocr_bboxes(image, prediction):
730
- scale = 1
731
- draw = ImageDraw.Draw(image)
732
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
733
- for box, label in zip(bboxes, labels):
734
- color = random.choice(colormap)
735
- new_box = (np.array(box) * scale).tolist()
736
- draw.polygon(new_box, width=3, outline=color)
737
- draw.text((new_box[0]+8, new_box[1]+2),
738
- "{}".format(label),
739
- align="right",
740
- fill=color)
741
- return image
742
-
743
- def convert_to_od_format(data):
744
- bboxes = data.get('bboxes', [])
745
- labels = data.get('bboxes_labels', [])
746
- od_results = {
747
- 'bboxes': bboxes,
748
- 'labels': labels
749
- }
750
- return od_results
751
-
752
- def draw_ocr_bboxes(image, prediction):
753
- scale = 1
754
- draw = ImageDraw.Draw(image)
755
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
756
- for box, label in zip(bboxes, labels):
757
- color = random.choice(colormap)
758
- new_box = (np.array(box) * scale).tolist()
759
- draw.polygon(new_box, width=3, outline=color)
760
- draw.text((new_box[0]+8, new_box[1]+2),
761
- "{}".format(label),
762
- align="right",
763
- fill=color)
764
- return image
765
- def process_image(image, task_prompt, text_input=None):
766
- # Test
767
- if isinstance(image, str): # If image is a file path
768
- image = Image.open(image) # Load image from file path
769
- else: # If image is a NumPy array
770
- image = Image.fromarray(image) # Convert NumPy array to PIL Image
771
- if task_prompt == 'Caption':
772
- task_prompt = '<CAPTION>'
773
- results = run_example(task_prompt, image)
774
- return results[task_prompt], None
775
- elif task_prompt == 'Detailed Caption':
776
- task_prompt = '<DETAILED_CAPTION>'
777
- results = run_example(task_prompt, image)
778
- return results[task_prompt], None
779
- elif task_prompt == 'More Detailed Caption':
780
- task_prompt = '<MORE_DETAILED_CAPTION>'
781
- results = run_example(task_prompt, image)
782
- return results[task_prompt], plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
783
- elif task_prompt == 'Caption + Grounding':
784
- task_prompt = '<CAPTION>'
785
- results = run_example(task_prompt, image)
786
- text_input = results[task_prompt]
787
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
788
- results = run_example(task_prompt, image, text_input)
789
- results['<CAPTION>'] = text_input
790
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
791
- return results, fig_to_pil(fig)
792
- elif task_prompt == 'Detailed Caption + Grounding':
793
- task_prompt = '<DETAILED_CAPTION>'
794
- results = run_example(task_prompt, image)
795
- text_input = results[task_prompt]
796
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
797
- results = run_example(task_prompt, image, text_input)
798
- results['<DETAILED_CAPTION>'] = text_input
799
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
800
- return results, fig_to_pil(fig)
801
- elif task_prompt == 'More Detailed Caption + Grounding':
802
- task_prompt = '<MORE_DETAILED_CAPTION>'
803
- results = run_example(task_prompt, image)
804
- text_input = results[task_prompt]
805
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
806
- results = run_example(task_prompt, image, text_input)
807
- results['<MORE_DETAILED_CAPTION>'] = text_input
808
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
809
- return results, fig_to_pil(fig)
810
- elif task_prompt == 'Object Detection':
811
- task_prompt = '<OD>'
812
- results = run_example(task_prompt, image)
813
- fig = plot_bbox(image, results['<OD>'])
814
- return results, fig_to_pil(fig)
815
- elif task_prompt == 'Dense Region Caption':
816
- task_prompt = '<DENSE_REGION_CAPTION>'
817
- results = run_example(task_prompt, image)
818
- fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
819
- return results, fig_to_pil(fig)
820
- elif task_prompt == 'Region Proposal':
821
- task_prompt = '<REGION_PROPOSAL>'
822
- results = run_example(task_prompt, image)
823
- fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
824
- return results, fig_to_pil(fig)
825
- elif task_prompt == 'Caption to Phrase Grounding':
826
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
827
- results = run_example(task_prompt, image, text_input)
828
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
829
- return results, fig_to_pil(fig)
830
- elif task_prompt == 'Referring Expression Segmentation':
831
- task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
832
- results = run_example(task_prompt, image, text_input)
833
- output_image = copy.deepcopy(image)
834
- output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
835
- return results, output_image
836
- elif task_prompt == 'Region to Segmentation':
837
- task_prompt = '<REGION_TO_SEGMENTATION>'
838
- results = run_example(task_prompt, image, text_input)
839
- output_image = copy.deepcopy(image)
840
- output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
841
- return results, output_image
842
- elif task_prompt == 'Open Vocabulary Detection':
843
- task_prompt = '<OPEN_VOCABULARY_DETECTION>'
844
- results = run_example(task_prompt, image, text_input)
845
- bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
846
- fig = plot_bbox(image, bbox_results)
847
- return results, fig_to_pil(fig)
848
- elif task_prompt == 'Region to Category':
849
- task_prompt = '<REGION_TO_CATEGORY>'
850
- results = run_example(task_prompt, image, text_input)
851
- return results, None
852
- elif task_prompt == 'Region to Description':
853
- task_prompt = '<REGION_TO_DESCRIPTION>'
854
- results = run_example(task_prompt, image, text_input)
855
- return results, None
856
- elif task_prompt == 'OCR':
857
- task_prompt = '<OCR>'
858
- results = run_example(task_prompt, image)
859
- return results, None
860
- elif task_prompt == 'OCR with Region':
861
- task_prompt = '<OCR_WITH_REGION>'
862
- results = run_example(task_prompt, image)
863
- output_image = copy.deepcopy(image)
864
- output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
865
- return results, output_image
866
- else:
867
- return "", None # Return empty string and None for unknown task prompts
868
- ##############
869
- # Custom CSS to set the height of the gr.Dropdown menu
870
- css = """
871
- div.progress-level div.progress-level-inner {
872
- text-align: left !important;
873
- width: 55.5% !important;
874
- #output {
875
- height: 500px;
876
- overflow: auto;
877
- border: 1px solid #ccc;
878
- }
879
- """
880
- single_task_list =[
881
- 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
882
- 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
883
- 'Referring Expression Segmentation', 'Region to Segmentation',
884
- 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
885
- 'OCR', 'OCR with Region'
886
- ]
887
- cascaded_task_list =[
888
- 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
889
- ]
890
- def update_task_dropdown(choice):
891
- if choice == 'Cascaded task':
892
- return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
893
- else:
894
- return gr.Dropdown(choices=single_task_list, value='Caption')
895
-
896
- args = parse_args()
897
-
898
- predictor = Predictor()
899
-
900
- dropdown_list = [
901
- EVA02_LARGE_MODEL_DSV3_REPO,
902
- SWINV2_MODEL_DSV3_REPO,
903
- CONV_MODEL_DSV3_REPO,
904
- VIT_MODEL_DSV3_REPO,
905
- VIT_LARGE_MODEL_DSV3_REPO,
906
- # ---
907
- MOAT_MODEL_DSV2_REPO,
908
- SWIN_MODEL_DSV2_REPO,
909
- CONV_MODEL_DSV2_REPO,
910
- CONV2_MODEL_DSV2_REPO,
911
- VIT_MODEL_DSV2_REPO,
912
- # ---
913
- SWINV2_MODEL_IS_DSV1_REPO,
914
- EVA02_LARGE_MODEL_IS_DSV1_REPO,
915
- ]
916
- llama_list = [
917
- META_LLAMA_3_3B_REPO,
918
- META_LLAMA_3_8B_REPO,
919
- ]
920
-
921
- # This is workaround will make the space restart every 2 days. (for test).
922
- def _restart_space():
923
- HF_TOKEN = os.getenv("HF_TOKEN")
924
- if not HF_TOKEN:
925
- raise ValueError("HF_TOKEN environment variable is not set.")
926
- huggingface_hub.HfApi().restart_space(repo_id="Werli/Multi-Tagger", token=HF_TOKEN, factory_reboot=False)
927
- scheduler = BackgroundScheduler()
928
- # Add a job to restart the space every 2 days (172800 seconds)
929
- restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
930
- # Start the scheduler
931
- scheduler.start()
932
- next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
933
- NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
934
-
935
- # Using "reilnuud/polite" theme
936
- with gr.Blocks(title=TITLE, css=css, theme="Werli/wd-tagger-images", fill_width=True) as demo:
937
- gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
938
- gr.Markdown(value=DESCRIPTION)
939
- gr.Markdown(NEXT_RESTART)
940
- with gr.Tab(label="Waifu Diffusion"):
941
- with gr.Row():
942
- with gr.Column():
943
- submit = gr.Button(value="Submit", variant="primary", size="lg")
944
- with gr.Column(variant="panel"):
945
- # Create an Image component for uploading images
946
- image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
947
- with gr.Row():
948
- upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
949
- remove_button = gr.Button("Remove Selected Image", size="sm")
950
- gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
951
-
952
- model_repo = gr.Dropdown(
953
- dropdown_list,
954
- value=EVA02_LARGE_MODEL_DSV3_REPO,
955
- label="Model",
956
- )
957
- with gr.Row():
958
- general_thresh = gr.Slider(
959
- 0,
960
- 1,
961
- step=args.score_slider_step,
962
- value=args.score_general_threshold,
963
- label="General Tags Threshold",
964
- scale=3,
965
- )
966
- general_mcut_enabled = gr.Checkbox(
967
- value=False,
968
- label="Use MCut threshold",
969
- scale=1,
970
- )
971
- with gr.Row():
972
- character_thresh = gr.Slider(
973
- 0,
974
- 1,
975
- step=args.score_slider_step,
976
- value=args.score_character_threshold,
977
- label="Character Tags Threshold",
978
- scale=3,
979
- )
980
- character_mcut_enabled = gr.Checkbox(
981
- value=False,
982
- label="Use MCut threshold",
983
- scale=1,
984
- )
985
- with gr.Row():
986
- characters_merge_enabled = gr.Checkbox(
987
- value=True,
988
- label="Merge characters into the string output",
989
- scale=1,
990
- )
991
- with gr.Row():
992
- llama3_reorganize_model_repo = gr.Dropdown(
993
- [None] + llama_list,
994
- value=None,
995
- label="Llama3 Model",
996
- info="Use the Llama3 model to reorganize the article (Note: very slow)",
997
- )
998
- with gr.Row():
999
- additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
1000
- additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
1001
- with gr.Row():
1002
- clear = gr.ClearButton(
1003
- components=[
1004
- gallery,
1005
- model_repo,
1006
- general_thresh,
1007
- general_mcut_enabled,
1008
- character_thresh,
1009
- character_mcut_enabled,
1010
- characters_merge_enabled,
1011
- llama3_reorganize_model_repo,
1012
- additional_tags_prepend,
1013
- additional_tags_append,
1014
- ],
1015
- variant="secondary",
1016
- size="lg",
1017
- )
1018
- with gr.Column(variant="panel"):
1019
- download_file = gr.File(label="Output (Download)")
1020
- sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True)
1021
- categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True)
1022
- categorized = gr.JSON(label="Categorized (tags)")
1023
- rating = gr.Label(label="Rating")
1024
- character_res = gr.Label(label="Output (characters)")
1025
- general_res = gr.Label(label="Output (tags)")
1026
- unclassified = gr.JSON(label="Unclassified (tags)")
1027
- clear.add(
1028
- [
1029
- download_file,
1030
- sorted_general_strings,
1031
- categorized,
1032
- rating,
1033
- character_res,
1034
- general_res,
1035
- unclassified,
1036
- ]
1037
- )
1038
- tag_results = gr.State({})
1039
- # Define the event listener to add the uploaded image to the gallery
1040
- image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
1041
- # When the upload button is clicked, add the new images to the gallery
1042
- upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
1043
- # Event to update the selected image when an image is clicked in the gallery
1044
- selected_image = gr.Textbox(label="Selected Image", visible=False)
1045
- gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
1046
- # Event to remove a selected image from the gallery
1047
- remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
1048
- submit.click(
1049
- predictor.predict,
1050
- inputs=[
1051
- gallery,
1052
- model_repo,
1053
- general_thresh,
1054
- general_mcut_enabled,
1055
- character_thresh,
1056
- character_mcut_enabled,
1057
- characters_merge_enabled,
1058
- llama3_reorganize_model_repo,
1059
- additional_tags_prepend,
1060
- additional_tags_append,
1061
- tag_results,
1062
- ],
1063
- outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results, categorized_output,],
1064
- )
1065
- gr.Examples(
1066
- [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
1067
- inputs=[
1068
- image_input,
1069
- model_repo,
1070
- general_thresh,
1071
- general_mcut_enabled,
1072
- character_thresh,
1073
- character_mcut_enabled,
1074
- ],
1075
- )
1076
- with gr.Tab(label="Florence 2 Image Captioning"):
1077
- with gr.Row():
1078
- with gr.Column(variant="panel"):
1079
- input_img = gr.Image(label="Input Picture")
1080
- task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
1081
- task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
1082
- task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
1083
- text_input = gr.Textbox(label="Text Input (optional)")
1084
- submit_btn = gr.Button(value="Submit")
1085
- with gr.Column(variant="panel"):
1086
- #OUTPUT
1087
- output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8) # Here is the problem!
1088
- output_img = gr.Image(label="Output Image")
1089
- gr.Examples(
1090
- examples=[
1091
- ["images/image1.png", 'Object Detection'],
1092
- ["images/image2.png", 'OCR with Region']
1093
- ],
1094
- inputs=[input_img, task_prompt],
1095
- outputs=[output_text, output_img],
1096
- fn=process_image,
1097
- cache_examples=False,
1098
- label='Try examples'
1099
- )
1100
- submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
1101
-
1102
- demo.queue(max_size=2)
1103
  demo.launch(debug=True) # test
 
1
+ import os
2
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3
+ import io
4
+ import copy
5
+ import requests
6
+ import numpy as np
7
+ import spaces
8
+ import gradio as gr
9
+ from transformers import AutoProcessor, AutoModelForCausalLM
10
+ from transformers import AutoModelForCausalLM, AutoProcessor
11
+ from transformers.dynamic_module_utils import get_imports
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.patches as patches
15
+ from unittest.mock import patch
16
+
17
+ import argparse
18
+ import huggingface_hub
19
+ import onnxruntime as rt
20
+ import pandas as pd
21
+ import traceback
22
+ import tempfile
23
+ import zipfile
24
+ import re
25
+ import ast
26
+ import time
27
+ from datetime import datetime, timezone
28
+ from collections import defaultdict
29
+ from classifyTags import classify_tags
30
+ # Add scheduler code here
31
+ from apscheduler.schedulers.background import BackgroundScheduler
32
+
33
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
35
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
36
+ if not str(filename).endswith("/modeling_florence2.py"):
37
+ return get_imports(filename)
38
+ imports = get_imports(filename)
39
+ if "flash_attn" in imports:
40
+ imports.remove("flash_attn")
41
+ return imports
42
+
43
+ @spaces.GPU
44
+ def get_device_type():
45
+ import torch
46
+ if torch.cuda.is_available():
47
+ return "cuda"
48
+ else:
49
+ if (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
50
+ return "mps"
51
+ else:
52
+ return "cpu"
53
+
54
+ model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
55
+
56
+ import subprocess
57
+ device = get_device_type()
58
+ if (device == "cuda"):
59
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
60
+ model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
61
+ processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
62
+ model.to(device)
63
+ else:
64
+ #https://huggingface.co/microsoft/Florence-2-base-ft/discussions/4
65
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
66
+ model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
67
+ processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
68
+ model.to(device)
69
+
70
+ TITLE = "Multi-Tagger"
71
+ DESCRIPTION = """
72
+ Multi-Tagger is a powerful and versatile application that integrates two cutting-edge models: Waifu Diffusion and Florence 2. This app is designed to provide comprehensive image analysis and captioning capabilities, making it a valuable tool for AI artists, researchers, and enthusiasts.
73
+
74
+ Features:
75
+ - Supports batch processing of multiple images.
76
+ - Tags images with multiple categories: general tags, character tags, and ratings.
77
+ - Tags are categorized into groups (e.g., general, characters, ratings).
78
+ - Displays categorized tags in a structured format.
79
+ - Integrates Llama3 models to reorganize the tags into a readable English article.
80
+ - Includes a separate tab for image captioning using Florence 2.
81
+ - Florence 2 supports CUDA, MPS or CPU if one of them is available.
82
+ - Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection).
83
+ - Displays output text and images for tasks that generate visual outputs.
84
+ - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process.
85
+
86
+ Example image by [me.](https://huggingface.co/Werli)
87
+ """
88
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
89
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
90
+
91
+ # Dataset v3 series of models:
92
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
93
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
94
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
95
+ VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
96
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
97
+
98
+ # Dataset v2 series of models:
99
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
100
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
101
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
102
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
103
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
104
+
105
+ # IdolSankaku series of models:
106
+ EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
107
+ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
108
+
109
+ # Files to download from the repos
110
+ MODEL_FILENAME = "model.onnx"
111
+ LABEL_FILENAME = "selected_tags.csv"
112
+
113
+ # LLAMA model
114
+ META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
115
+ META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
116
+
117
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
118
+ kaomojis = [
119
+ "0_0",
120
+ "(o)_(o)",
121
+ "+_+",
122
+ "+_-",
123
+ "._.",
124
+ "<o>_<o>",
125
+ "<|>_<|>",
126
+ "=_=",
127
+ ">_<",
128
+ "3_3",
129
+ "6_9",
130
+ ">_o",
131
+ "@_@",
132
+ "^_^",
133
+ "o_o",
134
+ "u_u",
135
+ "x_x",
136
+ "|_|",
137
+ "||_||",
138
+ ]
139
+ def parse_args() -> argparse.Namespace:
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
142
+ parser.add_argument("--score-general-threshold", type=float, default=0.35)
143
+ parser.add_argument("--score-character-threshold", type=float, default=0.85)
144
+ parser.add_argument("--share", action="store_true")
145
+ return parser.parse_args()
146
+ def load_labels(dataframe) -> list[str]:
147
+ name_series = dataframe["name"]
148
+ name_series = name_series.map(
149
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
150
+ )
151
+ tag_names = name_series.tolist()
152
+
153
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
154
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
155
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
156
+ return tag_names, rating_indexes, general_indexes, character_indexes
157
+ def mcut_threshold(probs):
158
+ """
159
+ Maximum Cut Thresholding (MCut)
160
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
161
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
162
+ (pp. 172-183).
163
+ """
164
+ sorted_probs = probs[probs.argsort()[::-1]]
165
+ difs = sorted_probs[:-1] - sorted_probs[1:]
166
+ t = difs.argmax()
167
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
168
+ return thresh
169
+ class Timer:
170
+ def __init__(self):
171
+ self.start_time = time.perf_counter() # Record the start time
172
+ self.checkpoints = [("Start", self.start_time)] # Store checkpoints
173
+
174
+ def checkpoint(self, label="Checkpoint"):
175
+ """Record a checkpoint with a given label."""
176
+ now = time.perf_counter()
177
+ self.checkpoints.append((label, now))
178
+
179
+ def report(self, is_clear_checkpoints = True):
180
+ # Determine the max label width for alignment
181
+ max_label_length = max(len(label) for label, _ in self.checkpoints)
182
+
183
+ prev_time = self.checkpoints[0][1]
184
+ for label, curr_time in self.checkpoints[1:]:
185
+ elapsed = curr_time - prev_time
186
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
187
+ prev_time = curr_time
188
+
189
+ if is_clear_checkpoints:
190
+ self.checkpoints.clear()
191
+ self.checkpoint() # Store checkpoints
192
+
193
+ def report_all(self):
194
+ """Print all recorded checkpoints and total execution time with aligned formatting."""
195
+ print("\n> Execution Time Report:")
196
+
197
+ # Determine the max label width for alignment
198
+ max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
199
+
200
+ prev_time = self.start_time
201
+ for label, curr_time in self.checkpoints[1:]:
202
+ elapsed = curr_time - prev_time
203
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
204
+ prev_time = curr_time
205
+
206
+ total_time = self.checkpoints[-1][1] - self.start_time
207
+ print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
208
+
209
+ self.checkpoints.clear()
210
+
211
+ def restart(self):
212
+ self.start_time = time.perf_counter() # Record the start time
213
+ self.checkpoints = [("Start", self.start_time)] # Store checkpoints
214
+
215
+ class Llama3Reorganize:
216
+ def __init__(
217
+ self,
218
+ repoId: str,
219
+ device: str = None,
220
+ loadModel: bool = False,
221
+ ):
222
+ """Initializes the Llama model.
223
+
224
+ Args:
225
+ repoId: LLAMA model repo.
226
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
227
+ ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
228
+ localFilesOnly: If True, avoid downloading the file and return the path to the
229
+ local cached file if it exists.
230
+ """
231
+ self.modelPath = self.download_model(repoId)
232
+
233
+ if device is None:
234
+ import torch
235
+ self.totalVram = 0
236
+ if torch.cuda.is_available():
237
+ try:
238
+ deviceId = torch.cuda.current_device()
239
+ self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
240
+ except Exception as e:
241
+ print(traceback.format_exc())
242
+ print("Error detect vram: " + str(e))
243
+ device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
244
+ else:
245
+ device = "cpu"
246
+
247
+ self.device = device
248
+ self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
249
+
250
+ if loadModel:
251
+ self.load_model()
252
+
253
+ def download_model(self, repoId):
254
+ import warnings
255
+ import requests
256
+ allowPatterns = [
257
+ "config.json",
258
+ "generation_config.json",
259
+ "model.bin",
260
+ "pytorch_model.bin",
261
+ "pytorch_model.bin.index.json",
262
+ "pytorch_model-*.bin",
263
+ "sentencepiece.bpe.model",
264
+ "tokenizer.json",
265
+ "tokenizer_config.json",
266
+ "shared_vocabulary.txt",
267
+ "shared_vocabulary.json",
268
+ "special_tokens_map.json",
269
+ "spiece.model",
270
+ "vocab.json",
271
+ "model.safetensors",
272
+ "model-*.safetensors",
273
+ "model.safetensors.index.json",
274
+ "quantize_config.json",
275
+ "tokenizer.model",
276
+ "vocabulary.json",
277
+ "preprocessor_config.json",
278
+ "added_tokens.json"
279
+ ]
280
+
281
+ kwargs = {"allow_patterns": allowPatterns,}
282
+
283
+ try:
284
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
285
+ except (
286
+ huggingface_hub.utils.HfHubHTTPError,
287
+ requests.exceptions.ConnectionError,
288
+ ) as exception:
289
+ warnings.warn(
290
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
291
+ repoId,
292
+ exception,
293
+ )
294
+ warnings.warn(
295
+ "Trying to load the model directly from the local cache, if it exists."
296
+ )
297
+
298
+ kwargs["local_files_only"] = True
299
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
300
+
301
+
302
+ def load_model(self):
303
+ import ctranslate2
304
+ import transformers
305
+ try:
306
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
307
+ kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
308
+ kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
309
+ self.roleSystem = {"role": "system", "content": self.system_prompt}
310
+ self.Model = ctranslate2.Generator(**kwargsModel)
311
+
312
+ self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
313
+ self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
314
+
315
+ except Exception as e:
316
+ self.release_vram()
317
+ raise e
318
+
319
+
320
+ def release_vram(self):
321
+ try:
322
+ import torch
323
+ if torch.cuda.is_available():
324
+ if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
325
+ self.Model.unload_model()
326
+
327
+ if getattr(self, "Tokenizer", None) is not None:
328
+ del self.Tokenizer
329
+ if getattr(self, "Model", None) is not None:
330
+ del self.Model
331
+ import gc
332
+ gc.collect()
333
+ try:
334
+ torch.cuda.empty_cache()
335
+ except Exception as e:
336
+ print(traceback.format_exc())
337
+ print("\tcuda empty cache, error: " + str(e))
338
+ print("release vram end.")
339
+ except Exception as e:
340
+ print(traceback.format_exc())
341
+ print("Error release vram: " + str(e))
342
+
343
+ def reorganize(self, text: str, max_length: int = 400):
344
+ output = None
345
+ result = None
346
+ try:
347
+ input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
348
+ source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
349
+ output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
350
+ target = output[0]
351
+ result = self.Tokenizer.decode(target.sequences_ids[0])
352
+
353
+ if len(result) > 2:
354
+ if result[0] == "\"" and result[len(result) - 1] == "\"":
355
+ result = result[1:-1]
356
+ elif result[0] == "'" and result[len(result) - 1] == "'":
357
+ result = result[1:-1]
358
+ elif result[0] == "「" and result[len(result) - 1] == "」":
359
+ result = result[1:-1]
360
+ elif result[0] == "『" and result[len(result) - 1] == "』":
361
+ result = result[1:-1]
362
+ except Exception as e:
363
+ print(traceback.format_exc())
364
+ print("Error reorganize text: " + str(e))
365
+
366
+ return result
367
+
368
+
369
+ class Predictor:
370
+ def __init__(self):
371
+ self.model_target_size = None
372
+ self.last_loaded_repo = None
373
+ def download_model(self, model_repo):
374
+ csv_path = huggingface_hub.hf_hub_download(
375
+ model_repo,
376
+ LABEL_FILENAME,
377
+ )
378
+ model_path = huggingface_hub.hf_hub_download(
379
+ model_repo,
380
+ MODEL_FILENAME,
381
+ )
382
+ return csv_path, model_path
383
+ def load_model(self, model_repo):
384
+ if model_repo == self.last_loaded_repo:
385
+ return
386
+
387
+ csv_path, model_path = self.download_model(model_repo)
388
+
389
+ tags_df = pd.read_csv(csv_path)
390
+ sep_tags = load_labels(tags_df)
391
+
392
+ self.tag_names = sep_tags[0]
393
+ self.rating_indexes = sep_tags[1]
394
+ self.general_indexes = sep_tags[2]
395
+ self.character_indexes = sep_tags[3]
396
+
397
+ model = rt.InferenceSession(model_path)
398
+ _, height, width, _ = model.get_inputs()[0].shape
399
+ self.model_target_size = height
400
+
401
+ self.last_loaded_repo = model_repo
402
+ self.model = model
403
+ def prepare_image(self, path):
404
+ image = Image.open(path)
405
+ image = image.convert("RGBA")
406
+ target_size = self.model_target_size
407
+
408
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
409
+ canvas.alpha_composite(image)
410
+ image = canvas.convert("RGB")
411
+
412
+ # Pad image to square
413
+ image_shape = image.size
414
+ max_dim = max(image_shape)
415
+ pad_left = (max_dim - image_shape[0]) // 2
416
+ pad_top = (max_dim - image_shape[1]) // 2
417
+
418
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
419
+ padded_image.paste(image, (pad_left, pad_top))
420
+
421
+ # Resize
422
+ if max_dim != target_size:
423
+ padded_image = padded_image.resize(
424
+ (target_size, target_size),
425
+ Image.BICUBIC,
426
+ )
427
+ # Convert to numpy array
428
+ image_array = np.asarray(padded_image, dtype=np.float32)
429
+
430
+ # Convert PIL-native RGB to BGR
431
+ image_array = image_array[:, :, ::-1]
432
+
433
+ return np.expand_dims(image_array, axis=0)
434
+
435
+ def create_file(self, text: str, directory: str, fileName: str) -> str:
436
+ # Write the text to a file
437
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
438
+ file.write(text)
439
+
440
+ return file.name
441
+
442
+ def predict(
443
+ self,
444
+ gallery,
445
+ model_repo,
446
+ general_thresh,
447
+ general_mcut_enabled,
448
+ character_thresh,
449
+ character_mcut_enabled,
450
+ characters_merge_enabled,
451
+ llama3_reorganize_model_repo,
452
+ additional_tags_prepend,
453
+ additional_tags_append,
454
+ tag_results,
455
+ progress=gr.Progress()
456
+ ):
457
+ gallery_len = len(gallery)
458
+ print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
459
+
460
+ timer = Timer() # Create a timer
461
+ progressRatio = 0.5 if llama3_reorganize_model_repo else 1
462
+ progressTotal = gallery_len + 1
463
+ current_progress = 0
464
+
465
+ self.load_model(model_repo)
466
+ current_progress += progressRatio/progressTotal;
467
+ progress(current_progress, desc="Initialize wd model finished")
468
+ timer.checkpoint(f"Initialize wd model")
469
+
470
+ # Result
471
+ txt_infos = []
472
+ output_dir = tempfile.mkdtemp()
473
+ if not os.path.exists(output_dir):
474
+ os.makedirs(output_dir)
475
+
476
+ sorted_general_strings = ""
477
+ rating = None
478
+ character_res = None
479
+ general_res = None
480
+
481
+ if llama3_reorganize_model_repo:
482
+ print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
483
+ llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
484
+ current_progress += progressRatio/progressTotal;
485
+ progress(current_progress, desc="Initialize llama3 model finished")
486
+ timer.checkpoint(f"Initialize llama3 model")
487
+
488
+ timer.report()
489
+
490
+ prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
491
+ append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
492
+ if prepend_list and append_list:
493
+ append_list = [item for item in append_list if item not in prepend_list]
494
+
495
+ # Dictionary to track counters for each filename
496
+ name_counters = defaultdict(int)
497
+ # New code to create categorized output string
498
+ categorized_output_strings = []
499
+ for idx, value in enumerate(gallery):
500
+ try:
501
+ image_path = value[0]
502
+ image_name = os.path.splitext(os.path.basename(image_path))[0]
503
+
504
+ # Increment the counter for the current name
505
+ name_counters[image_name] += 1
506
+
507
+ if name_counters[image_name] > 1:
508
+ image_name = f"{image_name}_{name_counters[image_name]:02d}"
509
+
510
+ image = self.prepare_image(image_path)
511
+
512
+ input_name = self.model.get_inputs()[0].name
513
+ label_name = self.model.get_outputs()[0].name
514
+ print(f"Gallery {idx:02d}: Starting run wd model...")
515
+ preds = self.model.run([label_name], {input_name: image})[0]
516
+
517
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
518
+
519
+ # First 4 labels are actually ratings: pick one with argmax
520
+ ratings_names = [labels[i] for i in self.rating_indexes]
521
+ rating = dict(ratings_names)
522
+
523
+ # Then we have general tags: pick any where prediction confidence > threshold
524
+ general_names = [labels[i] for i in self.general_indexes]
525
+
526
+ if general_mcut_enabled:
527
+ general_probs = np.array([x[1] for x in general_names])
528
+ general_thresh = mcut_threshold(general_probs)
529
+
530
+ general_res = [x for x in general_names if x[1] > general_thresh]
531
+ general_res = dict(general_res)
532
+
533
+ # Everything else is characters: pick any where prediction confidence > threshold
534
+ character_names = [labels[i] for i in self.character_indexes]
535
+
536
+ if character_mcut_enabled:
537
+ character_probs = np.array([x[1] for x in character_names])
538
+ character_thresh = mcut_threshold(character_probs)
539
+ character_thresh = max(0.15, character_thresh)
540
+
541
+ character_res = [x for x in character_names if x[1] > character_thresh]
542
+ character_res = dict(character_res)
543
+ character_list = list(character_res.keys())
544
+
545
+ sorted_general_list = sorted(
546
+ general_res.items(),
547
+ key=lambda x: x[1],
548
+ reverse=True,
549
+ )
550
+ sorted_general_list = [x[0] for x in sorted_general_list]
551
+ #Remove values from character_list that already exist in sorted_general_list
552
+ character_list = [item for item in character_list if item not in sorted_general_list]
553
+ #Remove values from sorted_general_list that already exist in prepend_list or append_list
554
+ if prepend_list:
555
+ sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
556
+ if append_list:
557
+ sorted_general_list = [item for item in sorted_general_list if item not in append_list]
558
+
559
+ sorted_general_list = prepend_list + sorted_general_list + append_list
560
+
561
+ sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
562
+
563
+ classified_tags, unclassified_tags = classify_tags(sorted_general_list)
564
+
565
+ # Create a single string of all categorized tags
566
+ categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
567
+ categorized_output_strings.append(categorized_output_string)
568
+
569
+ current_progress += progressRatio/progressTotal;
570
+ progress(current_progress, desc=f"image{idx:02d}, predict finished")
571
+ timer.checkpoint(f"image{idx:02d}, predict finished")
572
+
573
+ if llama3_reorganize_model_repo:
574
+ print(f"Starting reorganize with llama3...")
575
+ reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
576
+ reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
577
+ reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
578
+ reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
579
+ sorted_general_strings += "," + reorganize_strings
580
+
581
+ current_progress += progressRatio/progressTotal;
582
+ progress(current_progress, desc=f"image{idx:02d}, llama3 reorganize finished")
583
+ timer.checkpoint(f"image{idx:02d}, llama3 reorganize finished")
584
+
585
+ txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
586
+ txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
587
+
588
+ tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
589
+ timer.report()
590
+ except Exception as e:
591
+ print(traceback.format_exc())
592
+ print("Error predict: " + str(e))
593
+ # Result
594
+ download = []
595
+ if txt_infos is not None and len(txt_infos) > 0:
596
+ downloadZipPath = os.path.join(output_dir, "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
597
+ with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
598
+ for info in txt_infos:
599
+ # Get file name from lookup
600
+ taggers_zip.write(info["path"], arcname=info["name"])
601
+ download.append(downloadZipPath)
602
+
603
+ if llama3_reorganize_model_repo:
604
+ llama3_reorganize.release_vram()
605
+ del llama3_reorganize
606
+
607
+ progress(1, desc=f"Predict completed")
608
+ timer.report_all() # Print all recorded times
609
+ print("Predict is complete.")
610
+
611
+ # Collect all categorized output strings into a single string
612
+ final_categorized_output = ', '.join(categorized_output_strings)
613
+
614
+ return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results, final_categorized_output
615
+ # END
616
+
617
+ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
618
+ if not selected_state:
619
+ return selected_state
620
+
621
+ tag_result = { "strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}" }
622
+ if selected_state.value["image"]["path"] in tag_results:
623
+ tag_result = tag_results[selected_state.value["image"]["path"]]
624
+
625
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
626
+
627
+ def append_gallery(gallery: list, image: str):
628
+ if gallery is None:
629
+ gallery = []
630
+ if not image:
631
+ return gallery, None
632
+
633
+ gallery.append(image)
634
+
635
+ return gallery, None
636
+
637
+
638
+ def extend_gallery(gallery: list, images):
639
+ if gallery is None:
640
+ gallery = []
641
+ if not images:
642
+ return gallery
643
+
644
+ # Combine the new images with the existing gallery images
645
+ gallery.extend(images)
646
+
647
+ return gallery
648
+
649
+ def remove_image_from_gallery(gallery: list, selected_image: str):
650
+ if not gallery or not selected_image:
651
+ return gallery
652
+
653
+ selected_image = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
654
+ # Remove the selected image from the gallery
655
+ if selected_image in gallery:
656
+ gallery.remove(selected_image)
657
+ return gallery
658
+
659
+ # END
660
+
661
+ def fig_to_pil(fig):
662
+ buf = io.BytesIO()
663
+ fig.savefig(buf, format='png')
664
+ buf.seek(0)
665
+ return Image.open(buf)
666
+
667
+ @spaces.GPU
668
+ def run_example(task_prompt, image, text_input=None):
669
+ if text_input is None:
670
+ prompt = task_prompt
671
+ else:
672
+ prompt = task_prompt + text_input
673
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
674
+ generated_ids = model.generate(
675
+ input_ids=inputs["input_ids"],
676
+ pixel_values=inputs["pixel_values"],
677
+ max_new_tokens=1024,
678
+ early_stopping=False,
679
+ do_sample=False,
680
+ num_beams=3,
681
+ )
682
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
683
+ parsed_answer = processor.post_process_generation(
684
+ generated_text,
685
+ task=task_prompt,
686
+ image_size=(image.width, image.height)
687
+ )
688
+ return parsed_answer
689
+
690
+ def plot_bbox(image, data):
691
+ fig, ax = plt.subplots()
692
+ ax.imshow(image)
693
+ for bbox, label in zip(data['bboxes'], data['labels']):
694
+ x1, y1, x2, y2 = bbox
695
+ rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
696
+ ax.add_patch(rect)
697
+ plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
698
+ ax.axis('off')
699
+ return fig
700
+
701
+ def draw_polygons(image, prediction, fill_mask=False):
702
+ draw = ImageDraw.Draw(image)
703
+ scale = 1
704
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
705
+ color = random.choice(colormap)
706
+ fill_color = random.choice(colormap) if fill_mask else None
707
+ for _polygon in polygons:
708
+ _polygon = np.array(_polygon).reshape(-1, 2)
709
+ if len(_polygon) < 3:
710
+ print('Invalid polygon:', _polygon)
711
+ continue
712
+ _polygon = (_polygon * scale).reshape(-1).tolist()
713
+ if fill_mask:
714
+ draw.polygon(_polygon, outline=color, fill=fill_color)
715
+ else:
716
+ draw.polygon(_polygon, outline=color)
717
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
718
+ return image
719
+
720
+ def convert_to_od_format(data):
721
+ bboxes = data.get('bboxes', [])
722
+ labels = data.get('bboxes_labels', [])
723
+ od_results = {
724
+ 'bboxes': bboxes,
725
+ 'labels': labels
726
+ }
727
+ return od_results
728
+
729
+ def draw_ocr_bboxes(image, prediction):
730
+ scale = 1
731
+ draw = ImageDraw.Draw(image)
732
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
733
+ for box, label in zip(bboxes, labels):
734
+ color = random.choice(colormap)
735
+ new_box = (np.array(box) * scale).tolist()
736
+ draw.polygon(new_box, width=3, outline=color)
737
+ draw.text((new_box[0]+8, new_box[1]+2),
738
+ "{}".format(label),
739
+ align="right",
740
+ fill=color)
741
+ return image
742
+
743
+ def convert_to_od_format(data):
744
+ bboxes = data.get('bboxes', [])
745
+ labels = data.get('bboxes_labels', [])
746
+ od_results = {
747
+ 'bboxes': bboxes,
748
+ 'labels': labels
749
+ }
750
+ return od_results
751
+
752
+ def draw_ocr_bboxes(image, prediction):
753
+ scale = 1
754
+ draw = ImageDraw.Draw(image)
755
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
756
+ for box, label in zip(bboxes, labels):
757
+ color = random.choice(colormap)
758
+ new_box = (np.array(box) * scale).tolist()
759
+ draw.polygon(new_box, width=3, outline=color)
760
+ draw.text((new_box[0]+8, new_box[1]+2),
761
+ "{}".format(label),
762
+ align="right",
763
+ fill=color)
764
+ return image
765
+ def process_image(image, task_prompt, text_input=None):
766
+ # Test
767
+ if isinstance(image, str): # If image is a file path
768
+ image = Image.open(image) # Load image from file path
769
+ else: # If image is a NumPy array
770
+ image = Image.fromarray(image) # Convert NumPy array to PIL Image
771
+ if task_prompt == 'Caption':
772
+ task_prompt = '<CAPTION>'
773
+ results = run_example(task_prompt, image)
774
+ return results[task_prompt], None
775
+ elif task_prompt == 'Detailed Caption':
776
+ task_prompt = '<DETAILED_CAPTION>'
777
+ results = run_example(task_prompt, image)
778
+ return results[task_prompt], None
779
+ elif task_prompt == 'More Detailed Caption':
780
+ task_prompt = '<MORE_DETAILED_CAPTION>'
781
+ results = run_example(task_prompt, image)
782
+ return results[task_prompt], plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
783
+ elif task_prompt == 'Caption + Grounding':
784
+ task_prompt = '<CAPTION>'
785
+ results = run_example(task_prompt, image)
786
+ text_input = results[task_prompt]
787
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
788
+ results = run_example(task_prompt, image, text_input)
789
+ results['<CAPTION>'] = text_input
790
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
791
+ return results, fig_to_pil(fig)
792
+ elif task_prompt == 'Detailed Caption + Grounding':
793
+ task_prompt = '<DETAILED_CAPTION>'
794
+ results = run_example(task_prompt, image)
795
+ text_input = results[task_prompt]
796
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
797
+ results = run_example(task_prompt, image, text_input)
798
+ results['<DETAILED_CAPTION>'] = text_input
799
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
800
+ return results, fig_to_pil(fig)
801
+ elif task_prompt == 'More Detailed Caption + Grounding':
802
+ task_prompt = '<MORE_DETAILED_CAPTION>'
803
+ results = run_example(task_prompt, image)
804
+ text_input = results[task_prompt]
805
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
806
+ results = run_example(task_prompt, image, text_input)
807
+ results['<MORE_DETAILED_CAPTION>'] = text_input
808
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
809
+ return results, fig_to_pil(fig)
810
+ elif task_prompt == 'Object Detection':
811
+ task_prompt = '<OD>'
812
+ results = run_example(task_prompt, image)
813
+ fig = plot_bbox(image, results['<OD>'])
814
+ return results, fig_to_pil(fig)
815
+ elif task_prompt == 'Dense Region Caption':
816
+ task_prompt = '<DENSE_REGION_CAPTION>'
817
+ results = run_example(task_prompt, image)
818
+ fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
819
+ return results, fig_to_pil(fig)
820
+ elif task_prompt == 'Region Proposal':
821
+ task_prompt = '<REGION_PROPOSAL>'
822
+ results = run_example(task_prompt, image)
823
+ fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
824
+ return results, fig_to_pil(fig)
825
+ elif task_prompt == 'Caption to Phrase Grounding':
826
+ task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
827
+ results = run_example(task_prompt, image, text_input)
828
+ fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
829
+ return results, fig_to_pil(fig)
830
+ elif task_prompt == 'Referring Expression Segmentation':
831
+ task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
832
+ results = run_example(task_prompt, image, text_input)
833
+ output_image = copy.deepcopy(image)
834
+ output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
835
+ return results, output_image
836
+ elif task_prompt == 'Region to Segmentation':
837
+ task_prompt = '<REGION_TO_SEGMENTATION>'
838
+ results = run_example(task_prompt, image, text_input)
839
+ output_image = copy.deepcopy(image)
840
+ output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
841
+ return results, output_image
842
+ elif task_prompt == 'Open Vocabulary Detection':
843
+ task_prompt = '<OPEN_VOCABULARY_DETECTION>'
844
+ results = run_example(task_prompt, image, text_input)
845
+ bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
846
+ fig = plot_bbox(image, bbox_results)
847
+ return results, fig_to_pil(fig)
848
+ elif task_prompt == 'Region to Category':
849
+ task_prompt = '<REGION_TO_CATEGORY>'
850
+ results = run_example(task_prompt, image, text_input)
851
+ return results, None
852
+ elif task_prompt == 'Region to Description':
853
+ task_prompt = '<REGION_TO_DESCRIPTION>'
854
+ results = run_example(task_prompt, image, text_input)
855
+ return results, None
856
+ elif task_prompt == 'OCR':
857
+ task_prompt = '<OCR>'
858
+ results = run_example(task_prompt, image)
859
+ return results, None
860
+ elif task_prompt == 'OCR with Region':
861
+ task_prompt = '<OCR_WITH_REGION>'
862
+ results = run_example(task_prompt, image)
863
+ output_image = copy.deepcopy(image)
864
+ output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
865
+ return results, output_image
866
+ else:
867
+ return "", None # Return empty string and None for unknown task prompts
868
+ ##############
869
+ # Custom CSS to set the height of the gr.Dropdown menu
870
+ css = """
871
+ div.progress-level div.progress-level-inner {
872
+ text-align: left !important;
873
+ width: 55.5% !important;
874
+ #output {
875
+ height: 500px;
876
+ overflow: auto;
877
+ border: 1px solid #ccc;
878
+ }
879
+ """
880
+ single_task_list =[
881
+ 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
882
+ 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
883
+ 'Referring Expression Segmentation', 'Region to Segmentation',
884
+ 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
885
+ 'OCR', 'OCR with Region'
886
+ ]
887
+ cascaded_task_list =[
888
+ 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
889
+ ]
890
+ def update_task_dropdown(choice):
891
+ if choice == 'Cascaded task':
892
+ return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
893
+ else:
894
+ return gr.Dropdown(choices=single_task_list, value='Caption')
895
+
896
+ args = parse_args()
897
+
898
+ predictor = Predictor()
899
+
900
+ dropdown_list = [
901
+ EVA02_LARGE_MODEL_DSV3_REPO,
902
+ SWINV2_MODEL_DSV3_REPO,
903
+ CONV_MODEL_DSV3_REPO,
904
+ VIT_MODEL_DSV3_REPO,
905
+ VIT_LARGE_MODEL_DSV3_REPO,
906
+ # ---
907
+ MOAT_MODEL_DSV2_REPO,
908
+ SWIN_MODEL_DSV2_REPO,
909
+ CONV_MODEL_DSV2_REPO,
910
+ CONV2_MODEL_DSV2_REPO,
911
+ VIT_MODEL_DSV2_REPO,
912
+ # ---
913
+ SWINV2_MODEL_IS_DSV1_REPO,
914
+ EVA02_LARGE_MODEL_IS_DSV1_REPO,
915
+ ]
916
+ llama_list = [
917
+ META_LLAMA_3_3B_REPO,
918
+ META_LLAMA_3_8B_REPO,
919
+ ]
920
+
921
+ # This is workaround will make the space restart every 2 days. (for test).
922
+ def _restart_space():
923
+ HF_TOKEN = os.getenv("HF_TOKEN")
924
+ if not HF_TOKEN:
925
+ raise ValueError("HF_TOKEN environment variable is not set.")
926
+ huggingface_hub.HfApi().restart_space(repo_id="Werli/Multi-Tagger", token=HF_TOKEN, factory_reboot=False)
927
+ scheduler = BackgroundScheduler()
928
+ # Add a job to restart the space every 2 days (172800 seconds)
929
+ restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
930
+ # Start the scheduler
931
+ scheduler.start()
932
+ next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
933
+ NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
934
+
935
+ # Using "JohnSmith9982/small_and_pretty" theme
936
+ with gr.Blocks(title=TITLE, css=css, theme="Werli/wd-tagger-images", fill_width=True) as demo:
937
+ gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
938
+ gr.Markdown(value=DESCRIPTION)
939
+ gr.Markdown(NEXT_RESTART)
940
+ with gr.Tab(label="Waifu Diffusion"):
941
+ with gr.Row():
942
+ with gr.Column():
943
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
944
+ with gr.Column(variant="panel"):
945
+ # Create an Image component for uploading images
946
+ image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
947
+ with gr.Row():
948
+ upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
949
+ remove_button = gr.Button("Remove Selected Image", size="sm")
950
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
951
+
952
+ model_repo = gr.Dropdown(
953
+ dropdown_list,
954
+ value=EVA02_LARGE_MODEL_DSV3_REPO,
955
+ label="Model",
956
+ )
957
+ with gr.Row():
958
+ general_thresh = gr.Slider(
959
+ 0,
960
+ 1,
961
+ step=args.score_slider_step,
962
+ value=args.score_general_threshold,
963
+ label="General Tags Threshold",
964
+ scale=3,
965
+ )
966
+ general_mcut_enabled = gr.Checkbox(
967
+ value=False,
968
+ label="Use MCut threshold",
969
+ scale=1,
970
+ )
971
+ with gr.Row():
972
+ character_thresh = gr.Slider(
973
+ 0,
974
+ 1,
975
+ step=args.score_slider_step,
976
+ value=args.score_character_threshold,
977
+ label="Character Tags Threshold",
978
+ scale=3,
979
+ )
980
+ character_mcut_enabled = gr.Checkbox(
981
+ value=False,
982
+ label="Use MCut threshold",
983
+ scale=1,
984
+ )
985
+ with gr.Row():
986
+ characters_merge_enabled = gr.Checkbox(
987
+ value=True,
988
+ label="Merge characters into the string output",
989
+ scale=1,
990
+ )
991
+ with gr.Row():
992
+ llama3_reorganize_model_repo = gr.Dropdown(
993
+ [None] + llama_list,
994
+ value=None,
995
+ label="Llama3 Model",
996
+ info="Use the Llama3 model to reorganize the article (Note: very slow)",
997
+ )
998
+ with gr.Row():
999
+ additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
1000
+ additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
1001
+ with gr.Row():
1002
+ clear = gr.ClearButton(
1003
+ components=[
1004
+ gallery,
1005
+ model_repo,
1006
+ general_thresh,
1007
+ general_mcut_enabled,
1008
+ character_thresh,
1009
+ character_mcut_enabled,
1010
+ characters_merge_enabled,
1011
+ llama3_reorganize_model_repo,
1012
+ additional_tags_prepend,
1013
+ additional_tags_append,
1014
+ ],
1015
+ variant="secondary",
1016
+ size="lg",
1017
+ )
1018
+ with gr.Column(variant="panel"):
1019
+ download_file = gr.File(label="Output (Download)")
1020
+ sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True)
1021
+ categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True)
1022
+ categorized = gr.JSON(label="Categorized (tags)")
1023
+ rating = gr.Label(label="Rating")
1024
+ character_res = gr.Label(label="Output (characters)")
1025
+ general_res = gr.Label(label="Output (tags)")
1026
+ unclassified = gr.JSON(label="Unclassified (tags)")
1027
+ clear.add(
1028
+ [
1029
+ download_file,
1030
+ sorted_general_strings,
1031
+ categorized,
1032
+ rating,
1033
+ character_res,
1034
+ general_res,
1035
+ unclassified,
1036
+ ]
1037
+ )
1038
+ tag_results = gr.State({})
1039
+ # Define the event listener to add the uploaded image to the gallery
1040
+ image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
1041
+ # When the upload button is clicked, add the new images to the gallery
1042
+ upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
1043
+ # Event to update the selected image when an image is clicked in the gallery
1044
+ selected_image = gr.Textbox(label="Selected Image", visible=False)
1045
+ gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
1046
+ # Event to remove a selected image from the gallery
1047
+ remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
1048
+ submit.click(
1049
+ predictor.predict,
1050
+ inputs=[
1051
+ gallery,
1052
+ model_repo,
1053
+ general_thresh,
1054
+ general_mcut_enabled,
1055
+ character_thresh,
1056
+ character_mcut_enabled,
1057
+ characters_merge_enabled,
1058
+ llama3_reorganize_model_repo,
1059
+ additional_tags_prepend,
1060
+ additional_tags_append,
1061
+ tag_results,
1062
+ ],
1063
+ outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results, categorized_output,],
1064
+ )
1065
+ gr.Examples(
1066
+ [["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
1067
+ inputs=[
1068
+ image_input,
1069
+ model_repo,
1070
+ general_thresh,
1071
+ general_mcut_enabled,
1072
+ character_thresh,
1073
+ character_mcut_enabled,
1074
+ ],
1075
+ )
1076
+ with gr.Tab(label="Florence 2 Image Captioning"):
1077
+ with gr.Row():
1078
+ with gr.Column(variant="panel"):
1079
+ input_img = gr.Image(label="Input Picture")
1080
+ task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
1081
+ task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
1082
+ task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
1083
+ text_input = gr.Textbox(label="Text Input (optional)")
1084
+ submit_btn = gr.Button(value="Submit")
1085
+ with gr.Column(variant="panel"):
1086
+ #OUTPUT
1087
+ output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8) # Here is the problem!
1088
+ output_img = gr.Image(label="Output Image")
1089
+ gr.Examples(
1090
+ examples=[
1091
+ ["images/image1.png", 'Object Detection'],
1092
+ ["images/image2.png", 'OCR with Region']
1093
+ ],
1094
+ inputs=[input_img, task_prompt],
1095
+ outputs=[output_text, output_img],
1096
+ fn=process_image,
1097
+ cache_examples=False,
1098
+ label='Try examples'
1099
+ )
1100
+ submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
1101
+
1102
+ demo.queue(max_size=2)
1103
  demo.launch(debug=True) # test