tori29umai commited on
Commit
a9deb63
1 Parent(s): 9c9bde8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -16
app.py CHANGED
@@ -23,34 +23,27 @@ dl_cn_config(cn_dir)
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
26
- def load_model(lora_dir, cn_dir):
27
- dtype = torch.float16
28
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
29
- controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
30
-
31
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
32
- "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
33
- )
34
- pipe.enable_model_cpu_offload()
35
- pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
36
- return pipe
37
 
38
  @spaces.GPU(duration=120)
39
- def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
40
- pipe = load_model(lora_dir, cn_dir)
 
41
  input_image = Image.open(input_image_path)
42
  base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
43
  resize_image = resize_image_aspect_ratio(input_image)
44
  resize_base_image = resize_image_aspect_ratio(base_image)
45
  generator = torch.manual_seed(0)
46
  last_time = time.time()
 
 
47
  prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
48
  execute_tags = ["realistic", "nose", "asian"]
49
  prompt = execute_prompt(execute_tags, prompt)
50
- prompt = remove_duplicates(prompt)
51
  prompt = remove_color(prompt)
52
  print(prompt)
53
 
 
54
  output_image = pipe(
55
  image=resize_base_image,
56
  control_image=resize_image,
@@ -66,13 +59,43 @@ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
66
  output_image = output_image.resize(input_image.size, Image.LANCZOS)
67
  return output_image
68
 
 
69
  class Img2Img:
70
  def __init__(self):
71
  self.demo = self.layout()
72
  self.tagger_model = None
73
  self.input_image_path = None
74
  self.bg_removed_image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
76
  def process_prompt_analysis(self, input_image_path):
77
  if self.tagger_model is None:
78
  self.tagger_model = modelLoad(tagger_dir)
@@ -94,6 +117,7 @@ class Img2Img:
94
  with gr.Blocks(css=css) as demo:
95
  with gr.Row():
96
  with gr.Column():
 
97
  self.input_image_path = gr.Image(label="Input image", type='filepath')
98
  self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
99
 
@@ -120,8 +144,8 @@ class Img2Img:
120
  )
121
 
122
  generate_button.click(
123
- fn=predict,
124
- inputs=[self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
125
  outputs=self.output_image
126
  )
127
  return demo
 
23
  dl_tagger_model(tagger_dir)
24
  dl_lora_model(lora_dir)
25
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @spaces.GPU(duration=120)
28
+ def predict(self, lora_model, input_image_path, prompt, negative_prompt, controlnet_scale):
29
+ # LoRAモデルに基づきpipeを取得
30
+ pipe = self.load_model(lora_model)
31
  input_image = Image.open(input_image_path)
32
  base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
33
  resize_image = resize_image_aspect_ratio(input_image)
34
  resize_base_image = resize_image_aspect_ratio(base_image)
35
  generator = torch.manual_seed(0)
36
  last_time = time.time()
37
+
38
+ # プロンプト生成
39
  prompt = "masterpiece, best quality, monochrome, greyscale, lineart, white background, star-shaped pupils, " + prompt
40
  execute_tags = ["realistic", "nose", "asian"]
41
  prompt = execute_prompt(execute_tags, prompt)
42
+ prompt = remove_duplicates(prompt)
43
  prompt = remove_color(prompt)
44
  print(prompt)
45
 
46
+ # 画像生成
47
  output_image = pipe(
48
  image=resize_base_image,
49
  control_image=resize_image,
 
59
  output_image = output_image.resize(input_image.size, Image.LANCZOS)
60
  return output_image
61
 
62
+
63
  class Img2Img:
64
  def __init__(self):
65
  self.demo = self.layout()
66
  self.tagger_model = None
67
  self.input_image_path = None
68
  self.bg_removed_image = None
69
+ self.pipe = None
70
+ self.current_lora_model = None
71
+
72
+ def load_model(self, lora_model):
73
+ # 既に正しいpipeがロードされている場合は再利用
74
+ if self.pipe and self.current_lora_model == lora_model:
75
+ return self.pipe # キャッシュされたpipeを返す
76
+
77
+ # 新しいpipeの生成
78
+ dtype = torch.float16
79
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
80
+ controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
81
+
82
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
83
+ "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
84
+ )
85
+ self.pipe.enable_model_cpu_offload()
86
+
87
+ # LoRAモデルの設定
88
+ if lora_model == "とりにく風":
89
+ self.pipe.load_lora_weights(lora_dir, weight_name="tori29umai_line.safetensors")
90
+ elif lora_model == "少女漫画風":
91
+ self.pipe.load_lora_weights(lora_dir, weight_name="syoujomannga_line.safetensors")
92
+ elif lora_model == "劇画調風":
93
+ self.pipe.load_lora_weights(lora_dir, weight_name="gekiga_line.safetensors")
94
 
95
+ # 現在のlora_modelを保存
96
+ self.current_lora_model = lora_model
97
+ return self.pipe
98
+
99
  def process_prompt_analysis(self, input_image_path):
100
  if self.tagger_model is None:
101
  self.tagger_model = modelLoad(tagger_dir)
 
117
  with gr.Blocks(css=css) as demo:
118
  with gr.Row():
119
  with gr.Column():
120
+ self.lora_model = gr.Dropdown(label="Image Style", choices=["プレーン", "とりにく風", "少女漫画風"], value="プレーン")
121
  self.input_image_path = gr.Image(label="Input image", type='filepath')
122
  self.bg_removed_image_path = gr.Image(label="Background Removed Image", type='filepath')
123
 
 
144
  )
145
 
146
  generate_button.click(
147
+ fn=self.predict,
148
+ inputs=[self.lora_model, self.bg_removed_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
149
  outputs=self.output_image
150
  )
151
  return demo