nuwandaa commited on
Commit
39a774c
·
1 Parent(s): 7f487ff

Get device automatically

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -52,7 +52,7 @@ is_attention_slicing_enabled = True
52
 
53
  # Load model
54
  dtype = torch.float16
55
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
56
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
57
 
58
  model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -63,7 +63,7 @@ pipeline = DiffusionPipeline.from_pretrained(
63
  variant="fp16",
64
  use_safetensors=True,
65
  torch_dtype=dtype,
66
- ).to('cuda')
67
 
68
  if is_attention_slicing_enabled:
69
  pipeline.enable_attention_slicing()
@@ -75,13 +75,13 @@ if is_cpu_offload_enabled:
75
  @spaces.GPU
76
  def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
77
  try:
78
- generator = torch.Generator('cuda').manual_seed(seed)
79
  prompt = "" # Set prompt to null
80
 
81
  source_image_pure = gradio_image["background"]
82
  mask_image_pure = gradio_image["layers"][0]
83
- source_image = preprocess_image(source_image_pure.convert('RGB'), 'cuda')
84
- mask = preprocess_mask(mask_image_pure, 'cuda')
85
 
86
  START_STEP = 0 # AAS start step
87
  END_STEP = int(strength * num_inference_steps) # AAS end step
 
52
 
53
  # Load model
54
  dtype = torch.float16
55
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
56
  scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
57
 
58
  model_path = "stabilityai/stable-diffusion-xl-base-1.0"
 
63
  variant="fp16",
64
  use_safetensors=True,
65
  torch_dtype=dtype,
66
+ ).to(device)
67
 
68
  if is_attention_slicing_enabled:
69
  pipeline.enable_attention_slicing()
 
75
  @spaces.GPU
76
  def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
77
  try:
78
+ generator = torch.Generator(device).manual_seed(seed)
79
  prompt = "" # Set prompt to null
80
 
81
  source_image_pure = gradio_image["background"]
82
  mask_image_pure = gradio_image["layers"][0]
83
+ source_image = preprocess_image(source_image_pure.convert('RGB'), device)
84
+ mask = preprocess_mask(mask_image_pure, device)
85
 
86
  START_STEP = 0 # AAS start step
87
  END_STEP = int(strength * num_inference_steps) # AAS end step