radames commited on
Commit
4f05465
·
1 Parent(s): 3e73135

performance

Browse files
Files changed (2) hide show
  1. interface/app.py +8 -2
  2. interface/model_loader.py +1 -1
interface/app.py CHANGED
@@ -10,6 +10,9 @@ import io
10
  from huggingface_hub import snapshot_download
11
  import json
12
 
 
 
 
13
  models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
14
 
15
 
@@ -52,7 +55,10 @@ default_dxdysxsy = json.dumps(
52
  )
53
 
54
  def cv_to_pil(img):
55
- return Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
 
 
 
56
 
57
 
58
  def random_sample(model_name: str):
@@ -175,5 +181,5 @@ Double click to add or remove stop points.
175
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
176
  )
177
 
178
- block.queue(api_open=False)
179
  block.launch(show_api=False)
 
10
  from huggingface_hub import snapshot_download
11
  import json
12
 
13
+ # disable if running on another environment
14
+ RESIZE = True
15
+
16
  models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
17
 
18
 
 
55
  )
56
 
57
  def cv_to_pil(img):
58
+ img = Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
59
+ if RESIZE:
60
+ img = img.resize((128, 128))
61
+ return img
62
 
63
 
64
  def random_sample(model_name: str):
 
181
  random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
182
  )
183
 
184
+ # block.queue(api_open=False)
185
  block.launch(show_api=False)
interface/model_loader.py CHANGED
@@ -12,7 +12,7 @@ class Model:
12
  ):
13
  self.truncation = truncation
14
  self.use_average_code_as_input = use_average_code_as_input
15
- ckpt = torch.load(checkpoint_path, map_location="cpu")
16
  opts = ckpt["opts"]
17
  opts["checkpoint_path"] = checkpoint_path
18
  self.opts = Namespace(**ckpt["opts"])
 
12
  ):
13
  self.truncation = truncation
14
  self.use_average_code_as_input = use_average_code_as_input
15
+ ckpt = torch.load(checkpoint_path, map_location="cuda")
16
  opts = ckpt["opts"]
17
  opts["checkpoint_path"] = checkpoint_path
18
  self.opts = Namespace(**ckpt["opts"])