Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,306 @@
|
|
1 |
-
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# μλ μ½λλ illusion diffusionμμ μμ±ν κ²μ μ°Έκ³ νκΈ° μν΄ κ°μ Έμ¨ κ²μ.
|
2 |
|
3 |
+
import spaces
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from gradio import processing_utils, utils
|
7 |
+
from PIL import Image
|
8 |
+
import random
|
9 |
|
10 |
+
from diffusers import (
|
11 |
+
DiffusionPipeline,
|
12 |
+
AutoencoderKL,
|
13 |
+
StableDiffusionControlNetPipeline,
|
14 |
+
ControlNetModel,
|
15 |
+
StableDiffusionLatentUpscalePipeline,
|
16 |
+
StableDiffusionImg2ImgPipeline,
|
17 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
18 |
+
DPMSolverMultistepScheduler,
|
19 |
+
EulerDiscreteScheduler
|
20 |
+
)
|
21 |
+
import tempfile
|
22 |
+
import time
|
23 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
24 |
+
import user_history
|
25 |
+
from illusion_style import css
|
26 |
+
import os
|
27 |
+
# from transformers import CLIPImageProcessor
|
28 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
29 |
|
30 |
+
|
31 |
+
BASE_MODEL = ""
|
32 |
+
# BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
33 |
+
|
34 |
+
# Initialize both pipelines
|
35 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
36 |
+
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
|
37 |
+
|
38 |
+
# Initialize the safety checker conditionally
|
39 |
+
# 보μ κ΄λ ¨.
|
40 |
+
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
41 |
+
safety_checker = None
|
42 |
+
# feature_extractor = None
|
43 |
+
if SAFETY_CHECKER_ENABLED:
|
44 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
|
45 |
+
# feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
46 |
+
|
47 |
+
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
48 |
+
BASE_MODEL,
|
49 |
+
controlnet=controlnet,
|
50 |
+
vae=vae,
|
51 |
+
safety_checker=safety_checker,
|
52 |
+
# feature_extractor=feature_extractor,
|
53 |
+
torch_dtype=torch.float16,
|
54 |
+
).to("cuda")
|
55 |
+
|
56 |
+
# Function to check NSFW images
|
57 |
+
#def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
58 |
+
# if SAFETY_CHECKER_ENABLED:
|
59 |
+
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
60 |
+
# has_nsfw_concepts = safety_checker(
|
61 |
+
# images=[images],
|
62 |
+
# clip_input=safety_checker_input.pixel_values.to("cuda")
|
63 |
+
# )
|
64 |
+
# return images, has_nsfw_concepts
|
65 |
+
# else:
|
66 |
+
# return images, [False] * len(images)
|
67 |
+
|
68 |
+
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
69 |
+
#main_pipe.unet.to(memory_format=torch.channels_last)
|
70 |
+
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
71 |
+
#model_id = "stabilityai/sd-x2-latent-upscaler"
|
72 |
+
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
|
73 |
+
|
74 |
+
|
75 |
+
#image_pipe.unet = torch.compile(image_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
76 |
+
#upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
77 |
+
#upscaler.to("cuda")
|
78 |
+
|
79 |
+
|
80 |
+
# Sampler map
|
81 |
+
SAMPLER_MAP = {
|
82 |
+
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
|
83 |
+
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
|
84 |
+
}
|
85 |
+
|
86 |
+
#μ΄λ―Έμ§ 짧μ μͺ½ κΈ°μ€μΌλ‘ μ μ¬κ°ν ν¬κΈ°, μ€μ¬ κΈ°μ€μΌλ‘ μλ₯Ό μμ. μλ₯Έ κ²°κ³Ό λ°ν
|
87 |
+
def center_crop_resize(img, output_size=(512, 512)):
|
88 |
+
width, height = img.size
|
89 |
+
|
90 |
+
# Calculate dimensions to crop to the center
|
91 |
+
new_dimension = min(width, height)
|
92 |
+
left = (width - new_dimension)/2
|
93 |
+
top = (height - new_dimension)/2
|
94 |
+
right = (width + new_dimension)/2
|
95 |
+
bottom = (height + new_dimension)/2
|
96 |
+
# Crop and resize
|
97 |
+
img = img.crop((left, top, right, bottom))
|
98 |
+
img = img.resize(output_size)
|
99 |
+
|
100 |
+
return img
|
101 |
+
|
102 |
+
#μ΄λ―Έμ§ νλ μ λΉ κ³΅κ°μ ν½μ
μ λΌμ λ£μ΄ μμ°μ€λ½κ² λ§λ€μ΄ μ£Όλ
|
103 |
+
def common_upscale(samples, width, height, upscale_method, crop=False):
|
104 |
+
if crop == "center":
|
105 |
+
old_width = samples.shape[3]
|
106 |
+
old_height = samples.shape[2]
|
107 |
+
old_aspect = old_width / old_height
|
108 |
+
new_aspect = width / height
|
109 |
+
x = 0
|
110 |
+
y = 0
|
111 |
+
if old_aspect > new_aspect:
|
112 |
+
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
113 |
+
elif old_aspect < new_aspect:
|
114 |
+
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
115 |
+
s = samples[:,:,y:old_height-y,x:old_width-x]
|
116 |
+
else:
|
117 |
+
s = samples
|
118 |
+
|
119 |
+
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
120 |
+
|
121 |
+
#μμ λ§μ°¬κ°μ§λ‘ μ
μ€μΌμΌλ§
|
122 |
+
def upscale(samples, upscale_method, scale_by):
|
123 |
+
#s = samples.copy()
|
124 |
+
width = round(samples["images"].shape[3] * scale_by)
|
125 |
+
height = round(samples["images"].shape[2] * scale_by)
|
126 |
+
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
|
127 |
+
return (s)
|
128 |
+
|
129 |
+
# μ¬μ©μκ° λΉ κ°μ μ
λ ₯ νλμ§ μ νλμ§
|
130 |
+
def check_inputs(prompt: str, control_image: Image.Image):
|
131 |
+
if control_image is None:
|
132 |
+
raise gr.Error("Please select or upload an Input Illusion")
|
133 |
+
if prompt is None or prompt == "":
|
134 |
+
raise gr.Error("Prompt is required")
|
135 |
+
|
136 |
+
# Base64 -> PIL
|
137 |
+
def convert_to_pil(base64_image):
|
138 |
+
pil_image = Image.open(base64_image)
|
139 |
+
return pil_image
|
140 |
+
|
141 |
+
|
142 |
+
# PIL -> Base64
|
143 |
+
def convert_to_base64(pil_image):
|
144 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
145 |
+
image.save(temp_file.name)
|
146 |
+
return temp_file.name
|
147 |
+
|
148 |
+
# Inference function
|
149 |
+
@spaces.GPU
|
150 |
+
def inference(
|
151 |
+
control_image: Image.Image,
|
152 |
+
prompt: str,
|
153 |
+
negative_prompt: str,
|
154 |
+
guidance_scale: float = 8.0,
|
155 |
+
controlnet_conditioning_scale: float = 1,
|
156 |
+
control_guidance_start: float = 1,
|
157 |
+
control_guidance_end: float = 1,
|
158 |
+
upscaler_strength: float = 0.5,
|
159 |
+
seed: int = -1,
|
160 |
+
sampler = "DPM++ Karras SDE",
|
161 |
+
progress = gr.Progress(track_tqdm=True),
|
162 |
+
profile: gr.OAuthProfile | None = None,
|
163 |
+
):
|
164 |
+
start_time = time.time()
|
165 |
+
start_time_struct = time.localtime(start_time)
|
166 |
+
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
|
167 |
+
print(f"Inference started at {start_time_formatted}")
|
168 |
+
|
169 |
+
# Generate the initial image
|
170 |
+
#init_image = init_pipe(prompt).images[0]
|
171 |
+
|
172 |
+
# Rest of your existing code
|
173 |
+
control_image_small = center_crop_resize(control_image)
|
174 |
+
control_image_large = center_crop_resize(control_image, (1024, 1024))
|
175 |
+
|
176 |
+
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
|
177 |
+
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
|
178 |
+
generator = torch.Generator(device="cuda").manual_seed(my_seed)
|
179 |
+
|
180 |
+
out = main_pipe(
|
181 |
+
prompt=prompt,
|
182 |
+
negative_prompt=negative_prompt,
|
183 |
+
image=control_image_small,
|
184 |
+
guidance_scale=float(guidance_scale),
|
185 |
+
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
|
186 |
+
generator=generator,
|
187 |
+
control_guidance_start=float(control_guidance_start),
|
188 |
+
control_guidance_end=float(control_guidance_end),
|
189 |
+
num_inference_steps=15,
|
190 |
+
output_type="latent"
|
191 |
+
)
|
192 |
+
upscaled_latents = upscale(out, "nearest-exact", 2)
|
193 |
+
out_image = image_pipe(
|
194 |
+
prompt=prompt,
|
195 |
+
negative_prompt=negative_prompt,
|
196 |
+
control_image=control_image_large,
|
197 |
+
image=upscaled_latents,
|
198 |
+
guidance_scale=float(guidance_scale),
|
199 |
+
generator=generator,
|
200 |
+
num_inference_steps=20,
|
201 |
+
strength=upscaler_strength,
|
202 |
+
control_guidance_start=float(control_guidance_start),
|
203 |
+
control_guidance_end=float(control_guidance_end),
|
204 |
+
controlnet_conditioning_scale=float(controlnet_conditioning_scale)
|
205 |
+
)
|
206 |
+
end_time = time.time()
|
207 |
+
end_time_struct = time.localtime(end_time)
|
208 |
+
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
|
209 |
+
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
|
210 |
+
|
211 |
+
# Save image + metadata
|
212 |
+
# λ©ν λ°μ΄ν°λ μμ΄ λμ€μ λΆμ μ μ©μ΄ν λ―.
|
213 |
+
user_history.save_image(
|
214 |
+
label=prompt,
|
215 |
+
image=out_image["images"][0],
|
216 |
+
profile=profile,
|
217 |
+
metadata={
|
218 |
+
"prompt": prompt,
|
219 |
+
"negative_prompt": negative_prompt,
|
220 |
+
"guidance_scale": guidance_scale,
|
221 |
+
"controlnet_conditioning_scale": controlnet_conditioning_scale,
|
222 |
+
"control_guidance_start": control_guidance_start,
|
223 |
+
"control_guidance_end": control_guidance_end,
|
224 |
+
"upscaler_strength": upscaler_strength,
|
225 |
+
"seed": seed,
|
226 |
+
"sampler": sampler,
|
227 |
+
},
|
228 |
+
)
|
229 |
+
|
230 |
+
return out_image["images"][0], gr.update(visible=True), gr.update(visible=True), my_seed
|
231 |
+
|
232 |
+
with gr.Blocks() as app:
|
233 |
+
gr.Markdown(
|
234 |
+
'''
|
235 |
+
<div style="text-align: center;">
|
236 |
+
<h1>Destroy Deepfake, Protect Image π</h1>
|
237 |
+
<p style="font-size:16px;">Generate your image with a protective shield. Try it now!</p>
|
238 |
+
<p>When you upload an image, a protective filter is applied and the modified image is outputted. Even if malicious users try to use the protected photo for deepfake synthesis, the protective filter will ensure that the results are distorted.</p>
|
239 |
+
<p>If you have any questions, please contact us at the email address on the right. <a href="[email protected]"></p>
|
240 |
+
<p>Please send your feedback to this address. <a href="https://μ¬κΈ°μ νΌλλ°±ν ꡬκΈνΌ λ§λ€μ΄ μ¬λ¦°λ€κ±°λ.."> It will greatly help us improve our service. Given a prompt and your pattern, we use a QR code conditioned controlnet to create a stunning illusion! Credit to: <a href="https://twitter.com/MrUgleh">MrUgleh</a> for discovering the workflow :)</p>
|
241 |
+
</div>
|
242 |
+
'''
|
243 |
+
)
|
244 |
+
|
245 |
+
# μ
μΆλ ₯ μ΄λ―Έμ§ μ μ₯ μ μ¬μ©
|
246 |
+
# state_img_input = gr.State()
|
247 |
+
# state_img_output = gr.State()
|
248 |
+
|
249 |
+
|
250 |
+
with gr.Row():
|
251 |
+
with gr.Column():
|
252 |
+
control_image = gr.Image(label="Input your image", type="pil", elem_id="control_image")
|
253 |
+
controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.8, label="protecting strength", elem_id="illusion_strength", info="ControlNet conditioning scale")
|
254 |
+
# gr.Examples(examples=["checkers.png", "checkers_mid.jpg", "pattern.png", "ultra_checkers.png", "spiral.jpeg", "funky.jpeg" ], inputs=control_image)
|
255 |
+
'''
|
256 |
+
prompt = gr.Textbox(label="Prompt", elem_id="prompt", info="Type what you want to generate", placeholder="Medieval village scene with busy streets and castle in the distance")
|
257 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", info="Type what you don't want to see", value="low quality", elem_id="negative_prompt")
|
258 |
+
with gr.Accordion(label="Advanced Options", open=False):
|
259 |
+
guidance_scale = gr.Slider(minimum=0.0, maximum=50.0, step=0.25, value=7.5, label="Guidance Scale")
|
260 |
+
sampler = gr.Dropdown(choices=list(SAMPLER_MAP.keys()), value="Euler")
|
261 |
+
control_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0, label="Start of ControlNet")
|
262 |
+
control_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="End of ControlNet")
|
263 |
+
strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Strength of the upscaler")
|
264 |
+
seed = gr.Slider(minimum=-1, maximum=9999999999, step=1, value=-1, label="Seed", info="-1 means random seed")
|
265 |
+
used_seed = gr.Number(label="Last seed used",interactive=False)
|
266 |
+
run_btn = gr.Button("Run")
|
267 |
+
with gr.Column():
|
268 |
+
result_image = gr.Image(label="Illusion Diffusion Output", interactive=False, elem_id="output")
|
269 |
+
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
|
270 |
+
community_icon = gr.HTML(community_icon_html)
|
271 |
+
loading_icon = gr.HTML(loading_icon_html)
|
272 |
+
share_button = gr.Button("Share to community", elem_id="share-btn")
|
273 |
+
|
274 |
+
prompt.submit(
|
275 |
+
check_inputs,
|
276 |
+
inputs=[prompt, control_image],
|
277 |
+
queue=False
|
278 |
+
).success(
|
279 |
+
inference,
|
280 |
+
inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
|
281 |
+
outputs=[result_image, result_image, share_group, used_seed])
|
282 |
+
'''
|
283 |
+
run_btn = gr.Button("Run")
|
284 |
+
|
285 |
+
|
286 |
+
run_btn.click(
|
287 |
+
check_inputs,
|
288 |
+
inputs=[prompt, control_image],
|
289 |
+
queue=False
|
290 |
+
).success(
|
291 |
+
inference,
|
292 |
+
inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
|
293 |
+
outputs=[result_image, result_image, share_group, used_seed])
|
294 |
+
|
295 |
+
share_button.click(None, [], [], js=share_js)
|
296 |
+
|
297 |
+
with gr.Blocks(css=css) as app_with_history:
|
298 |
+
with gr.Tab("Demo"):
|
299 |
+
app.render()
|
300 |
+
with gr.Tab("Past generations"):
|
301 |
+
user_history.render()
|
302 |
+
|
303 |
+
app_with_history.queue(max_size=20,api_open=False )
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
app_with_history.launch(max_threads=400)
|