Spaces:
Sleeping
Sleeping
Zongsheng
commited on
Commit
•
4cd2c6a
1
Parent(s):
f857ecf
add resize for arbitraty size
Browse files- sampler.py +5 -4
sampler.py
CHANGED
@@ -166,6 +166,11 @@ class DifIRSampler(BaseSampler):
|
|
166 |
# basical image restoration
|
167 |
device = next(self.model.parameters()).device
|
168 |
y0 = y0.to(device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
169 |
if need_restoration:
|
170 |
with torch.no_grad():
|
171 |
if model_kwargs_ir is None:
|
@@ -176,10 +181,6 @@ class DifIRSampler(BaseSampler):
|
|
176 |
im_hq = y0
|
177 |
im_hq.clamp_(0.0, 1.0)
|
178 |
|
179 |
-
h_old, w_old = im_hq.shape[2:4]
|
180 |
-
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
|
181 |
-
im_hq = resize(im_hq, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
|
182 |
-
|
183 |
# diffuse for im_hq
|
184 |
yt = self.diffusion.q_sample(
|
185 |
x_start=post_fun(im_hq),
|
|
|
166 |
# basical image restoration
|
167 |
device = next(self.model.parameters()).device
|
168 |
y0 = y0.to(device=device, dtype=torch.float32)
|
169 |
+
|
170 |
+
h_old, w_old = y0.shape[2:4]
|
171 |
+
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
|
172 |
+
y0 = resize(y0, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
|
173 |
+
|
174 |
if need_restoration:
|
175 |
with torch.no_grad():
|
176 |
if model_kwargs_ir is None:
|
|
|
181 |
im_hq = y0
|
182 |
im_hq.clamp_(0.0, 1.0)
|
183 |
|
|
|
|
|
|
|
|
|
184 |
# diffuse for im_hq
|
185 |
yt = self.diffusion.q_sample(
|
186 |
x_start=post_fun(im_hq),
|