Spaces:
Runtime error
Runtime error
Commit
·
2a072c6
1
Parent(s):
ef5107a
Auto deploy
Browse files- stCompressService.py +9 -9
stCompressService.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
import pathlib
|
3 |
import torch
|
4 |
import torch.hub
|
5 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
@@ -29,11 +28,10 @@ def loadModel(device):
|
|
29 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
30 |
|
31 |
config = Config.deserialize(ckpt["config"])
|
32 |
-
model = Compressor(**config.Model.Params).to(device)
|
33 |
model.QuantizationParameter = "qp_2_msssim"
|
34 |
model.load_state_dict(ckpt["model"])
|
35 |
-
return model
|
36 |
-
|
37 |
|
38 |
|
39 |
@st.cache
|
@@ -46,8 +44,9 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
|
|
46 |
# [c, h, w]
|
47 |
image = (image - 0.5) * 2
|
48 |
|
49 |
-
with model.
|
50 |
-
codes,
|
|
|
51 |
|
52 |
return File(headers[0], binaries[0])
|
53 |
|
@@ -56,9 +55,10 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
|
|
56 |
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
57 |
binaries = sourceFile.Content
|
58 |
|
59 |
-
with model.
|
|
|
60 |
# [1, c, h, w]
|
61 |
-
restored = model.
|
62 |
|
63 |
# [c, h, w]
|
64 |
return DeTransform()(restored[0])
|
@@ -71,7 +71,7 @@ def main():
|
|
71 |
else:
|
72 |
device = torch.device("cuda")
|
73 |
|
74 |
-
model = loadModel(device)
|
75 |
|
76 |
st.sidebar.markdown("""
|
77 |
<p align="center">
|
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
import torch.hub
|
4 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
|
|
28 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
29 |
|
30 |
config = Config.deserialize(ckpt["config"])
|
31 |
+
model = Compressor(**config.Model.Params).to(device).eval()
|
32 |
model.QuantizationParameter = "qp_2_msssim"
|
33 |
model.load_state_dict(ckpt["model"])
|
34 |
+
return torch.jit.script(model)
|
|
|
35 |
|
36 |
|
37 |
@st.cache
|
|
|
44 |
# [c, h, w]
|
45 |
image = (image - 0.5) * 2
|
46 |
|
47 |
+
with model.readyForCoding() as cdfs:
|
48 |
+
codes, size = model.encode(image[None, ...])
|
49 |
+
binaries, headers = model.compress(codes, size, cdfs)
|
50 |
|
51 |
return File(headers[0], binaries[0])
|
52 |
|
|
|
55 |
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
56 |
binaries = sourceFile.Content
|
57 |
|
58 |
+
with model.readyForCoding() as cdfs:
|
59 |
+
codes, imageSize = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
|
60 |
# [1, c, h, w]
|
61 |
+
restored = model.decode(codes, imageSize)
|
62 |
|
63 |
# [c, h, w]
|
64 |
return DeTransform()(restored[0])
|
|
|
71 |
else:
|
72 |
device = torch.device("cuda")
|
73 |
|
74 |
+
model = loadModel(device)
|
75 |
|
76 |
st.sidebar.markdown("""
|
77 |
<p align="center">
|