xiaosu-zhu commited on
Commit
d5a4886
·
1 Parent(s): 43a2e18

Init Deploy

Browse files
Files changed (1) hide show
  1. stCompressService.py +127 -0
stCompressService.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import torch
3
+ import torch.hub
4
+ from torchvision.transforms.functional import convert_image_dtype, to_tensor
5
+ from torchvision.io.image import ImageReadMode, encode_png, decode_image
6
+
7
+ from mcquic import Config
8
+ from mcquic.modules.compressor import BaseCompressor, Compressor
9
+ from mcquic.datasets.transforms import AlignedCrop
10
+ import mcquic
11
+ from mcquic.utils.specification import File, FileHeader
12
+ from mcquic.utils.vision import DeTransform
13
+
14
+ try:
15
+ import streamlit as st
16
+ except:
17
+ raise ImportError("To run `mcquic service`, please install Streamlit by `pip install streamlit` firstly.")
18
+
19
+
20
+ MODELS_URL = "https://github.com/xiaosu-zhu/McQuic/releases/download/generic/"
21
+
22
+
23
+ @st.experimental_singleton
24
+ def loadModel(qp: int, local: pathlib.Path, device, mse: bool):
25
+ suffix = "mse" if mse else "msssim"
26
+ # ckpt = torch.hub.load_state_dict_from_url(MODELS_URL + f"qp_{qp}_{suffix}.mcquic", map_location=device)
27
+
28
+ ckpt = torch.load("./qp_3_msssim_fcc58b73.mcquic", map_location=device)
29
+
30
+ config = Config.deserialize(ckpt["config"])
31
+ model = Compressor(**config.Model.Params).to(device)
32
+ model.QuantizationParameter = str(local) if local is not None else str(qp)
33
+ model.load_state_dict(ckpt["model"])
34
+ return model
35
+
36
+
37
+
38
+ @st.cache
39
+ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
40
+ image = convert_image_dtype(image)
41
+
42
+ if crop:
43
+ image = AlignedCrop()(image)
44
+
45
+ # [c, h, w]
46
+ image = (image - 0.5) * 2
47
+
48
+ with model._quantizer.readyForCoding() as cdfs:
49
+ codes, binaries, headers = model.compress(image[None, ...], cdfs)
50
+
51
+ return File(headers[0], binaries[0])
52
+
53
+
54
+ @st.cache
55
+ def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
56
+ binaries = sourceFile.Content
57
+
58
+ with model._quantizer.readyForCoding() as cdfs:
59
+ # [1, c, h, w]
60
+ restored = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
61
+
62
+ # [c, h, w]
63
+ return DeTransform()(restored[0])
64
+
65
+
66
+
67
+ def main(debug: bool, quiet: bool, disable_gpu: bool):
68
+ if disable_gpu or not torch.cuda.is_available():
69
+ device = torch.device("cpu")
70
+ else:
71
+ device = torch.device("cuda")
72
+
73
+ model = loadModel(3, None, device, False).eval()
74
+
75
+ st.sidebar.markdown("""
76
+ <p align="center">
77
+ <a href="https://github.com/xiaosu-zhu/McQuic">
78
+ <img src="https://raw.githubusercontent.com/xiaosu-zhu/McQuic/main/assets/McQuic-light.svg" alt="McQuic" title="McQuic" width="45%"/>
79
+ </a>
80
+ <br/>
81
+ <span>
82
+ <i>a.k.a.</i> <b><i>M</i></b>ulti-<b><i>c</i></b>odebook <b><i>Qu</i></b>antizers for neural <b><i>i</i></b>mage <b><i>c</i></b>ompression
83
+ </span>
84
+ </p>
85
+
86
+ <p align="center">
87
+ Compressing images on-the-fly.
88
+ </p>
89
+
90
+
91
+ <a href="#">
92
+ <image src="https://img.shields.io/badge/NOTE-yellow?style=for-the-badge" alt="NOTE"/>
93
+ </a>
94
+
95
+ > Due to resources limitation, I only provide compression service with model `qp = 3`.
96
+ """, unsafe_allow_html=True)
97
+
98
+
99
+ with st.form("SubmitForm"):
100
+ uploadedFile = st.file_uploader("Try running McQuic to compress or restore images!", type=["png", "jpg", "jpeg", "mcq"], help="Upload your image or compressed `.mcq` file here.")
101
+ cropping = st.checkbox("Cropping image to align grids.", help="If checked, the image is cropped to align to feature map grids. This makes output smaller.")
102
+ submitted = st.form_submit_button("Submit", help="Click to start compress/restore.")
103
+ if submitted and uploadedFile is not None:
104
+ if uploadedFile.name.endswith(".mcq"):
105
+ uploadedFile.flush()
106
+
107
+ binaryFile = File.deserialize(uploadedFile.read())
108
+
109
+ st.text(str(binaryFile))
110
+
111
+ result = decompressImage(binaryFile, model)
112
+ st.image(result.cpu().permute(1, 2, 0).numpy())
113
+ st.download_button("Click to download restored image", data=bytes(encode_png(result.cpu()).tolist()), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["png"]), mime="image/png")
114
+ else:
115
+ raw = torch.ByteTensor(torch.ByteStorage.from_buffer(uploadedFile.read())) # type: ignore
116
+ image = decode_image(raw, ImageReadMode.RGB).to(device)
117
+ st.image(image.cpu().permute(1, 2, 0).numpy())
118
+ result = compressImage(image, model, cropping)
119
+
120
+ st.text(str(result))
121
+
122
+ st.download_button("Click to download compressed file", data=result.serialize(), file_name=".".join(uploadedFile.name.split(".")[:-1] + ["mcq"]), mime="image/mcq")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ with torch.inference_mode():
127
+ main(False, False, True)