developerskyebrowse commited on
Commit
12646d5
·
1 Parent(s): 0abe9ec

local and empty cache on hf

Browse files
Files changed (2) hide show
  1. app.py +1 -2
  2. local_app.py +106 -90
app.py CHANGED
@@ -364,7 +364,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
364
  # image processing
365
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
366
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
367
- preprocessor.load("NormalBae")
368
  return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
369
 
370
  # AI image processing
@@ -434,7 +433,7 @@ def process_image(
434
  ).images[0]
435
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
436
  # torch.cuda.synchronize()
437
- # torch.cuda.empty_cache()
438
  return results
439
 
440
  if prod:
 
364
  # image processing
365
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
366
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
 
367
  return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
368
 
369
  # AI image processing
 
433
  ).images[0]
434
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
435
  # torch.cuda.synchronize()
436
+ torch.cuda.empty_cache()
437
  return results
438
 
439
  if prod:
local_app.py CHANGED
@@ -10,6 +10,8 @@ import random
10
  import time
11
  import gradio as gr
12
  import numpy as np
 
 
13
  import gc
14
  import torch
15
  from PIL import Image
@@ -17,85 +19,105 @@ from diffusers import (
17
  ControlNetModel,
18
  DPMSolverMultistepScheduler,
19
  StableDiffusionControlNetPipeline,
20
- AutoencoderKL,
21
  )
22
- from diffusers.models.attention_processor import AttnProcessor2_0
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  API_KEY = os.environ.get("API_KEY", None)
 
25
 
26
  print("CUDA version:", torch.version.cuda)
27
  print("loading everything")
28
  compiled = False
29
 
30
- if gr.NO_RELOAD:
31
- torch.cuda.max_memory_allocated(device="cuda")
32
- # Controlnet Normal
33
- model_id = "lllyasviel/control_v11p_sd15_normalbae"
34
- print("initializing controlnet")
35
- controlnet = ControlNetModel.from_pretrained(
36
- model_id,
37
- torch_dtype=torch.float16,
38
- attn_implementation="flash_attention_2",
39
- ).to("cuda")
40
 
41
- # Scheduler
42
- scheduler = DPMSolverMultistepScheduler.from_pretrained(
43
- "runwayml/stable-diffusion-v1-5",
44
- solver_order=2,
45
- subfolder="scheduler",
46
- use_karras_sigmas=True,
47
- final_sigmas_type="sigma_min",
48
- algorithm_type="sde-dpmsolver++",
49
- prediction_type="epsilon",
50
- thresholding=False,
51
- denoise_final=True,
52
- device_map="cuda",
53
- torch_dtype=torch.float16,
54
- )
55
 
56
- # Stable Diffusion Pipeline URL
57
- # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
58
- base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
59
- vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
 
 
 
 
 
 
 
60
 
61
- # print('loading vae')
62
- # vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
63
- # vae.to(memory_format=torch.channels_last)
64
 
65
- print('loading pipe')
66
- pipe = StableDiffusionControlNetPipeline.from_single_file(
67
- base_model_url,
68
- safety_checker=None,
69
- # load_safety_checker=True,
70
- controlnet=controlnet,
71
- scheduler=scheduler,
72
- # vae=vae,
73
- torch_dtype=torch.float16,
74
- )
75
 
76
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
77
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
78
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
79
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
80
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
81
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
82
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
83
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
84
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
85
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
86
- pipe.to("cuda")
87
 
88
- print("loading preprocessor")
89
- from preprocess import Preprocessor
90
- preprocessor = Preprocessor()
91
- preprocessor.load("NormalBae")
 
 
 
 
 
 
 
 
 
 
92
 
93
- print("---------------Loaded controlnet pipeline---------------")
94
- pipe.unet.set_attn_processor(AttnProcessor2_0())
95
- torch.cuda.empty_cache()
96
- gc.collect()
97
- print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
98
- print("Model Compiled!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
101
  if randomize_seed:
@@ -227,7 +249,7 @@ footer {
227
  visibility: hidden;
228
  }
229
  .gradio-container {
230
- max-width: 900px !important;
231
  }
232
  .gr-image {
233
  display: flex;
@@ -255,14 +277,14 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
255
  label="Image resolution",
256
  minimum=256,
257
  maximum=1024,
258
- value=768,
259
  step=256,
260
  )
261
  preprocess_resolution = gr.Slider(
262
  label="Preprocess resolution",
263
  minimum=128,
264
  maximum=1024,
265
- value=768,
266
  step=1,
267
  )
268
  num_steps = gr.Slider(
@@ -282,11 +304,13 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
282
  value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
283
  )
284
  #############################################################################
 
285
  with gr.Column():
286
  prompt = gr.Textbox(
287
- label="Custom Prompt",
288
- placeholder="boho chic",
289
  )
 
290
  with gr.Row(visible=True):
291
  style_selection = gr.Radio(
292
  show_label=True,
@@ -297,24 +321,24 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
297
  label="Design Styles",
298
  )
299
  # input image
300
- with gr.Row():
301
- with gr.Column():
302
  image = gr.Image(
303
  label="Input",
304
  sources=["upload"],
305
  show_label=True,
306
  mirror_webcam=True,
307
- format="webp",
308
  )
309
  # run button
310
  with gr.Column():
311
  run_button = gr.Button(value="Use this one", size="lg", visible=False)
312
  # output image
313
- with gr.Column():
314
  result = gr.Image(
315
  label="Output",
316
  interactive=False,
317
- format="webp",
318
  show_share_button= False,
319
  )
320
  # Use this image button
@@ -333,28 +357,22 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
333
  guidance_scale,
334
  seed,
335
  ]
