PeterL1n commited on
Commit
59f3984
·
verified ·
1 Parent(s): d20698c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -20
app.py CHANGED
@@ -7,7 +7,6 @@ from diffusers.image_processor import VaeImageProcessor
7
  from transformers import CLIPImageProcessor
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
10
- from PIL import Image
11
 
12
  device = "cuda"
13
  dtype = torch.float16
@@ -21,30 +20,24 @@ opts = {
21
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
22
  }
23
 
24
- # Default to load 4-step model.
25
- step_loaded = 4
26
- unet = UNet2DConditionModel.from_config(base, subfolder="unet")
27
- unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
28
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
29
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
30
-
31
- # Safety checker.
32
- safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
33
- feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
34
- image_processor = VaeImageProcessor(vae_scale_factor=8)
35
-
36
  # Inference function.
37
- @spaces.GPU(enable_queue=True)
38
  def generate(prompt, option, progress=gr.Progress()):
39
- global step_loaded
40
  print(prompt, option)
41
  ckpt, step = opts[option]
 
 
 
 
 
 
 
 
 
 
 
 
42
  progress((0, step))
43
-
44
- if pipe.device.type != device:
45
- pipe.to(device, dtype)
46
- if safety_checker.device.type != device:
47
- safety_checker.to(device, dtype)
48
 
49
  if step != step_loaded:
50
  print(f"Switching checkpoint from {step_loaded} to {step}")
 
7
  from transformers import CLIPImageProcessor
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
 
10
 
11
  device = "cuda"
12
  dtype = torch.float16
 
20
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Inference function.
24
+ @spaces.GPU()
25
  def generate(prompt, option, progress=gr.Progress()):
 
26
  print(prompt, option)
27
  ckpt, step = opts[option]
28
+
29
+ # Main pipeline.
30
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet")
31
+ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]))).to(device, dtype)
32
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
33
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing").to(device, dtype)
34
+
35
+ # Safety checker.
36
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
37
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
38
+ image_processor = VaeImageProcessor(vae_scale_factor=8)
39
+
40
  progress((0, step))
 
 
 
 
 
41
 
42
  if step != step_loaded:
43
  print(f"Switching checkpoint from {step_loaded} to {step}")