hanriv commited on
Commit
2bd1102
Β·
verified Β·
1 Parent(s): 5879110

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -4
app.py CHANGED
@@ -1,7 +1,306 @@
1
- import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # dhaahha
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # μ•„λž˜ μ½”λ“œλŠ” illusion diffusionμ—μ„œ μž‘μ„±ν•œ 것을 μ°Έκ³ ν•˜κΈ° μœ„ν•΄ κ°€μ Έμ˜¨ κ²ƒμž„.
2
 
3
+ import spaces
4
+ import torch
5
+ import gradio as gr
6
+ from gradio import processing_utils, utils
7
+ from PIL import Image
8
+ import random
9
 
10
+ from diffusers import (
11
+ DiffusionPipeline,
12
+ AutoencoderKL,
13
+ StableDiffusionControlNetPipeline,
14
+ ControlNetModel,
15
+ StableDiffusionLatentUpscalePipeline,
16
+ StableDiffusionImg2ImgPipeline,
17
+ StableDiffusionControlNetImg2ImgPipeline,
18
+ DPMSolverMultistepScheduler,
19
+ EulerDiscreteScheduler
20
+ )
21
+ import tempfile
22
+ import time
23
+ from share_btn import community_icon_html, loading_icon_html, share_js
24
+ import user_history
25
+ from illusion_style import css
26
+ import os
27
+ # from transformers import CLIPImageProcessor
28
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
29
 
