xiaosu-zhu commited on
Commit
2a072c6
·
1 Parent(s): ef5107a

Auto deploy

Browse files
Files changed (1) hide show
  1. stCompressService.py +9 -9
stCompressService.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import pathlib
3
  import torch
4
  import torch.hub
5
  from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
@@ -29,11 +28,10 @@ def loadModel(device):
29
  ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
30
 
31
  config = Config.deserialize(ckpt["config"])
32
- model = Compressor(**config.Model.Params).to(device)
33
  model.QuantizationParameter = "qp_2_msssim"
34
  model.load_state_dict(ckpt["model"])
35
- return model
36
-
37
 
38
 
39
  @st.cache
@@ -46,8 +44,9 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
46
  # [c, h, w]
47
  image = (image - 0.5) * 2
48
 
49
- with model._quantizer.readyForCoding() as cdfs:
50
- codes, binaries, headers = model.compress(image[None, ...], cdfs)
 
51
 
52
  return File(headers[0], binaries[0])
53
 
@@ -56,9 +55,10 @@ def compressImage(image: torch.Tensor, model: BaseCompressor, crop: bool) -> Fil
56
  def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
57
  binaries = sourceFile.Content
58
 
59
- with model._quantizer.readyForCoding() as cdfs:
 
60
  # [1, c, h, w]
61
- restored = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
62
 
63
  # [c, h, w]
64
  return DeTransform()(restored[0])
@@ -71,7 +71,7 @@ def main():
71
  else:
72
  device = torch.device("cuda")
73
 
74
- model = loadModel(device).eval()
75
 
76
  st.sidebar.markdown("""
77
  <p align="center">
 
1
  import os
 
2
  import torch
3
  import torch.hub
4
  from torchvision.transforms.functional import convert_image_dtype, pil_to_tensor
 
28
  ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
29
 
30
  config = Config.deserialize(ckpt["config"])
31
+ model = Compressor(**config.Model.Params).to(device).eval()
32
  model.QuantizationParameter = "qp_2_msssim"
33
  model.load_state_dict(ckpt["model"])
34
+ return torch.jit.script(model)
 
35
 
36
 
37
  @st.cache
 
44
  # [c, h, w]
45
  image = (image - 0.5) * 2
46
 
47
+ with model.readyForCoding() as cdfs:
48
+ codes, size = model.encode(image[None, ...])
49
+ binaries, headers = model.compress(codes, size, cdfs)
50
 
51
  return File(headers[0], binaries[0])
52
 
 
55
  def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor:
56
  binaries = sourceFile.Content
57
 
58
+ with model.readyForCoding() as cdfs:
59
+ codes, imageSize = model.decompress([binaries], cdfs, [sourceFile.FileHeader])
60
  # [1, c, h, w]
61
+ restored = model.decode(codes, imageSize)
62
 
63
  # [c, h, w]
64
  return DeTransform()(restored[0])
 
71
  else:
72
  device = torch.device("cuda")
73
 
74
+ model = loadModel(device)
75
 
76
  st.sidebar.markdown("""
77
  <p align="center">