bedead commited on
Commit
7c5e77b
·
verified ·
1 Parent(s): 08ada40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -43
app.py CHANGED
@@ -18,6 +18,8 @@ config_file = hf_hub_download(
18
  filename="config_promax.json",
19
  )
20
 
 
 
21
  config = ControlNetModel_Union.load_config(config_file)
22
  controlnet_model = ControlNetModel_Union.from_config(config)
23
  model_file = hf_hub_download(
@@ -28,11 +30,10 @@ state_dict = load_state_dict(model_file)
28
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
-
33
  model.to(device=device, dtype=torch.float16 if device == 'cuda' else torch.float32)
 
34
  vae = AutoencoderKL.from_pretrained(
35
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 if device == 'cuda' else torch.float32
36
  ).to(device)
37
 
38
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
@@ -40,21 +41,22 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
41
  vae=vae,
42
  controlnet=model,
43
- variant="fp16" if device == 'cuda' else None,
44
  ).to(device)
45
 
46
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- prompt = "high quality"
49
- (
50
- prompt_embeds,
51
- negative_prompt_embeds,
52
- pooled_prompt_embeds,
53
- negative_pooled_prompt_embeds,
54
- ) = pipe.encode_prompt(prompt, device, True)
55
-
56
-
57
- def fill_image(image, model_selection):
58
  source = image["background"]
59
  mask = image["layers"][0]
60
 
@@ -72,8 +74,14 @@ def fill_image(image, model_selection):
72
  ):
73
  yield image, cnet_image
74
 
75
- image = image.convert("RGBA")
76
- cnet_image.paste(image, (0, 0), binary_mask)
 
 
 
 
 
 
77
 
78
  yield source, cnet_image
79
 
@@ -82,30 +90,33 @@ def clear_result():
82
  return gr.update(value=None)
83
 
84
 
85
- css = """
86
- .gradio-container {
87
- width: 1024px !important;
88
- }
89
- """
90
-
91
-
92
- title = """<h1 align="center">Diffusers Image Fill</h1>
93
- <div align="center">Draw the mask over the subject you want to erase or change.</div>
94
- <div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-image-fill'>Diffusers Image Fill</a>.</div>
95
- """
96
-
97
- with gr.Blocks(css=css) as demo:
98
- gr.HTML(title)
99
-
100
- run_button = gr.Button("Generate")
 
 
 
 
 
 
 
101
 
102
  with gr.Row():
103
  input_image = gr.ImageMask(
104
- type="pil",
105
- label="Input Image",
106
- canvas_size=(1024, 1024),
107
- layers=False,
108
- sources=["upload"],
109
  )
110
 
111
  result = ImageSlider(
@@ -113,21 +124,50 @@ with gr.Blocks(css=css) as demo:
113
  label="Generated Image",
114
  )
115
 
116
- model_selection = gr.Dropdown(
117
- choices=list(MODELS.keys()),
118
- value="RealVisXL V5.0 Lightning",
119
- label="Model",
 
 
 
120
  )
121
 
122
  run_button.click(
123
  fn=clear_result,
124
  inputs=None,
125
  outputs=result,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  ).then(
127
  fn=fill_image,
128
- inputs=[input_image, model_selection],
129
  outputs=result,
 
 
 
 
130
  )
131
 
132
 
133
- demo.launch(share=False)
 
18
  filename="config_promax.json",
19
  )
20
 
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+
23
  config = ControlNetModel_Union.load_config(config_file)
24
  controlnet_model = ControlNetModel_Union.from_config(config)
25
  model_file = hf_hub_download(
 
30
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
  )
 
 
33
  model.to(device=device, dtype=torch.float16 if device == 'cuda' else torch.float32)
34
+
35
  vae = AutoencoderKL.from_pretrained(
36
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
37
  ).to(device)
38
 
39
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
 
41
  torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
42
  vae=vae,
43
  controlnet=model,
44
+ #variant="fp16",
45
  ).to(device)
46
 
47
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
48
+ if device == 'cuda':
49
+ pipe.enable_model_cpu_offload()
50
+ pipe.enable_xformers_memory_efficient_attention()
51
+
52
+ def fill_image(prompt, image, model_selection, paste_back):
53
+ (
54
+ prompt_embeds,
55
+ negative_prompt_embeds,
56
+ pooled_prompt_embeds,
57
+ negative_pooled_prompt_embeds,
58
+ ) = pipe.encode_prompt(prompt, "cuda", True)
59
 
 
 
 
 
 
 
 
 
 
 
60
  source = image["background"]
61
  mask = image["layers"][0]
62
 
 
74
  ):
75
  yield image, cnet_image
76
 
77
+ print(f"{model_selection=}")
78
+ print(f"{paste_back=}")
79
+
80
+ if paste_back:
81
+ image = image.convert("RGBA")
82
+ cnet_image.paste(image, (0, 0), binary_mask)
83
+ else:
84
+ cnet_image = image
85
 
86
  yield source, cnet_image
87
 
 
90
  return gr.update(value=None)
91
 
92
 
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown('# Ai object removal')
95
+ with gr.Row():
96
+ with gr.Column():
97
+ prompt = gr.Textbox(
98
+ label="Prompt",
99
+ info="Describe what to inpaint the mask with",
100
+ lines=3,
101
+ )
102
+ with gr.Column():
103
+ model_selection = gr.Dropdown(
104
+ choices=list(MODELS.keys()),
105
+ value="RealVisXL V5.0 Lightning",
106
+ label="Model",
107
+ visible = False,
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ run_button = gr.Button("Generate")
113
+
114
+ with gr.Column():
115
+ paste_back = gr.Checkbox(False, label="Paste back original", visible = False)
116
 
117
  with gr.Row():
118
  input_image = gr.ImageMask(
119
+ type="pil", label="Input Image", layers=False
 
 
 
 
120
  )
121
 
122
  result = ImageSlider(
 
124
  label="Generated Image",
125
  )
126
 
127
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
128
+
129
+ def use_output_as_input(output_image):
130
+ return gr.update(value=output_image[1])
131
+
132
+ use_as_input_button.click(
133
+ fn=use_output_as_input, inputs=[result], outputs=[input_image]
134
  )
135
 
136
  run_button.click(
137
  fn=clear_result,
138
  inputs=None,
139
  outputs=result,
140
+ ).then(
141
+ fn=lambda: gr.update(visible=False),
142
+ inputs=None,
143
+ outputs=use_as_input_button,
144
+ ).then(
145
+ fn=fill_image,
146
+ inputs=[prompt, input_image, model_selection, paste_back],
147
+ outputs=result,
148
+ ).then(
149
+ fn=lambda: gr.update(visible=True),
150
+ inputs=None,
151
+ outputs=use_as_input_button,
152
+ )
153
+
154
+ prompt.submit(
155
+ fn=clear_result,
156
+ inputs=None,
157
+ outputs=result,
158
+ ).then(
159
+ fn=lambda: gr.update(visible=False),
160
+ inputs=None,
161
+ outputs=use_as_input_button,
162
  ).then(
163
  fn=fill_image,
164
+ inputs=[prompt, input_image, model_selection, paste_back],
165
  outputs=result,
166
+ ).then(
167
+ fn=lambda: gr.update(visible=True),
168
+ inputs=None,
169
+ outputs=use_as_input_button,
170
  )
171
 
172
 
173
+ demo.launch(share=True,debug = True)