30
+
31
+ BASE_MODEL = ""
32
+ # BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
33
+
34
+ # Initialize both pipelines
35
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
36
+ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
37
+
38
+ # Initialize the safety checker conditionally
39
+ # λ³΄μ•ˆ κ΄€λ ¨.
40
+ SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
41
+ safety_checker = None
42
+ # feature_extractor = None
43
+ if SAFETY_CHECKER_ENABLED:
44
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
45
+ # feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
46
+
47
+ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
48
+ BASE_MODEL,
49
+ controlnet=controlnet,
50
+ vae=vae,
51
+ safety_checker=safety_checker,
52
+ # feature_extractor=feature_extractor,
53
+ torch_dtype=torch.float16,
54
+ ).to("cuda")
55
+
56
+ # Function to check NSFW images
57
+ #def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
58
+ # if SAFETY_CHECKER_ENABLED:
59
+ # safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
60
+ # has_nsfw_concepts = safety_checker(
61
+ # images=[images],
62
+ # clip_input=safety_checker_input.pixel_values.to("cuda")
63
+ # )
64
+ # return images, has_nsfw_concepts
65
+ # else:
66
+ # return images, [False] * len(images)
67
+
68
+ #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
69
+ #main_pipe.unet.to(memory_format=torch.channels_last)
70
+ #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
71
+ #model_id = "stabilityai/sd-x2-latent-upscaler"
72
+ image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
73
+
74
+
75
+ #image_pipe.unet = torch.compile(image_pipe.unet, mode="reduce-overhead", fullgraph=True)
76
+ #upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
77
+ #upscaler.to("cuda")
78
+
79
+
80
+ # Sampler map
81
+ SAMPLER_MAP = {
82
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
83
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
84
+ }
85
+
86
+ #이미지 짧은 μͺ½ κΈ°μ€€μœΌλ‘œ μ •μ‚¬κ°ν˜• 크기, 쀑심 κΈ°μ€€μœΌλ‘œ 자λ₯Ό μ˜μ—­. 자λ₯Έ κ²°κ³Ό λ°˜ν™˜
87
+ def center_crop_resize(img, output_size=(512, 512)):
88
+ width, height = img.size
89
+
90
+ # Calculate dimensions to crop to the center
91
+ new_dimension = min(width, height)
92
+ left = (width - new_dimension)/2
93
+ top = (height - new_dimension)/2
94
+ right = (width + new_dimension)/2
95
+ bottom = (height + new_dimension)/2
96
+ # Crop and resize
97
+ img = img.crop((left, top, right, bottom))
98
+ img = img.resize(output_size)
99
+
100
+ return img
101
+
102
+ #이미지 ν™•λŒ€ μ‹œ 빈 곡간에 픽셀을 λΌμ›Œ λ„£μ–΄ μžμ—°μŠ€λŸ½κ²Œ λ§Œλ“€μ–΄ μ£ΌλŠ”
103
+ def common_upscale(samples, width, height, upscale_method, crop=False):
104
+ if crop == "center":
105
+ old_width = samples.shape[3]
106
+ old_height = samples.shape[2]
107
+ old_aspect = old_width / old_height
108
+ new_aspect = width / height
109
+ x = 0
110
+ y = 0
111
+ if old_aspect > new_aspect:
112
+ x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
113
+ elif old_aspect < new_aspect:
114
+ y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
115
+ s = samples[:,:,y:old_height-y,x:old_width-x]
116
+ else:
117
+ s = samples
118
+
119
+ return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
120
+
121
+ #μœ„μ™€ λ§ˆμ°¬κ°€μ§€λ‘œ μ—…μŠ€μΌ€μΌλ§
122
+ def upscale(samples, upscale_method, scale_by):
123
+ #s = samples.copy()
124
+ width = round(samples["images"].shape[3] * scale_by)
125
+ height = round(samples["images"].shape[2] * scale_by)
126
+ s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
127
+ return (s)
128
+
129
+ # μ‚¬μš©μžκ°€ 빈 값을 μž…λ ₯ ν•˜λŠ”μ§€ μ•ˆ ν•˜λŠ”μ§€
130
+ def check_inputs(prompt: str, control_image: Image.Image):
131
+ if control_image is None:
132
+ raise gr.Error("Please select or upload an Input Illusion")
133
+ if prompt is None or prompt == "":
134
+ raise gr.Error("Prompt is required")
135
+
136
+ # Base64 -> PIL
137
+ def convert_to_pil(base64_image):
138
+ pil_image = Image.open(base64_image)
139
+ return pil_image
140
+
141
+
142
+ # PIL -> Base64
143
+ def convert_to_base64(pil_image):
144
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
145
+ image.save(temp_file.name)
146
+ return temp_file.name
147
+
148
+ # Inference function
149
+ @spaces.GPU
150
+ def inference(
151
+ control_image: Image.Image,
152
+ prompt: str,
153
+ negative_prompt: str,
154
+ guidance_scale: float = 8.0,
155
+ controlnet_conditioning_scale: float = 1,
156
+ control_guidance_start: float = 1,
157
+ control_guidance_end: float = 1,
158
+ upscaler_strength: float = 0.5,
159
+ seed: int = -1,
160
+ sampler = "DPM++ Karras SDE",
161
+ progress = gr.Progress(track_tqdm=True),
162
+ profile: gr.OAuthProfile | None = None,
163
+ ):
164
+ start_time = time.time()
165
+ start_time_struct = time.localtime(start_time)
166
+ start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
167
+ print(f"Inference started at {start_time_formatted}")
168
+
169
+ # Generate the initial image
170
+ #init_image = init_pipe(prompt).images[0]
171
+
172
+ # Rest of your existing code
173
+ control_image_small = center_crop_resize(control_image)
174
+ control_image_large = center_crop_resize(control_image, (1024, 1024))
175
+
176
+ main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
177
+ my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
178
+ generator = torch.Generator(device="cuda").manual_seed(my_seed)
179
+
180
+ out = main_pipe(
181
+ prompt=prompt,
182
+ negative_prompt=negative_prompt,
183
+ image=control_image_small,
184
+ guidance_scale=float(guidance_scale),
185
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
186
+ generator=generator,
187
+ control_guidance_start=float(control_guidance_start),
188
+ control_guidance_end=float(control_guidance_end),
189
+ num_inference_steps=15,
190
+ output_type="latent"
191
+ )
192
+ upscaled_latents = upscale(out, "nearest-exact", 2)
193
+ out_image = image_pipe(
194
+ prompt=prompt,
195
+ negative_prompt=negative_prompt,
196
+ control_image=control_image_large,
197
+ image=upscaled_latents,
198
+ guidance_scale=float(guidance_scale),
199
+ generator=generator,
200
+ num_inference_steps=20,
201
+ strength=upscaler_strength,
202
+ control_guidance_start=float(control_guidance_start),
203
+ control_guidance_end=float(control_guidance_end),
204
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale)
205
+ )
206
+ end_time = time.time()
207
+ end_time_struct = time.localtime(end_time)
208
+ end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
209
+ print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
210
+
211
+ # Save image + metadata
212
+ # 메타 데이터도 μžˆμ–΄ λ‚˜μ€‘μ— 뢄석 μ‹œ μš©μ΄ν•  λ“―.
213
+ user_history.save_image(
214
+ label=prompt,
215
+ image=out_image["images"][0],
216
+ profile=profile,
217
+ metadata={
218
+ "prompt": prompt,
219
+ "negative_prompt": negative_prompt,
220
+ "guidance_scale": guidance_scale,
221
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
222
+ "control_guidance_start": control_guidance_start,
223
+ "control_guidance_end": control_guidance_end,
224
+ "upscaler_strength": upscaler_strength,
225
+ "seed": seed,
226
+ "sampler": sampler,
227
+ },
228
+ )
229
+
230
+ return out_image["images"][0], gr.update(visible=True), gr.update(visible=True), my_seed
231
+
232
+ with gr.Blocks() as app:
233
+ gr.Markdown(
234
+ '''
235
+ <div style="text-align: center;">
236
+ <h1>Destroy Deepfake, Protect Image πŸŒ€</h1>
237
+ <p style="font-size:16px;">Generate your image with a protective shield. Try it now!</p>
238
+ <p>When you upload an image, a protective filter is applied and the modified image is outputted. Even if malicious users try to use the protected photo for deepfake synthesis, the protective filter will ensure that the results are distorted.</p>
239
+ <p>If you have any questions, please contact us at the email address on the right. <a href="[email protected]"></p>
240
+ <p>Please send your feedback to this address. <a href="https://여기에 ν”Όλ“œλ°±ν•  ꡬ글폼 λ§Œλ“€μ–΄ μ˜¬λ¦°λ‹€κ±°λ‚˜.."> It will greatly help us improve our service. Given a prompt and your pattern, we use a QR code conditioned controlnet to create a stunning illusion! Credit to: <a href="https://twitter.com/MrUgleh">MrUgleh</a> for discovering the workflow :)</p>
241
+ </div>
242
+ '''
243
+ )
244
+
245
+ # μž…μΆœλ ₯ 이미지 μ €μž₯ μ‹œ μ‚¬μš©
246
+ # state_img_input = gr.State()
247
+ # state_img_output = gr.State()
248
+
249
+
250
+ with gr.Row():
251
+ with gr.Column():
252
+ control_image = gr.Image(label="Input your image", type="pil", elem_id="control_image")
253
+ controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.8, label="protecting strength", elem_id="illusion_strength", info="ControlNet conditioning scale")
254
+ # gr.Examples(examples=["checkers.png", "checkers_mid.jpg", "pattern.png", "ultra_checkers.png", "spiral.jpeg", "funky.jpeg" ], inputs=control_image)
255
+ '''
256
+ prompt = gr.Textbox(label="Prompt", elem_id="prompt", info="Type what you want to generate", placeholder="Medieval village scene with busy streets and castle in the distance")
257
+ negative_prompt = gr.Textbox(label="Negative Prompt", info="Type what you don't want to see", value="low quality", elem_id="negative_prompt")
258
+ with gr.Accordion(label="Advanced Options", open=False):
259
+ guidance_scale = gr.Slider(minimum=0.0, maximum=50.0, step=0.25, value=7.5, label="Guidance Scale")
260
+ sampler = gr.Dropdown(choices=list(SAMPLER_MAP.keys()), value="Euler")
261
+ control_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0, label="Start of ControlNet")
262
+ control_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="End of ControlNet")
263
+ strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Strength of the upscaler")
264
+ seed = gr.Slider(minimum=-1, maximum=9999999999, step=1, value=-1, label="Seed", info="-1 means random seed")
265
+ used_seed = gr.Number(label="Last seed used",interactive=False)
266
+ run_btn = gr.Button("Run")
267
+ with gr.Column():
268
+ result_image = gr.Image(label="Illusion Diffusion Output", interactive=False, elem_id="output")
269
+ with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
270
+ community_icon = gr.HTML(community_icon_html)
271
+ loading_icon = gr.HTML(loading_icon_html)
272
+ share_button = gr.Button("Share to community", elem_id="share-btn")
273
+
274
+ prompt.submit(
275
+ check_inputs,
276
+ inputs=[prompt, control_image],
277
+ queue=False
278
+ ).success(
279
+ inference,
280
+ inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
281
+ outputs=[result_image, result_image, share_group, used_seed])
282
+ '''
283
+ run_btn = gr.Button("Run")
284
+
285
+
286
+ run_btn.click(
287
+ check_inputs,
288
+ inputs=[prompt, control_image],
289
+ queue=False
290
+ ).success(
291
+ inference,
292
+ inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
293
+ outputs=[result_image, result_image, share_group, used_seed])
294
+
295
+ share_button.click(None, [], [], js=share_js)
296
+
297
+ with gr.Blocks(css=css) as app_with_history:
298
+ with gr.Tab("Demo"):
299
+ app.render()
300
+ with gr.Tab("Past generations"):
301
+ user_history.render()
302
+
303
+ app_with_history.queue(max_size=20,api_open=False )
304
+
305
+ if __name__ == "__main__":
306
+ app_with_history.launch(max_threads=400)