multimodalart HF staff commited on
Commit
de26694
·
verified ·
1 Parent(s): af04902

Update app_ctrlx.py

Browse files
Files changed (1) hide show
  1. app_ctrlx.py +6 -4
app_ctrlx.py CHANGED
@@ -9,6 +9,8 @@ from ctrl_x.pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
9
  from ctrl_x.utils import *
10
  from ctrl_x.utils.sdxl import *
11
 
 
 
12
 
13
  parser = ArgumentParser()
14
  parser.add_argument("-m", "--model", type=str, default=None) # Optionally, load model checkpoint from single file
@@ -21,19 +23,19 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
  model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
22
  refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- variant = "fp16" if device == "cuda" else "fp32"
25
 
26
  scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") # TODO: Support other schedulers
27
  if args.model is None:
28
  pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
29
- model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, variant=variant, use_safetensors=True
30
  )
31
  else:
32
  print(f"Using weights {args.model} for SDXL base model.")
33
  pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
34
  refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
35
  refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae,
36
- torch_dtype=torch_dtype, variant=variant, use_safetensors=True,
37
  )
38
 
39
  if torch.cuda.is_available():
@@ -155,7 +157,7 @@ title = """
155
  </div>
156
  """
157
 
158
-
159
  def inference(
160
  structure_image, appearance_image,
161
  prompt, structure_prompt, appearance_prompt,
 
9
  from ctrl_x.utils import *
10
  from ctrl_x.utils.sdxl import *
11
 
12
+ import spaces
13
+
14
 
15
  parser = ArgumentParser()
16
  parser.add_argument("-m", "--model", type=str, default=None) # Optionally, load model checkpoint from single file
 
23
  model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
24
  refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ #variant = "fp16" if device == "cuda" else "fp32"
27
 
28
  scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") # TODO: Support other schedulers
29
  if args.model is None:
30
  pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
31
+ model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, use_safetensors=True
32
  )
33
  else:
34
  print(f"Using weights {args.model} for SDXL base model.")
35
  pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
36
  refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
37
  refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae,
38
+ torch_dtype=torch_dtype, use_safetensors=True,
39
  )
40
 
41
  if torch.cuda.is_available():
 
157
  </div>
158
  """
159
 
160
+ @spaces.GPU
161
  def inference(
162
  structure_image, appearance_image,
163
  prompt, structure_prompt, appearance_prompt,