Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -59,7 +59,10 @@ model = instantiate_from_config(config)
|
|
59 |
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
|
60 |
load_state_dict(model, ckpt, strict=True)
|
61 |
model.freeze()
|
62 |
-
|
|
|
|
|
|
|
63 |
|
64 |
@torch.no_grad()
|
65 |
def process(
|
@@ -113,7 +116,7 @@ def process(
|
|
113 |
control_img = np.array(control_img)
|
114 |
|
115 |
# Convert to tensor (NCHW, [0,1])
|
116 |
-
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=
|
117 |
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
|
118 |
height, width = control.size(-2), control.size(-1)
|
119 |
model.control_scales = [strength] * 13
|
@@ -122,7 +125,7 @@ def process(
|
|
122 |
preds = []
|
123 |
for _ in tqdm(range(num_samples)):
|
124 |
shape = (1, 4, height // 8, width // 8)
|
125 |
-
x_T = torch.randn(shape, device=
|
126 |
|
127 |
if not tile_diffusion and not tile_vae:
|
128 |
samples = sampler.sample_ccsr(
|
|
|
59 |
ckpt = torch.load("weights/real-world_ccsr.ckpt", map_location="cpu")
|
60 |
load_state_dict(model, ckpt, strict=True)
|
61 |
model.freeze()
|
62 |
+
|
63 |
+
# Check if CUDA is available, otherwise use CPU
|
64 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
65 |
+
model.to(device)
|
66 |
|
67 |
@torch.no_grad()
|
68 |
def process(
|
|
|
116 |
control_img = np.array(control_img)
|
117 |
|
118 |
# Convert to tensor (NCHW, [0,1])
|
119 |
+
control = torch.tensor(control_img[None] / 255.0, dtype=torch.float32, device=device).clamp_(0, 1)
|
120 |
control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
|
121 |
height, width = control.size(-2), control.size(-1)
|
122 |
model.control_scales = [strength] * 13
|
|
|
125 |
preds = []
|
126 |
for _ in tqdm(range(num_samples)):
|
127 |
shape = (1, 4, height // 8, width // 8)
|
128 |
+
x_T = torch.randn(shape, device=device, dtype=torch.float32)
|
129 |
|
130 |
if not tile_diffusion and not tile_vae:
|
131 |
samples = sampler.sample_ccsr(
|