cocktailpeanut commited on
Commit
72b8226
·
1 Parent(s): 35222b5
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -7,6 +7,13 @@ import math
7
  import spaces
8
  import torch
9
 
 
 
 
 
 
 
 
10
  edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
11
  normal_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl.safetensors")
12
 
@@ -30,11 +37,11 @@ pipe_edit = CosStableDiffusionXLInstructPix2PixPipeline.from_single_file(
30
  edit_file, num_in_channels=8
31
  )
32
  pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
33
- pipe_edit.to("cuda")
34
 
35
  pipe_normal = StableDiffusionXLPipeline.from_single_file(normal_file, torch_dtype=torch.float16)
36
  pipe_normal.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
37
- pipe_normal.to("cuda")
38
 
39
  @spaces.GPU
40
  def run_normal(prompt, negative_prompt="", guidance_scale=7, progress=gr.Progress(track_tqdm=True)):
 
7
  import spaces
8
  import torch
9
 
10
+ if torch.backends.mps.is_available():
11
+ DEVICE = "mps"
12
+ elif torch.cuda.is_available():
13
+ DEVICE = "cuda"
14
+ else:
15
+ DEVICE = "cpu"
16
+
17
  edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
18
  normal_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl.safetensors")
19
 
 
37
  edit_file, num_in_channels=8
38
  )
39
  pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
40
+ pipe_edit.to(DEVICE)
41
 
42
  pipe_normal = StableDiffusionXLPipeline.from_single_file(normal_file, torch_dtype=torch.float16)
43
  pipe_normal.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
44
+ pipe_normal.to(DEVICE)
45
 
46
  @spaces.GPU
47
  def run_normal(prompt, negative_prompt="", guidance_scale=7, progress=gr.Progress(track_tqdm=True)):