jerukperas commited on
Commit
10d19d1
·
1 Parent(s): 13e5846

update app.py

Browse files
Files changed (2) hide show
  1. app.py +35 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,12 +2,14 @@ import torch
2
  import gradio as gr
3
  from optimum.onnxruntime import ORTModelForCausalLM
4
  from transformers import AutoTokenizer
 
5
 
6
  # https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a
7
  model_id = "p1atdev/dart-v2-sft"
8
  model = ORTModelForCausalLM.from_pretrained(model_id)
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
 
11
 
12
 
13
  # https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor
@@ -20,14 +22,14 @@ def get_tokens_as_list(word_list):
20
  return tokens_list
21
 
22
 
23
- def generate_tags(general_tags: str):
24
  # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format
25
  general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag)
26
  prompt = (
27
  "<|bos|>"
28
  # "<copyright></copyright>"
29
  # "<character></character>"
30
- "<|rating:general|><|aspect_ratio:tall|><|length:long|>"
31
  f"<general>{general_tags}<|identity:none|><|input_end|>"
32
  )
33
 
@@ -46,17 +48,44 @@ def generate_tags(general_tags: str):
46
  # bad_words_ids=bad_words_ids,
47
  )
48
 
49
- return ", ".join(
50
  [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  demo = gr.Interface(
55
  fn=generate_tags,
56
- inputs=gr.TextArea("1girl, black hair", lines=4),
57
- outputs=gr.Textbox(show_copy_button=True),
 
 
 
 
 
 
 
 
 
 
58
  clear_btn=None,
59
  analytics_enabled=False,
 
60
  )
61
 
62
- demo.launch()
 
2
  import gradio as gr
3
  from optimum.onnxruntime import ORTModelForCausalLM
4
  from transformers import AutoTokenizer
5
+ from huggingface_hub import InferenceClient
6
 
7
  # https://huggingface.co/collections/p1atdev/dart-v2-danbooru-tags-transformer-v2-66291115701b6fe773399b0a
8
  model_id = "p1atdev/dart-v2-sft"
9
  model = ORTModelForCausalLM.from_pretrained(model_id)
10
  tokenizer = AutoTokenizer.from_pretrained(model_id)
11
  tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)
12
+ txt2imgclient = InferenceClient()
13
 
14
 
15
  # https://huggingface.co/docs/transformers/v4.44.2/en/internal/generation_utils#transformers.NoBadWordsLogitsProcessor
 
22
  return tokens_list
23
 
24
 
25
+ def generate_tags(general_tags: str, generate_image: bool = False):
26
  # https://huggingface.co/p1atdev/dart-v2-sft#prompt-format
27
  general_tags = ",".join(tag.strip() for tag in general_tags.split(",") if tag)
28
  prompt = (
29
  "<|bos|>"
30
  # "<copyright></copyright>"
31
  # "<character></character>"
32
+ "<|rating:general|><|aspect_ratio:tall|><|length:medium|>"
33
  f"<general>{general_tags}<|identity:none|><|input_end|>"
34
  )
35
 
 
48
  # bad_words_ids=bad_words_ids,
49
  )
50
 
51
+ output_tags = ", ".join(
52
  [tag for tag in tokenizer.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != ""]
53
  )
54
 
55
+ yield (output_tags, None)
56
+
57
+ if generate_image:
58
+ txt2img_prompt = f"score_9, score_8_up, score_7_up, {output_tags}"
59
+ img = txt2imgclient.text_to_image(
60
+ prompt=txt2img_prompt,
61
+ negative_prompt="score_6, score_5, score_4, rating_explicit, child, loli, shota",
62
+ num_inference_steps=25,
63
+ height=1152,
64
+ width=896,
65
+ model="John6666/wai-real-mix-v8-sdxl",
66
+ scheduler="EulerAncestralDiscreteScheduler",
67
+ )
68
+
69
+ yield (output_tags, img)
70
+
71
 
72
  demo = gr.Interface(
73
  fn=generate_tags,
74
+ inputs=[
75
+ gr.TextArea("1girl, black hair", lines=4),
76
+ gr.Checkbox(
77
+ False,
78
+ label="Generate Image",
79
+ info="Generating image using InferenceClient (really slow) with output_tags as prompt",
80
+ ),
81
+ ],
82
+ outputs=[
83
+ gr.Textbox(label="output_tags", show_copy_button=True),
84
+ gr.Image(label="generated_image", format="jpeg", type="pil"),
85
+ ],
86
  clear_btn=None,
87
  analytics_enabled=False,
88
+ concurrency_limit=64,
89
  )
90
 
91
+ demo.queue().launch()
requirements.txt CHANGED
@@ -4,4 +4,5 @@
4
  gradio==4.42.0
5
  torch
6
  transformers
7
- optimum[onnxruntime] # or optimum[onnxruntime-gpu]
 
 
4
  gradio==4.42.0
5
  torch
6
  transformers
7
+ optimum[onnxruntime] # or optimum[onnxruntime-gpu]
8
+ Pillow