336
-
337
  with gr.Row():
338
  helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
339
-
340
  # image processing
341
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
342
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
343
  return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
344
-
345
- # # AI Image Processing
346
- # @gr.on(triggers=[use_ai_button.click], inputs=config, outputs=result, show_progress="minimal")
347
- # def submit(result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
348
- # return process_image(result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
349
-
350
  @gr.on(triggers=[use_ai_button.click], inputs=[result] + config, outputs=[image, result], show_progress="minimal")
351
  def submit(previous_result, image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
352
  # First, yield the previous result to update the input image immediately
353
  yield previous_result, gr.update()
354
-
355
  # Then, process the new input image
356
  new_result = process_image(previous_result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
357
-
358
  # Finally, yield the new result
359
  yield previous_result, new_result
360
 
@@ -362,12 +380,13 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
362
  @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
363
  def turn_buttons_off():
364
  return gr.update(visible=False), gr.update(visible=False)
365
-
366
  # Turn on buttons when processing is complete
367
  @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
368
  def turn_buttons_on():
369
  return gr.update(visible=True), gr.update(visible=True)
370
 
 
371
  @torch.inference_mode()
372
  def process_image(
373
  image,
@@ -386,11 +405,9 @@ def process_image(
386
  preprocess_start = time.time()
387
  print("processing image")
388
 
389
- # global preprocessor
390
- # preprocessor.load("NormalBae")
391
-
392
  seed = random.randint(0, MAX_SEED)
393
  generator = torch.cuda.manual_seed(seed)
 
394
  control_image = preprocessor(
395
  image=image,
396
  image_resolution=image_resolution,
@@ -415,9 +432,8 @@ def process_image(
415
  image=control_image,
416
  ).images[0]
417
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
418
- # results.save(os.path.join("/data", "temp_image.jpg"))
419
  # torch.cuda.synchronize()
420
- # torch.cuda.empty_cache()
421
  return results
422
 
423
  if prod:
 
10
  import time
11
  import gradio as gr
12
  import numpy as np
13
+ # import spaces
14
+ # import imageio
15
  import gc
16
  import torch
17
  from PIL import Image
 
19
  ControlNetModel,
20
  DPMSolverMultistepScheduler,
21
  StableDiffusionControlNetPipeline,
22
+ # AutoencoderKL,
23
  )
24
+ from controlnet_aux_local import NormalBaeDetector
25
+
26
  MAX_SEED = np.iinfo(np.int32).max
27
  API_KEY = os.environ.get("API_KEY", None)
28
+ # os.environ['HF_HOME'] = '/data/.huggingface'
29
 
30
  print("CUDA version:", torch.version.cuda)
31
  print("loading everything")
32
  compiled = False
33
 
34
+ class Preprocessor:
35
+ MODEL_ID = "lllyasviel/Annotators"
 
 
 
 
 
 
 
 
36
 
37
+ def __init__(self):
38
+ self.model = None
39
+ self.name = ""
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def load(self, name: str) -> None:
42
+ if name == self.name:
43
+ return
44
+ elif name == "NormalBae":
45
+ print("Loading NormalBae")
46
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
47
+ torch.cuda.empty_cache()
48
+ self.name = name
49
+ else:
50
+ raise ValueError
51
+ return
52
 
53
+ def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
54
+ return self.model(image, **kwargs)
 
55
 
56
+ # torch.cuda.max_memory_allocated(device="cuda")
 
 
 
 
 
 
 
 
 
57
 
58
+ # Controlnet Normal
59
+ model_id = "lllyasviel/control_v11p_sd15_normalbae"
60
+ print("initializing controlnet")
61
+ controlnet = ControlNetModel.from_pretrained(
62
+ model_id,
63
+ torch_dtype=torch.float16,
64
+ attn_implementation="flash_attention_2",
65
+ ).to("cuda")
 
 
 
66
 
67
+ # Scheduler
68
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
69
+ "runwayml/stable-diffusion-v1-5",
70
+ solver_order=2,
71
+ subfolder="scheduler",
72
+ use_karras_sigmas=True,
73
+ final_sigmas_type="sigma_min",
74
+ algorithm_type="sde-dpmsolver++",
75
+ prediction_type="epsilon",
76
+ thresholding=False,
77
+ denoise_final=True,
78
+ device_map="cuda",
79
+ torch_dtype=torch.float16,
80
+ )
81
 
82
+ # Stable Diffusion Pipeline URL
83
+ # base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
84
+ base_model_url = "https://huggingface.co/Lykon/AbsoluteReality/blob/main/AbsoluteReality_1.8.1_pruned.safetensors"
85
+ # vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
86
+
87
+ # print('loading vae')
88
+ # vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
89
+ # vae.to(memory_format=torch.channels_last)
90
+
91
+ print('loading pipe')
92
+ pipe = StableDiffusionControlNetPipeline.from_single_file(
93
+ base_model_url,
94
+ safety_checker=None,
95
+ controlnet=controlnet,
96
+ scheduler=scheduler,
97
+ # vae=vae,
98
+ torch_dtype=torch.float16,
99
+ ).to("cuda")
100
+
101
+ print("loading preprocessor")
102
+ preprocessor = Preprocessor()
103
+ preprocessor.load("NormalBae")
104
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
105
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
106
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="fcNeg-neg.pt", token="fcNeg-neg")
107
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
108
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
109
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_pet_play.pt", token="HDA_pet_play")
110
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_unconventional maid.pt", token="HDA_unconventional_maid")
111
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NakedHoodie.pt", token="HDA_NakedHoodie")
112
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
113
+ # pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
114
+ pipe.to("cuda")
115
+
116
+ print("---------------Loaded controlnet pipeline---------------")
117
+ torch.cuda.empty_cache()
118
+ gc.collect()
119
+ print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB")
120
+ print("Model Compiled!")
121
 
122
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
123
  if randomize_seed:
 
249
  visibility: hidden;
250
  }
251
  .gradio-container {
252
+ max-width: 1100px !important;
253
  }
254
  .gr-image {
255
  display: flex;
 
277
  label="Image resolution",
278
  minimum=256,
279
  maximum=1024,
280
+ value=512,
281
  step=256,
282
  )
283
  preprocess_resolution = gr.Slider(
284
  label="Preprocess resolution",
285
  minimum=128,
286
  maximum=1024,
287
+ value=512,
288
  step=1,
289
  )
290
  num_steps = gr.Slider(
 
304
  value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
305
  )
306
  #############################################################################
307
+ # input text
308
  with gr.Column():
309
  prompt = gr.Textbox(
310
+ label="Custom Design",
311
+ placeholder="Enter a description (optional)",
312
  )
313
+ # design options
314
  with gr.Row(visible=True):
315
  style_selection = gr.Radio(
316
  show_label=True,
 
321
  label="Design Styles",
322
  )
323
  # input image
324
+ with gr.Row(equal_height=True):
325
+ with gr.Column(scale=1, min_width=300):
326
  image = gr.Image(
327
  label="Input",
328
  sources=["upload"],
329
  show_label=True,
330
  mirror_webcam=True,
331
+ type="pil",
332
  )
333
  # run button
334
  with gr.Column():
335
  run_button = gr.Button(value="Use this one", size="lg", visible=False)
336
  # output image
337
+ with gr.Column(scale=1, min_width=300):
338
  result = gr.Image(
339
  label="Output",
340
  interactive=False,
341
+ type="pil",
342
  show_share_button= False,
343
  )
344
  # Use this image button
 
357
  guidance_scale,
358
  seed,
359
  ]
