tori29umai commited on
Commit
b2790c5
1 Parent(s): 4f189ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -68
app.py CHANGED
@@ -23,90 +23,78 @@ dl_cn_config(cn_dir)
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
26
- @spaces.GPU(duration=120)
27
- def predict(lora_model, input_image_path, prompt, negative_prompt, controlnet_scale, load_model_fn):
28
- # LoRAモデルに基づきpipeを取得
29
- pipe = load_model_fn(lora_model)
30
- input_image = Image.open(input_image_path)
31
- base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
32
- resize_image = resize_image_aspect_ratio(input_image)
33
- resize_base_image = resize_image_aspect_ratio(base_image)
34
- generator = torch.manual_seed(0)
35
- last_time = time.time()
36
-
37
- # プロンプト生成
38
- prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
39
- execute_tags = ["realistic", "nose", "asian"]
40
- prompt = execute_prompt(execute_tags, prompt)
41
- prompt = remove_duplicates(prompt)
42
- prompt = remove_color(prompt)
43
- print(prompt)
44
-
45
- # 画像生成
46
- output_image = pipe(
47
- image=resize_base_image,
48
- control_image=resize_image,
49
- strength=1.0,
50
- prompt=prompt,
51
- negative_prompt=negative_prompt,
52
- controlnet_conditioning_scale=float(controlnet_scale),
53
- generator=generator,
54
- num_inference_steps=30,
55
- eta=1.0,
56
- ).images[0]
57
- print(f"Time taken: {time.time() - last_time}")
58
- output_image = output_image.resize(input_image.size, Image.LANCZOS)
59
- return output_image
60
-
61
-
62
  class Img2Img:
63
  def __init__(self):
64
  self.demo = self.layout()
65
  self.tagger_model = None
66
  self.input_image_path = None
67
  self.bg_removed_image = None
68
- self.pipe = None
69
- self.current_lora_model = None
70
 
71
- def process_prompt_analysis(self, input_image_path):
72
- if self.tagger_model is None:
73
- self.tagger_model = modelLoad(tagger_dir)
74
- tags = analysis(input_image_path, tagger_dir, self.tagger_model)
75
- prompt = remove_color(tags)
76
- execute_tags = ["realistic", "nose", "asian"]
77
- prompt = execute_prompt(execute_tags, prompt)
78
- prompt = remove_duplicates(prompt)
79
- return prompt
80
-
81
-
82
  def load_model(self, lora_model):
83
- # 既に正しいpipeがロードされている場合は再利用
84
- if self.pipe and self.current_lora_model == lora_model:
85
- return self.pipe # キャッシュされたpipeを返す
86
-
87
- # 新しいpipeの生成
88
  dtype = torch.float16
89
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
90
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
91
 
92
- self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
93
- "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
94
  )
95
- self.pipe.enable_model_cpu_offload()
96
 
97
  # LoRAモデルの設定
98
  if lora_model == "とりにく風":
99
- self.pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
100
  elif lora_model == "少女漫画風":
101
- self.pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
102
  elif lora_model == "劇画調風":
103
- self.pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
104
  elif lora_model == "プレーン":
105
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # 現在のlora_modelを保存
108
- self.current_lora_model = lora_model
109
- return self.pipe
 
 
 
 
 
 
110
 
111
  def layout(self):
112
  css = """
@@ -119,7 +107,8 @@ class Img2Img:
119
  with gr.Blocks(css=css) as demo:
120
  with gr.Row():
121
  with gr.Column():
122
- self.lora_model = gr.Dropdown(label="Image Style", choices=["プレーン", "とりにく風", "少女漫画風"], value="プレーン")
 
123
  self.input_image_path = gr.Image(label="Input image", type='filepath')
124
  self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
125
 
@@ -146,8 +135,7 @@ class Img2Img:
146
  )
147
 
148
  generate_button.click(
149
- fn=lambda lora_model, input_image_path, prompt, negative_prompt, controlnet_scale:
150
- predict(lora_model, input_image_path, prompt, negative_prompt, controlnet_scale, self.load_model),
151
  inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
152
  outputs=self.output_image
153
  )
 
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class Img2Img:
27
  def __init__(self):
28
  self.demo = self.layout()
29
  self.tagger_model = None
30
  self.input_image_path = None
31
  self.bg_removed_image = None
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  def load_model(self, lora_model):
 
 
 
 
 
34
  dtype = torch.float16
35
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
36
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
37
 
38
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
39
+ "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
40
  )
41
+ pipe.enable_model_cpu_offload()
42
 
43
  # LoRAモデルの設定
44
  if lora_model == "とりにく風":
45
+ pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
46
  elif lora_model == "少女漫画風":
47
+ pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
48
  elif lora_model == "劇画調風":
49
+ pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
50
  elif lora_model == "プレーン":
51
+ pass # プレーンの場合はLoRAを読み込まない
52
+
53
+ return pipe
54
+
55
+ @spaces.GPU(duration=120)
56
+ def predict(self, lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
57
+ pipe = self.load_model(lora_model)
58
+ input_image = Image.open(input_image_path)
59
+ base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
60
+ resize_image = resize_image_aspect_ratio(input_image)
61
+ resize_base_image = resize_image_aspect_ratio(base_image)
62
+ generator = torch.manual_seed(0)
63
+ last_time = time.time()
64
+
65
+ # プロンプト生成
66
+ prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
67
+ execute_tags = ["realistic", "nose", "asian"]
68
+ prompt = execute_prompt(execute_tags, prompt)
69
+ prompt = remove_duplicates(prompt)
70
+ prompt = remove_color(prompt)
71
+ print(prompt)
72
+
73
+ # 画像生成
74
+ output_image = pipe(
75
+ image=resize_base_image,
76
+ control_image=resize_image,
77
+ strength=1.0,
78
+ prompt=prompt,
79
+ negative_prompt=negative_prompt,
80
+ controlnet_conditioning_scale=float(controlnet_scale),
81
+ generator=generator,
82
+ num_inference_steps=30,
83
+ eta=1.0,
84
+ ).images[0]
85
+ print(f"Time taken: {time.time() - last_time}")
86
+ output_image = output_image.resize(input_image.size, Image.LANCZOS)
87
+ return output_image
88
 
89
+ def process_prompt_analysis(self, input_image_path):
90
+ if self.tagger_model is None:
91
+ self.tagger_model = modelLoad(tagger_dir)
92
+ tags = analysis(input_image_path, tagger_dir, self.tagger_model)
93
+ prompt = remove_color(tags)
94
+ execute_tags = ["realistic", "nose", "asian"]
95
+ prompt = execute_prompt(execute_tags, prompt)
96
+ prompt = remove_duplicates(prompt)
97
+ return prompt
98
 
99
  def layout(self):
100
  css = """
 
107
  with gr.Blocks(css=css) as demo:
108
  with gr.Row():
109
  with gr.Column():
110
+ # LoRAモデル選択ドロップダウン
111
+ self.lora_model = gr.Dropdown(label="Image Style", choices=["プレーン", "とりにく風", "少女漫画風", "劇画調風"], value="プレーン")
112
  self.input_image_path = gr.Image(label="Input image", type='filepath')
113
  self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
114
 
 
135
  )
136
 
137
  generate_button.click(
138
+ fn=self.predict,
 
139
  inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
140
  outputs=self.output_image
141
  )