xiaosu-zhu commited on
Commit
627b096
·
1 Parent(s): 0c6725e

Auto deploy

Browse files
Files changed (1) hide show
  1. stCompressService.py +10 -8
stCompressService.py CHANGED
@@ -12,6 +12,8 @@ from mcquic.datasets.transforms import AlignedCrop
12
  from mcquic.utils.specification import File
13
  from mcquic.utils.vision import DeTransform
14
 
 
 
15
  try:
16
  import streamlit as st
17
  except:
@@ -31,11 +33,11 @@ def loadModel(device):
31
  model = Compressor(**config.Model.Params).to(device).eval()
32
  model.QuantizationParameter = "qp_2_msssim"
33
  model.load_state_dict(ckpt["model"])
34
- return model
35
 
36
 
37
  @st.cache
38
- def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
39
  image = convert_image_dtype(image)
40
 
41
  if crop:
@@ -46,17 +48,17 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
46
 
47
  with model.readyForCoding() as cdfs:
48
  codes, size = model.encode(image[None, ...])
49
- binaries, headers = model.compress(codes, size, cdfs)
50
 
51
  return File(headers[0], binaries[0])
52
 
53
 
54
  @st.cache
55
- def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
56
  binaries = sourceFile.Content
57
 
58
  with model.readyForCoding() as cdfs:
59
- codes, imageSize = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
60
  # [1, c, h, w]
61
  restored = model.decode(codes, imageSize)
62
 
@@ -71,7 +73,7 @@ def main():
71
  else:
72
  device = torch.device("cuda")
73
 
74
- model = loadModel(device)
75
 
76
  st.sidebar.markdown("""
77
  <p align="center">
@@ -133,7 +135,7 @@ def main():
133
 
134
  st.text(str(binaryFile))
135
 
136
- result = decompressImage(binaryFile, model)
137
  st.image(result.cpu().permute(1, 2, 0).numpy())
138
 
139
  downloadButton = st.empty()
@@ -162,7 +164,7 @@ def main():
162
  return
163
  image = pil_to_tensor(image.convert("RGB")).to(device)
164
  # st.image(image.cpu().permute(1, 2, 0).numpy())
165
- result = compressImage(image, model, cropping)
166
 
167
  st.text(str(result))
168
 
 
12
  from mcquic.utils.specification import File
13
  from mcquic.utils.vision import DeTransform
14
 
15
+ from mcquic.rans import RansEncoder, RansDecoder
16
+
17
  try:
18
  import streamlit as st
19
  except:
 
33
  model = Compressor(**config.Model.Params).to(device).eval()
34
  model.QuantizationParameter = "qp_2_msssim"
35
  model.load_state_dict(ckpt["model"])
36
+ return torch.jit.script(model), RansEncoder(), RansDecoder()
37
 
38
 
39
  @st.cache
40
+ def compressImage(encoder: RansEncoder, image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
41
  image = convert_image_dtype(image)
42
 
43
  if crop:
 
48
 
49
  with model.readyForCoding() as cdfs:
50
  codes, size = model.encode(image[None, ...])
51
+ binaries, headers = model.compress(encoder, codes, size, cdfs)
52
 
53
  return File(headers[0], binaries[0])
54
 
55
 
56
  @st.cache
57
+ def decompressImage(decoder: RansDecoder, sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
58
  binaries = sourceFile.Content
59
 
60
  with model.readyForCoding() as cdfs:
61
+ codes, imageSize = model.decompress(decoder, [binaries], cdfs, [sourceFile.FileHeader])
62
  # [1, c, h, w]
63
  restored = model.decode(codes, imageSize)
64
 
 
73
  else:
74
  device = torch.device("cuda")
75
 
76
+ model, encoder, decoder = loadModel(device)
77
 
78
  st.sidebar.markdown("""
79
  <p align="center">
 
135
 
136
  st.text(str(binaryFile))
137
 
138
+ result = decompressImage(encoder, binaryFile, model)
139
  st.image(result.cpu().permute(1, 2, 0).numpy())
140
 
141
  downloadButton = st.empty()
 
164
  return
165
  image = pil_to_tensor(image.convert("RGB")).to(device)
166
  # st.image(image.cpu().permute(1, 2, 0).numpy())
167
+ result = compressImage(decoder, image, model, cropping)
168
 
169
  st.text(str(result))
170