Spaces:
Runtime error
Runtime error
Better support for keep_shape
Browse files
app.py
CHANGED
@@ -25,18 +25,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
25 |
IMG_BITS = 13
|
26 |
|
27 |
|
28 |
-
class ToBinary(torch.autograd.Function):
|
29 |
-
|
30 |
-
@staticmethod
|
31 |
-
def forward(ctx, x):
|
32 |
-
return torch.floor(
|
33 |
-
x + 0.5) # no need for noise when we have plenty of data
|
34 |
-
|
35 |
-
@staticmethod
|
36 |
-
def backward(ctx, grad_output):
|
37 |
-
return grad_output.clone() # pass-through
|
38 |
-
|
39 |
-
|
40 |
class ResBlock(nn.Module):
|
41 |
|
42 |
def __init__(self, c_x, c_hidden):
|
@@ -242,26 +230,46 @@ def prepare_model(model_prefix):
|
|
242 |
return encoder, decoder
|
243 |
|
244 |
|
245 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
encoder, _ = prepare_model(model_prefix)
|
247 |
-
img_transform = transforms.Compose(
|
248 |
-
[transforms.PILToTensor(),
|
249 |
-
transforms.ConvertImageDtype(torch.float)] +
|
250 |
-
([transforms.Resize((224, 224))] if not keep_dims else []))
|
251 |
|
252 |
with torch.no_grad():
|
253 |
-
img =
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
with io.BytesIO() as buffer:
|
258 |
np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
|
259 |
z_b64 = base64.b64encode(buffer.getvalue()).decode()
|
260 |
|
261 |
-
return json.dumps({
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
|
264 |
def decode(model_prefix, z_str):
|
|
|
265 |
_, decoder = prepare_model(model_prefix)
|
266 |
|
267 |
z_json = json.loads(z_str)
|
@@ -269,10 +277,23 @@ def decode(model_prefix, z_str):
|
|
269 |
buffer.write(base64.b64decode(z_json["data"]))
|
270 |
buffer.seek(0)
|
271 |
z = np.load(buffer)
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
-
|
275 |
-
return VF.to_pil_image(
|
276 |
|
277 |
|
278 |
st.title("Clip Guided Binary Autoencoder")
|
@@ -288,12 +309,13 @@ encoder_tab, decoder_tab = st.tabs(["Encode", "Decode"])
|
|
288 |
|
289 |
with encoder_tab:
|
290 |
col_in, col_out = st.columns(2)
|
291 |
-
|
|
|
292 |
uploaded_file = col_in.file_uploader('Choose an Image')
|
293 |
if uploaded_file is not None:
|
294 |
image = Image.open(uploaded_file)
|
295 |
col_in.image(image, 'Input Image')
|
296 |
-
z_str = encode(model_prefix, image,
|
297 |
col_out.write("Encoded to:")
|
298 |
col_out.code(z_str, language=None)
|
299 |
col_out.image(decode(model_prefix, z_str), 'Output Image preview')
|
|
|
25 |
IMG_BITS = 13
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
class ResBlock(nn.Module):
|
29 |
|
30 |
def __init__(self, c_x, c_hidden):
|
|
|
230 |
return encoder, decoder
|
231 |
|
232 |
|
233 |
+
def compute_padding(img_shape):
|
234 |
+
hsize, vsize = (img_shape[1] + 7) // 8 * 8, (img_shape[0] + 7) // 8 * 8
|
235 |
+
hpad, vpad = hsize - img_shape[1], vsize - img_shape[0]
|
236 |
+
left, top = hpad // 2, vpad // 2
|
237 |
+
right, bottom = hpad - left, vpad - top
|
238 |
+
return left, top, right, bottom
|
239 |
+
|
240 |
+
|
241 |
+
def encode(model_prefix, img, keep_shape):
|
242 |
+
gc.collect()
|
243 |
encoder, _ = prepare_model(model_prefix)
|
|
|
|
|
|
|
|
|
244 |
|
245 |
with torch.no_grad():
|
246 |
+
img = VF.pil_to_tensor(img.convert("RGB"))
|
247 |
+
img = VF.convert_image_dtype(img)
|
248 |
+
img = img.unsqueeze(0).to(device)
|
249 |
+
img_shape = img.shape[2:]
|
250 |
+
|
251 |
+
if keep_shape:
|
252 |
+
left, top, right, bottom = compute_padding(img_shape)
|
253 |
+
img = VF.pad(img, [left, top, right, bottom], padding_mode='edge')
|
254 |
+
else:
|
255 |
+
img = VF.resize(img, [224, 224])
|
256 |
+
|
257 |
+
z = torch.floor(encoder(img) + 0.5)
|
258 |
|
259 |
with io.BytesIO() as buffer:
|
260 |
np.save(buffer, np.packbits(z.cpu().numpy().astype('bool')))
|
261 |
z_b64 = base64.b64encode(buffer.getvalue()).decode()
|
262 |
|
263 |
+
return json.dumps({
|
264 |
+
"img_shape": img_shape,
|
265 |
+
"z_shape": z.shape[2:],
|
266 |
+
"keep_shape": keep_shape,
|
267 |
+
"data": z_b64,
|
268 |
+
})
|
269 |
|
270 |
|
271 |
def decode(model_prefix, z_str):
|
272 |
+
gc.collect()
|
273 |
_, decoder = prepare_model(model_prefix)
|
274 |
|
275 |
z_json = json.loads(z_str)
|
|
|
277 |
buffer.write(base64.b64decode(z_json["data"]))
|
278 |
buffer.seek(0)
|
279 |
z = np.load(buffer)
|
280 |
+
img_shape = z_json["img_shape"]
|
281 |
+
z_shape = z_json["z_shape"]
|
282 |
+
keep_shape = z_json["keep_shape"]
|
283 |
+
|
284 |
+
z = np.unpackbits(z)[:IMG_BITS * z_shape[0] * z_shape[1]].astype('float')
|
285 |
+
z = z.reshape([1, IMG_BITS] + z_shape)
|
286 |
+
|
287 |
+
img = decoder(torch.Tensor(z).to(device))
|
288 |
+
|
289 |
+
if keep_shape:
|
290 |
+
left, top, right, bottom = compute_padding(img_shape)
|
291 |
+
img = img[0, :, top:img.shape[2] - bottom, left:img.shape[3] - right]
|
292 |
+
else:
|
293 |
+
img = img[0]
|
294 |
|
295 |
+
st.write(img.shape)
|
296 |
+
return VF.to_pil_image(img)
|
297 |
|
298 |
|
299 |
st.title("Clip Guided Binary Autoencoder")
|
|
|
309 |
|
310 |
with encoder_tab:
|
311 |
col_in, col_out = st.columns(2)
|
312 |
+
keep_shape = col_in.checkbox(
|
313 |
+
'Use original size of input image instead of rescaling (Experimental)')
|
314 |
uploaded_file = col_in.file_uploader('Choose an Image')
|
315 |
if uploaded_file is not None:
|
316 |
image = Image.open(uploaded_file)
|
317 |
col_in.image(image, 'Input Image')
|
318 |
+
z_str = encode(model_prefix, image, keep_shape)
|
319 |
col_out.write("Encoded to:")
|
320 |
col_out.code(z_str, language=None)
|
321 |
col_out.image(decode(model_prefix, z_str), 'Output Image preview')
|