xiaosu-zhu commited on
Commit
ba2bd35
·
1 Parent(s): 627b096

Auto deploy

Browse files
Files changed (1) hide show
  1. stCompressService.py +12 -13
stCompressService.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  import torch.hub
4
  from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
@@ -12,8 +13,6 @@ from mcquic.datasets.transforms import AlignedCrop
12
  from mcquic.utils.specification import File
13
  from mcquic.utils.vision import DeTransform
14
 
15
- from mcquic.rans import RansEncoder, RansDecoder
16
-
17
  try:
18
  import streamlit as st
19
  except:
@@ -30,14 +29,15 @@ def loadModel(device):
30
  ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
31
 
32
  config = Config.deserialize(ckpt["config"])
33
- model = Compressor(**config.Model.Params).to(device).eval()
34
  model.QuantizationParameter = "qp_2_msssim"
35
  model.load_state_dict(ckpt["model"])
36
- return torch.jit.script(model), RansEncoder(), RansDecoder()
 
37
 
38
 
39
  @st.cache
40
- def compressImage(encoder: RansEncoder, image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
41
  image = convert_image_dtype(image)
42
 
43
  if crop:
@@ -47,20 +47,19 @@ def compressImage(encoder: RansEncoder, image: torch.Tensor, model: BaseCompress
47
  image = (image - 0.5) * 2
48
 
49
  with model.readyForCoding() as cdfs:
50
- codes, size = model.encode(image[None, ...])
51
- binaries, headers = model.compress(encoder, codes, size, cdfs)
52
 
53
  return File(headers[0], binaries[0])
54
 
55
 
56
  @st.cache
57
- def decompressImage(decoder: RansDecoder, sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
58
  binaries = sourceFile.Content
59
 
60
  with model.readyForCoding() as cdfs:
61
- codes, imageSize = model.decompress(decoder, [binaries], cdfs, [sourceFile.FileHeader])
62
  # [1, c, h, w]
63
- restored = model.decode(codes, imageSize)
64
 
65
  # [c, h, w]
66
  return DeTransform()(restored[0])
@@ -73,7 +72,7 @@ def main():
73
  else:
74
  device = torch.device("cuda")
75
 
76
- model, encoder, decoder = loadModel(device)
77
 
78
  st.sidebar.markdown("""
79
  <p align="center">
@@ -135,7 +134,7 @@ def main():
135
 
136
  st.text(str(binaryFile))
137
 
138
- result = decompressImage(encoder, binaryFile, model)
139
  st.image(result.cpu().permute(1, 2, 0).numpy())
140
 
141
  downloadButton = st.empty()
@@ -164,7 +163,7 @@ def main():
164
  return
165
  image = pil_to_tensor(image.convert("RGB")).to(device)
166
  # st.image(image.cpu().permute(1, 2, 0).numpy())
167
- result = compressImage(decoder, image, model, cropping)
168
 
169
  st.text(str(result))
170
 
 
1
  import os
2
+ import pathlib
3
  import torch
4
  import torch.hub
5
  from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
 
13
  from mcquic.utils.specification import File
14
  from mcquic.utils.vision import DeTransform
15
 
 
 
16
  try:
17
  import streamlit as st
18
  except:
 
29
  ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
30
 
31
  config = Config.deserialize(ckpt["config"])
32
+ model = Compressor(**config.Model.Params).to(device)
33
  model.QuantizationParameter = "qp_2_msssim"
34
  model.load_state_dict(ckpt["model"])
35
+ return model
36
+
37
 
38
 
39
  @st.cache
40
+ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> File:
41
  image = convert_image_dtype(image)
42
 
43
  if crop:
 
47
  image = (image - 0.5) * 2
48
 
49
  with model.readyForCoding() as cdfs:
50
+ codes, size = model.encode(images)
51
+ binaries, headers = model.compress(self._encoder, codes, size, cdfs)
52
 
53
  return File(headers[0], binaries[0])
54
 
55
 
56
  @st.cache
57
+ def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
58
  binaries = sourceFile.Content
59
 
60
  with model.readyForCoding() as cdfs:
 
61
  # [1, c, h, w]
62
+ restored = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
63
 
64
  # [c, h, w]
65
  return DeTransform()(restored[0])
 
72
  else:
73
  device = torch.device("cuda")
74
 
75
+ model = loadModel(device).eval()
76
 
77
  st.sidebar.markdown("""
78
  <p align="center">
 
134
 
135
  st.text(str(binaryFile))
136
 
137
+ result = decompressImage(binaryFile, model)
138
  st.image(result.cpu().permute(1, 2, 0).numpy())
139
 
140
  downloadButton = st.empty()
 
163
  return
164
  image = pil_to_tensor(image.convert("RGB")).to(device)
165
  # st.image(image.cpu().permute(1, 2, 0).numpy())
166
+ result = compressImage(image, model, cropping)
167
 
168
  st.text(str(result))
169