tori29umai commited on
Commit
f68a8d0
1 Parent(s): d49ade0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -69
app.py CHANGED
@@ -23,80 +23,83 @@ dl_cn_config(cn_dir)
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class Img2Img:
27
  def __init__(self):
28
  self.demo = self.layout()
29
  self.tagger_model = None
30
  self.input_image_path = None
31
  self.bg_removed_image = None
32
- self.pipe = None
33
- self.current_lora_model = None
34
-
35
- def load_model(self, lora_model):
36
- # 既にロードされたpipeがあり、同じLoRAモデルの場合は再利用
37
- if self.pipe and self.current_lora_model == lora_model:
38
- return self.pipe # キャッシュされたpipeを返す
39
-
40
- # 新しいpipeの生成
41
- dtype = torch.float16
42
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
43
- controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
44
-
45
- self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
46
- "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
47
- )
48
- self.pipe.enable_model_cpu_offload()
49
-
50
- # LoRAモデルの設定
51
- if lora_model == "とりにく風":
52
- self.pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
53
- elif lora_model == "少女漫画風":
54
- self.pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
55
- elif lora_model == "劇画調風":
56
- self.pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
57
- elif lora_model == "プレーン":
58
- pass # プレーンの場合はLoRAを読み込まない
59
-
60
- # 現在のLoRAモデルを保存
61
- self.current_lora_model = lora_model
62
- return self.pipe
63
-
64
- @spaces.GPU(duration=120)
65
- def predict(self, lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
66
- # ここで新たなpipeを作成するのではなく、キャッシュしたpipeを取得
67
- pipe = self.load_model(lora_model)
68
-
69
- # 画像読み込みとリサイズ
70
- input_image = Image.open(input_image_path)
71
- base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
72
- resize_image = resize_image_aspect_ratio(input_image)
73
- resize_base_image = resize_image_aspect_ratio(base_image)
74
- generator = torch.manual_seed(0)
75
- last_time = time.time()
76
-
77
- # プロンプト生成
78
- prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
79
- execute_tags = ["realistic", "nose", "asian"]
80
- prompt = execute_prompt(execute_tags, prompt)
81
- prompt = remove_duplicates(prompt)
82
- prompt = remove_color(prompt)
83
- print(prompt)
84
-
85
- # 画像生成
86
- output_image = pipe(
87
- image=resize_base_image,
88
- control_image=resize_image,
89
- strength=1.0,
90
- prompt=prompt,
91
- negative_prompt=negative_prompt,
92
- controlnet_conditioning_scale=float(controlnet_scale),
93
- generator=generator,
94
- num_inference_steps=30,
95
- eta=1.0,
96
- ).images[0]
97
- print(f"Time taken: {time.time() - last_time}")
98
- output_image = output_image.resize(input_image.size, Image.LANCZOS)
99
- return output_image
100
 
101
  def process_prompt_analysis(self, input_image_path):
102
  if self.tagger_model is None:
@@ -147,7 +150,7 @@ class Img2Img:
147
  )
148
 
149
  generate_button.click(
150
- fn=self.predict,
151
  inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
152
  outputs=self.output_image
153
  )
 
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
26
+ # グローバル変数でpipeを管理
27
+ pipe = None
28
+ current_lora_model = None
29
+
30
+ def load_model(lora_model):
31
+ global pipe, current_lora_model
32
+ # 既にロードされたpipeがあり、同じLoRAモデルの場合は再利用
33
+ if pipe is not None and current_lora_model == lora_model:
34
+ return pipe # キャッシュされたpipeを返す
35
+
36
+ # 新しいpipeの生成
37
+ dtype = torch.float16
38
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
39
+ controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
40
+
41
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
42
+ "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
43
+ )
44
+ pipe.enable_model_cpu_offload()
45
+
46
+ # LoRAモデルの設定
47
+ if lora_model == "とりにく風":
48
+ pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
49
+ elif lora_model == "少女漫画風":
50
+ pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
51
+ elif lora_model == "劇画調風":
52
+ pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
53
+ elif lora_model == "プレーン":
54
+ pass # プレーンの場合はLoRAを読み込まない
55
+
56
+ # 現在のLoRAモデルを保存
57
+ current_lora_model = lora_model
58
+ return pipe
59
+
60
+ @spaces.GPU(duration=120)
61
+ def predict(lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
62
+ # pipeをグローバル変数から取得
63
+ pipe = load_model(lora_model)
64
+
65
+ # 画像読み込みとリサイズ
66
+ input_image = Image.open(input_image_path)
67
+ base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
68
+ resize_image = resize_image_aspect_ratio(input_image)
69
+ resize_base_image = resize_image_aspect_ratio(base_image)
70
+ generator = torch.manual_seed(0)
71
+ last_time = time.time()
72
+
73
+ # プロンプト生成
74
+ prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
75
+ execute_tags = ["realistic", "nose", "asian"]
76
+ prompt = execute_prompt(execute_tags, prompt)
77
+ prompt = remove_duplicates(prompt)
78
+ prompt = remove_color(prompt)
79
+ print(prompt)
80
+
81
+ # 画像生成
82
+ output_image = pipe(
83
+ image=resize_base_image,
84
+ control_image=resize_image,
85
+ strength=1.0,
86
+ prompt=prompt,
87
+ negative_prompt=negative_prompt,
88
+ controlnet_conditioning_scale=float(controlnet_scale),
89
+ generator=generator,
90
+ num_inference_steps=30,
91
+ eta=1.0,
92
+ ).images[0]
93
+ print(f"Time taken: {time.time() - last_time}")
94
+ output_image = output_image.resize(input_image.size, Image.LANCZOS)
95
+ return output_image
96
+
97
  class Img2Img:
98
  def __init__(self):
99
  self.demo = self.layout()
100
  self.tagger_model = None
101
  self.input_image_path = None
102
  self.bg_removed_image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def process_prompt_analysis(self, input_image_path):
105
  if self.tagger_model is None:
 
150
  )
151
 
152
  generate_button.click(
153
+ fn=predict,
154
  inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
155
  outputs=self.output_image
156
  )