Alexander Becker commited on
Commit
22a2526
·
1 Parent(s): 9077960

Use HF Hub download

Browse files
Files changed (2) hide show
  1. app.py +12 -10
  2. checkpoints/thera-edsr-plus.pkl +0 -3
app.py CHANGED
@@ -3,15 +3,24 @@ import json
3
  import os
4
 
5
  import gradio as gr
6
- from gradio_dualvision import DualVisionApp
7
- from gradio_dualvision.gradio_patches.radio import Radio
8
  from PIL import Image
9
  import numpy as np
10
 
 
 
 
11
  from model import build_thera
12
  from super_resolve import process
13
 
14
- CHECKPOINT = "checkpoints/thera-edsr-plus.pkl"
 
 
 
 
 
 
 
 
15
 
16
 
17
  class TheraApp(DualVisionApp):
@@ -73,13 +82,6 @@ class TheraApp(DualVisionApp):
73
  round(source.shape[1] * scale),
74
  )
75
 
76
- # load model
77
- with open(CHECKPOINT, 'rb') as fh:
78
- check = pickle.load(fh)
79
- params, backbone, size = check['model'], check['backbone'], check['size']
80
-
81
- model = build_thera(3, backbone, size)
82
-
83
  out = process(source, model, params, target_shape, do_ensemble=do_ensemble)
84
  out = Image.fromarray(np.asarray(out))
85
 
 
3
  import os
4
 
5
  import gradio as gr
 
 
6
  from PIL import Image
7
  import numpy as np
8
 
9
+ from gradio_dualvision import DualVisionApp
10
+ from gradio_dualvision.gradio_patches.radio import Radio
11
+ from huggingface_hub import hf_hub_download
12
  from model import build_thera
13
  from super_resolve import process
14
 
15
+ REPO_ID = "prs-eth/thera-edsr-plus"
16
+
17
+ # load model
18
+ model_path = hf_hub_download(repo_id=REPO_ID, filename="model.pkl")
19
+ with open(model_path, 'rb') as fh:
20
+ check = pickle.load(fh)
21
+ params, backbone, size = check['model'], check['backbone'], check['size']
22
+
23
+ model = build_thera(3, backbone, size)
24
 
25
 
26
  class TheraApp(DualVisionApp):
 
82
  round(source.shape[1] * scale),
83
  )
84
 
 
 
 
 
 
 
 
85
  out = process(source, model, params, target_shape, do_ensemble=do_ensemble)
86
  out = Image.fromarray(np.asarray(out))
87
 
checkpoints/thera-edsr-plus.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a805ca6f0486d9eba8f228200340a0e6aedde16529e11fc7b98dc26d830d9aa8
3
- size 31632862