Spaces:
Runtime error
Runtime error
Commit
·
627b096
1
Parent(s):
0c6725e
Auto deploy
Browse files- stCompressService.py +10 -8
stCompressService.py
CHANGED
@@ -12,6 +12,8 @@ from mcquic.datasets.transforms import AlignedCrop
|
|
12 |
from mcquic.utils.specification import File
|
13 |
from mcquic.utils.vision import DeTransform
|
14 |
|
|
|
|
|
15 |
try:
|
16 |
import streamlit as st
|
17 |
except:
|
@@ -31,11 +33,11 @@ def loadModel(device):
|
|
31 |
model = Compressor(**config.Model.Params).to(device).eval()
|
32 |
model.QuantizationParameter = "qp_2_msssim"
|
33 |
model.load_state_dict(ckpt["model"])
|
34 |
-
return model
|
35 |
|
36 |
|
37 |
@st.cache
|
38 |
-
def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
|
39 |
image = convert_image_dtype(image)
|
40 |
|
41 |
if crop:
|
@@ -46,17 +48,17 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
|
|
46 |
|
47 |
with model.readyForCoding() as cdfs:
|
48 |
codes, size = model.encode(image[None, ...])
|
49 |
-
binaries, headers = model.compress(codes, size, cdfs)
|
50 |
|
51 |
return File(headers[0], binaries[0])
|
52 |
|
53 |
|
54 |
@st.cache
|
55 |
-
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
56 |
binaries = sourceFile.Content
|
57 |
|
58 |
with model.readyForCoding() as cdfs:
|
59 |
-
codes, imageSize = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
|
60 |
# [1, c, h, w]
|
61 |
restored = model.decode(codes, imageSize)
|
62 |
|
@@ -71,7 +73,7 @@ def main():
|
|
71 |
else:
|
72 |
device = torch.device("cuda")
|
73 |
|
74 |
-
model = loadModel(device)
|
75 |
|
76 |
st.sidebar.markdown("""
|
77 |
<p align="center">
|
@@ -133,7 +135,7 @@ def main():
|
|
133 |
|
134 |
st.text(str(binaryFile))
|
135 |
|
136 |
-
result = decompressImage(binaryFile, model)
|
137 |
st.image(result.cpu().permute(1, 2, 0).numpy())
|
138 |
|
139 |
downloadButton = st.empty()
|
@@ -162,7 +164,7 @@ def main():
|
|
162 |
return
|
163 |
image = pil_to_tensor(image.convert("RGB")).to(device)
|
164 |
# st.image(image.cpu().permute(1, 2, 0).numpy())
|
165 |
-
result = compressImage(image, model, cropping)
|
166 |
|
167 |
st.text(str(result))
|
168 |
|
|
|
12 |
from mcquic.utils.specification import File
|
13 |
from mcquic.utils.vision import DeTransform
|
14 |
|
15 |
+
from mcquic.rans import RansEncoder, RansDecoder
|
16 |
+
|
17 |
try:
|
18 |
import streamlit as st
|
19 |
except:
|
|
|
33 |
model = Compressor(**config.Model.Params).to(device).eval()
|
34 |
model.QuantizationParameter = "qp_2_msssim"
|
35 |
model.load_state_dict(ckpt["model"])
|
36 |
+
return torch.jit.script(model), RansEncoder(), RansDecoder()
|
37 |
|
38 |
|
39 |
@st.cache
|
40 |
+
def compressImage(encoder: RansEncoder, image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
|
41 |
image = convert_image_dtype(image)
|
42 |
|
43 |
if crop:
|
|
|
48 |
|
49 |
with model.readyForCoding() as cdfs:
|
50 |
codes, size = model.encode(image[None, ...])
|
51 |
+
binaries, headers = model.compress(encoder, codes, size, cdfs)
|
52 |
|
53 |
return File(headers[0], binaries[0])
|
54 |
|
55 |
|
56 |
@st.cache
|
57 |
+
def decompressImage(decoder: RansDecoder, sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
58 |
binaries = sourceFile.Content
|
59 |
|
60 |
with model.readyForCoding() as cdfs:
|
61 |
+
codes, imageSize = model.decompress(decoder, [binaries], cdfs, [sourceFile.FileHeader])
|
62 |
# [1, c, h, w]
|
63 |
restored = model.decode(codes, imageSize)
|
64 |
|
|
|
73 |
else:
|
74 |
device = torch.device("cuda")
|
75 |
|
76 |
+
model, encoder, decoder = loadModel(device)
|
77 |
|
78 |
st.sidebar.markdown("""
|
79 |
<p align="center">
|
|
|
135 |
|
136 |
st.text(str(binaryFile))
|
137 |
|
138 |
+
result = decompressImage(encoder, binaryFile, model)
|
139 |
st.image(result.cpu().permute(1, 2, 0).numpy())
|
140 |
|
141 |
downloadButton = st.empty()
|
|
|
164 |
return
|
165 |
image = pil_to_tensor(image.convert("RGB")).to(device)
|
166 |
# st.image(image.cpu().permute(1, 2, 0).numpy())
|
167 |
+
result = compressImage(decoder, image, model, cropping)
|
168 |
|
169 |
st.text(str(result))
|
170 |
|