Spaces:
Runtime error
Runtime error
Commit
·
ba2bd35
1
Parent(s):
627b096
Auto deploy
Browse files- stCompressService.py +12 -13
stCompressService.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
import torch.hub
|
4 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
@@ -12,8 +13,6 @@ from mcquic.datasets.transforms import AlignedCrop
|
|
12 |
from mcquic.utils.specification import File
|
13 |
from mcquic.utils.vision import DeTransform
|
14 |
|
15 |
-
from mcquic.rans import RansEncoder, RansDecoder
|
16 |
-
|
17 |
try:
|
18 |
import streamlit as st
|
19 |
except:
|
@@ -30,14 +29,15 @@ def loadModel(device):
|
|
30 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
31 |
|
32 |
config = Config.deserialize(ckpt["config"])
|
33 |
-
model = Compressor(**config.Model.Params).to(device)
|
34 |
model.QuantizationParameter = "qp_2_msssim"
|
35 |
model.load_state_dict(ckpt["model"])
|
36 |
-
return
|
|
|
37 |
|
38 |
|
39 |
@st.cache
|
40 |
-
def compressImage(
|
41 |
image = convert_image_dtype(image)
|
42 |
|
43 |
if crop:
|
@@ -47,20 +47,19 @@ def compressImage(encoder: RansEncoder, image: torch.Tensor, model: BaseCompress
|
|
47 |
image = (image - 0.5) * 2
|
48 |
|
49 |
with model.readyForCoding() as cdfs:
|
50 |
-
codes, size = model.encode(
|
51 |
-
binaries, headers = model.compress(
|
52 |
|
53 |
return File(headers[0], binaries[0])
|
54 |
|
55 |
|
56 |
@st.cache
|
57 |
-
def decompressImage(
|
58 |
binaries = sourceFile.Content
|
59 |
|
60 |
with model.readyForCoding() as cdfs:
|
61 |
-
codes, imageSize = model.decompress(decoder, [binaries], cdfs, [sourceFile.FileHeader])
|
62 |
# [1, c, h, w]
|
63 |
-
restored = model.
|
64 |
|
65 |
# [c, h, w]
|
66 |
return DeTransform()(restored[0])
|
@@ -73,7 +72,7 @@ def main():
|
|
73 |
else:
|
74 |
device = torch.device("cuda")
|
75 |
|
76 |
-
model
|
77 |
|
78 |
st.sidebar.markdown("""
|
79 |
<p align="center">
|
@@ -135,7 +134,7 @@ def main():
|
|
135 |
|
136 |
st.text(str(binaryFile))
|
137 |
|
138 |
-
result = decompressImage(
|
139 |
st.image(result.cpu().permute(1, 2, 0).numpy())
|
140 |
|
141 |
downloadButton = st.empty()
|
@@ -164,7 +163,7 @@ def main():
|
|
164 |
return
|
165 |
image = pil_to_tensor(image.convert("RGB")).to(device)
|
166 |
# st.image(image.cpu().permute(1, 2, 0).numpy())
|
167 |
-
result = compressImage(
|
168 |
|
169 |
st.text(str(result))
|
170 |
|
|
|
1 |
import os
|
2 |
+
import pathlib
|
3 |
import torch
|
4 |
import torch.hub
|
5 |
from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
|
|
|
13 |
from mcquic.utils.specification import File
|
14 |
from mcquic.utils.vision import DeTransform
|
15 |
|
|
|
|
|
16 |
try:
|
17 |
import streamlit as st
|
18 |
except:
|
|
|
29 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
30 |
|
31 |
config = Config.deserialize(ckpt["config"])
|
32 |
+
model = Compressor(**config.Model.Params).to(device)
|
33 |
model.QuantizationParameter = "qp_2_msssim"
|
34 |
model.load_state_dict(ckpt["model"])
|
35 |
+
return model
|
36 |
+
|
37 |
|
38 |
|
39 |
@st.cache
|
40 |
+
def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
|
41 |
image = convert_image_dtype(image)
|
42 |
|
43 |
if crop:
|
|
|
47 |
image = (image - 0.5) * 2
|
48 |
|
49 |
with model.readyForCoding() as cdfs:
|
50 |
+
codes, size = model.encode(images)
|
51 |
+
binaries, headers = model.compress(self._encoder, codes, size, cdfs)
|
52 |
|
53 |
return File(headers[0], binaries[0])
|
54 |
|
55 |
|
56 |
@st.cache
|
57 |
+
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
58 |
binaries = sourceFile.Content
|
59 |
|
60 |
with model.readyForCoding() as cdfs:
|
|
|
61 |
# [1, c, h, w]
|
62 |
+
restored = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
|
63 |
|
64 |
# [c, h, w]
|
65 |
return DeTransform()(restored[0])
|
|
|
72 |
else:
|
73 |
device = torch.device("cuda")
|
74 |
|
75 |
+
model = loadModel(device).eval()
|
76 |
|
77 |
st.sidebar.markdown("""
|
78 |
<p align="center">
|
|
|
134 |
|
135 |
st.text(str(binaryFile))
|
136 |
|
137 |
+
result = decompressImage(binaryFile, model)
|
138 |
st.image(result.cpu().permute(1, 2, 0).numpy())
|
139 |
|
140 |
downloadButton = st.empty()
|
|
|
163 |
return
|
164 |
image = pil_to_tensor(image.convert("RGB")).to(device)
|
165 |
# st.image(image.cpu().permute(1, 2, 0).numpy())
|
166 |
+
result = compressImage(image, model, cropping)
|
167 |
|
168 |
st.text(str(result))
|
169 |
|