Spaces:
Runtime error
Runtime error
performance
Browse files- interface/app.py +8 -2
- interface/model_loader.py +1 -1
interface/app.py
CHANGED
@@ -10,6 +10,9 @@ import io
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
import json
|
12 |
|
|
|
|
|
|
|
13 |
models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
|
14 |
|
15 |
|
@@ -52,7 +55,10 @@ default_dxdysxsy = json.dumps(
|
|
52 |
)
|
53 |
|
54 |
def cv_to_pil(img):
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def random_sample(model_name: str):
|
@@ -175,5 +181,5 @@ Double click to add or remove stop points.
|
|
175 |
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
|
176 |
)
|
177 |
|
178 |
-
block.queue(api_open=False)
|
179 |
block.launch(show_api=False)
|
|
|
10 |
from huggingface_hub import snapshot_download
|
11 |
import json
|
12 |
|
13 |
+
# disable if running on another environment
|
14 |
+
RESIZE = True
|
15 |
+
|
16 |
models_path = snapshot_download(repo_id="radames/UserControllableLT", repo_type="model")
|
17 |
|
18 |
|
|
|
55 |
)
|
56 |
|
57 |
def cv_to_pil(img):
|
58 |
+
img = Image.fromarray(cv2.cvtColor(img.astype("uint8"), cv2.COLOR_BGR2RGB))
|
59 |
+
if RESIZE:
|
60 |
+
img = img.resize((128, 128))
|
61 |
+
return img
|
62 |
|
63 |
|
64 |
def random_sample(model_name: str):
|
|
|
181 |
random_sample, inputs=[model_name], outputs=[image, model_state, latents_state]
|
182 |
)
|
183 |
|
184 |
+
# block.queue(api_open=False)
|
185 |
block.launch(show_api=False)
|
interface/model_loader.py
CHANGED
@@ -12,7 +12,7 @@ class Model:
|
|
12 |
):
|
13 |
self.truncation = truncation
|
14 |
self.use_average_code_as_input = use_average_code_as_input
|
15 |
-
ckpt = torch.load(checkpoint_path, map_location="
|
16 |
opts = ckpt["opts"]
|
17 |
opts["checkpoint_path"] = checkpoint_path
|
18 |
self.opts = Namespace(**ckpt["opts"])
|
|
|
12 |
):
|
13 |
self.truncation = truncation
|
14 |
self.use_average_code_as_input = use_average_code_as_input
|
15 |
+
ckpt = torch.load(checkpoint_path, map_location="cuda")
|
16 |
opts = ckpt["opts"]
|
17 |
opts["checkpoint_path"] = checkpoint_path
|
18 |
self.opts = Namespace(**ckpt["opts"])
|