Spaces:
Runtime error
Runtime error
Commit
·
72b8226
1
Parent(s):
35222b5
update
Browse files
app.py
CHANGED
@@ -7,6 +7,13 @@ import math
|
|
7 |
import spaces
|
8 |
import torch
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
|
11 |
normal_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl.safetensors")
|
12 |
|
@@ -30,11 +37,11 @@ pipe_edit = CosStableDiffusionXLInstructPix2PixPipeline.from_single_file(
|
|
30 |
edit_file, num_in_channels=8
|
31 |
)
|
32 |
pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
|
33 |
-
pipe_edit.to(
|
34 |
|
35 |
pipe_normal = StableDiffusionXLPipeline.from_single_file(normal_file, torch_dtype=torch.float16)
|
36 |
pipe_normal.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
|
37 |
-
pipe_normal.to(
|
38 |
|
39 |
@spaces.GPU
|
40 |
def run_normal(prompt, negative_prompt="", guidance_scale=7, progress=gr.Progress(track_tqdm=True)):
|
|
|
7 |
import spaces
|
8 |
import torch
|
9 |
|
10 |
+
if torch.backends.mps.is_available():
|
11 |
+
DEVICE = "mps"
|
12 |
+
elif torch.cuda.is_available():
|
13 |
+
DEVICE = "cuda"
|
14 |
+
else:
|
15 |
+
DEVICE = "cpu"
|
16 |
+
|
17 |
edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors")
|
18 |
normal_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl.safetensors")
|
19 |
|
|
|
37 |
edit_file, num_in_channels=8
|
38 |
)
|
39 |
pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
|
40 |
+
pipe_edit.to(DEVICE)
|
41 |
|
42 |
pipe_normal = StableDiffusionXLPipeline.from_single_file(normal_file, torch_dtype=torch.float16)
|
43 |
pipe_normal.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
|
44 |
+
pipe_normal.to(DEVICE)
|
45 |
|
46 |
@spaces.GPU
|
47 |
def run_normal(prompt, negative_prompt="", guidance_scale=7, progress=gr.Progress(track_tqdm=True)):
|