360
+
361
  with gr.Row():
362
  helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True)
363
+
364
  # image processing
365
  @gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal")
366
  def auto_process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
367
  return process_image(image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
368
+
369
+ # AI image processing
 
 
 
 
370
  @gr.on(triggers=[use_ai_button.click], inputs=[result] + config, outputs=[image, result], show_progress="minimal")
371
  def submit(previous_result, image, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
372
  # First, yield the previous result to update the input image immediately
373
  yield previous_result, gr.update()
 
374
  # Then, process the new input image
375
  new_result = process_image(previous_result, style_selection, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
 
376
  # Finally, yield the new result
377
  yield previous_result, new_result
378
 
 
380
  @gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden")
381
  def turn_buttons_off():
382
  return gr.update(visible=False), gr.update(visible=False)
383
+
384
  # Turn on buttons when processing is complete
385
  @gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden")
386
  def turn_buttons_on():
387
  return gr.update(visible=True), gr.update(visible=True)
388
 
389
+ # @spaces.GPU(duration=12)
390
  @torch.inference_mode()
391
  def process_image(
392
  image,
 
405
  preprocess_start = time.time()
406
  print("processing image")
407
 
 
 
 
408
  seed = random.randint(0, MAX_SEED)
409
  generator = torch.cuda.manual_seed(seed)
410
+ preprocessor.load("NormalBae")
411
  control_image = preprocessor(
412
  image=image,
413
  image_resolution=image_resolution,
 
432
  image=control_image,
433
  ).images[0]
434
  print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------")
 
435
  # torch.cuda.synchronize()
436
+ torch.cuda.empty_cache()
437
  return results
438
 
439
  if prod: