Spaces:
Runtime error
Runtime error
Commit
·
d5a4886
1
Parent(s):
43a2e18
Init Deploy
Browse files- stCompressService.py +127 -0
stCompressService.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import torch
|
3 |
+
import torch.hub
|
4 |
+
from torchvision.transforms.functional import convert_image_dtype, to_tensor
|
5 |
+
from torchvision.io.image import ImageReadMode, encode_png, decode_image
|
6 |
+
|
7 |
+
from mcquic import Config
|
8 |
+
from mcquic.modules.compressor import BaseCompressor, Compressor
|
9 |
+
from mcquic.datasets.transforms import AlignedCrop
|
10 |
+
import mcquic
|
11 |
+
from mcquic.utils.specification import File, FileHeader
|
12 |
+
from mcquic.utils.vision import DeTransform
|
13 |
+
|
14 |
+
try:
|
15 |
+
import streamlit as st
|
16 |
+
except:
|
17 |
+
raise ImportError("To run `mcquic service`, please install Streamlit by `pip install streamlit` firstly.")
|
18 |
+
|
19 |
+
|
20 |
+
MODELS_URL = "https://github.com/xiaosu-zhu/McQuic/releases/download/generic/"
|
21 |
+
|
22 |
+
|
23 |
+
@st.experimental_singleton
|
24 |
+
def loadModel(qp: int, local: pathlib.Path, device, mse: bool):
|
25 |
+
suffix = "mse" if mse else "msssim"
|
26 |
+
# ckpt = torch.hub.load_state_dict_from_url(MODELS_URL + f"qp_{qp}_{suffix}.mcquic", map_location=device)
|
27 |
+
|
28 |
+
ckpt = torch.load("./qp_3_msssim_fcc58b73.mcquic", map_location=device)
|
29 |
+
|
30 |
+
config = Config.deserialize(ckpt["config"])
|
31 |
+
model = Compressor(**config.Model.Params).to(device)
|
32 |
+
model.QuantizationParameter = str(local) if local is not None else str(qp)
|
33 |
+
model.load_state_dict(ckpt["model"])
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
@st.cache
|
39 |
+
def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
|
40 |
+
image = convert_image_dtype(image)
|
41 |
+
|
42 |
+
if crop:
|
43 |
+
image = AlignedCrop()(image)
|
44 |
+
|
45 |
+
# [c, h, w]
|
46 |
+
image = (image - 0.5) * 2
|
47 |
+
|
48 |
+
with model._quantizer.readyForCoding() as cdfs:
|
49 |
+
codes, binaries, headers = model.compress(image[None, ...], cdfs)
|
50 |
+
|
51 |
+
return File(headers[0], binaries[0])
|
52 |
+
|
53 |
+
|
54 |
+
@st.cache
|
55 |
+
def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
|
56 |
+
binaries = sourceFile.Content
|
57 |
+
|
58 |
+
with model._quantizer.readyForCoding() as cdfs:
|
59 |
+
# [1, c, h, w]
|
60 |
+
restored = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
|
61 |
+
|
62 |
+
# [c, h, w]
|
63 |
+
return DeTransform()(restored[0])
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def main(debug: bool, quiet: bool, disable_gpu: bool):
|
68 |
+
if disable_gpu or not torch.cuda.is_available():
|
69 |
+
device = torch.device("cpu")
|
70 |
+
else:
|
71 |
+
device = torch.device("cuda")
|
72 |
+
|
73 |
+
model = loadModel(3, None, device, False).eval()
|
74 |
+
|
75 |
+
st.sidebar.markdown("""
|
76 |
+
<p align="center">
|
77 |
+
<a href="https://github.com/xiaosu-zhu/McQuic">
|
78 |
+
<img src="https://raw.githubusercontent.com/xiaosu-zhu/McQuic/main/assets/McQuic-light.svg" alt="McQuic" title="McQuic" width="45%"/>
|
79 |
+
</a>
|
80 |
+
<br/>
|
81 |
+
<span>
|
82 |
+
<i>a.k.a.</i> <b><i>M</i></b>ulti-<b><i>c</i></b>odebook <b><i>Qu</i></b>antizers for neural <b><i>i</i></b>mage <b><i>c</i></b>ompression
|
83 |
+
</span>
|
84 |
+
</p>
|
85 |
+
|
86 |
+
<p align="center">
|
87 |
+
Compressing images on-the-fly.
|
88 |
+
</p>
|
89 |
+
|
90 |
+
|
91 |
+
<a href="#">
|
92 |
+
<image src="https://img.shields.io/badge/NOTE-yellow?style=for-the-badge" alt="NOTE"/>
|
93 |
+
</a>
|
94 |
+
|
95 |
+
> Due to resources limitation, I only provide compression service with model `qp = 3`.
|
96 |
+
""", unsafe_allow_html=True)
|
97 |
+
|
98 |
+
|
99 |
+
with st.form("SubmitForm"):
|
100 |
+
uploadedFile = st.file_uploader("Try running McQuic to compress or restore images!", type=["png", "jpg", "jpeg", "mcq"], help="Upload your image or compressed `.mcq` file here.")
|
101 |
+
cropping = st.checkbox("Cropping image to align grids.", help="If checked, the image is cropped to align to feature map grids. This makes output smaller.")
|
102 |
+
submitted = st.form_submit_button("Submit", help="Click to start compress/restore.")
|
103 |
+
if submitted and uploadedFile is not None:
|
104 |
+
if uploadedFile.name.endswith(".mcq"):
|
105 |
+
uploadedFile.flush()
|
106 |
+
|
107 |
+
binaryFile = File.deserialize(uploadedFile.read())
|
108 |
+
|
109 |
+
st.text(str(binaryFile))
|
110 |
+
|
111 |
+
result = decompressImage(binaryFile, model)
|
112 |
+
st.image(result.cpu().permute(1, 2, 0).numpy())
|
113 |
+
st.download_button("Click to download restored image", data=bytes(encode_png(result.cpu()).tolist()), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["png"]), mime="image/png")
|
114 |
+
else:
|
115 |
+
raw = torch.ByteTensor(torch.ByteStorage.from_buffer(uploadedFile.read())) # type: ignore
|
116 |
+
image = decode_image(raw, ImageReadMode.RGB).to(device)
|
117 |
+
st.image(image.cpu().permute(1, 2, 0).numpy())
|
118 |
+
result = compressImage(image, model, cropping)
|
119 |
+
|
120 |
+
st.text(str(result))
|
121 |
+
|
122 |
+
st.download_button("Click to download compressed file", data=result.serialize(), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["mcq"]), mime="image/mcq")
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
with torch.inference_mode():
|
127 |
+
main(False, False, True)
|