tori29umai commited on
Commit
27419c1
1 Parent(s): 5a40f49
Files changed (1) hide show
  1. app.py +34 -7
app.py CHANGED
@@ -7,7 +7,8 @@ import os
7
  import time
8
 
9
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
10
- from utils.prompt_analysis import PromptAnalysis
 
11
 
12
 
13
  def load_model(lora_dir, cn_dir):
@@ -28,14 +29,14 @@ def load_model(lora_dir, cn_dir):
28
  return pipe
29
 
30
 
31
-
32
-
33
-
34
  class Img2Img:
35
  def __init__(self):
36
  self.setup_paths()
37
  self.setup_models()
38
  self.demo = self.layout()
 
 
 
39
 
40
  def setup_paths(self):
41
  self.path = os.getcwd()
@@ -52,6 +53,33 @@ class Img2Img:
52
  load_tagger_model(self.tagger_dir)
53
  load_lora_model(self.lora_dir)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def layout(self):
56
  css = """
57
  #intro{
@@ -64,8 +92,7 @@ class Img2Img:
64
  with gr.Row():
65
  with gr.Column():
66
  self.input_image_path = gr.Image(label="input_image", type='filepath')
67
- self.prompt_analysis = PromptAnalysis(self.tagger_dir)
68
- self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
69
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
70
  generate_button = gr.Button("生成")
71
  with gr.Column():
@@ -111,4 +138,4 @@ class Img2Img:
111
  return output_image
112
 
113
  img2img = Img2Img()
114
- img2img.demo.launch(share=True, server_name="none", server_port=7890)
 
7
  import time
8
 
9
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
10
+ from utils.prompt_utils import remove_color
11
+ from utils.tagger import modelLoad, analysis
12
 
13
 
14
  def load_model(lora_dir, cn_dir):
 
29
  return pipe
30
 
31
 
 
 
 
32
  class Img2Img:
33
  def __init__(self):
34
  self.setup_paths()
35
  self.setup_models()
36
  self.demo = self.layout()
37
+ self.default_nagative_prompt = "lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry"
38
+ self.post_filter = True
39
+ self.tagger_model = None
40
 
41
  def setup_paths(self):
42
  self.path = os.getcwd()
 
53
  load_tagger_model(self.tagger_dir)
54
  load_lora_model(self.lora_dir)
55
 
56
+
57
+ def prompt_layout(self, input_image_path):
58
+ with gr.Column():
59
+ with gr.Row():
60
+ self.prompt = gr.Textbox(label="prompt", lines=3)
61
+ with gr.Row():
62
+ self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value=self.default_nagative_prompt)
63
+ with gr.Row():
64
+ self.prompt_analysis_button = gr.Button("prompt解析")
65
+
66
+ self.prompt_analysis_button.click(
67
+ self.process_prompt_analysis,
68
+ inputs=[input_image_path],
69
+ outputs=self.prompt
70
+ )
71
+ return [self.prompt, self.negative_prompt]
72
+
73
+ def process_prompt_analysis(self, input_image_path):
74
+ if self.tagger_model is None:
75
+ self.tagger_model = modelLoad(self.tagger_dir)
76
+ tags = analysis(input_image_path, self.tagger_dir, self.tagger_model)
77
+ tags_list = tags
78
+ if self.post_filter:
79
+ tags_list = remove_color(tags)
80
+ return tags_list
81
+
82
+
83
  def layout(self):
84
  css = """
85
  #intro{
 
92
  with gr.Row():
93
  with gr.Column():
94
  self.input_image_path = gr.Image(label="input_image", type='filepath')
95
+ self.prompt, self.negative_prompt = self.prompt_layout(self.input_image_path)
 
96
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
97
  generate_button = gr.Button("生成")
98
  with gr.Column():
 
138
  return output_image
139
 
140
  img2img = Img2Img()
141
+ img2img.demo.launch(share=True)