Spaces:
Running
on
Zero
Running
on
Zero
Assuming GPUZero is always available.
Browse files- app.py +7 -8
- utils/predict.py +1 -1
app.py
CHANGED
@@ -10,29 +10,28 @@ from pathlib import Path
|
|
10 |
from PIL import Image
|
11 |
|
12 |
from plots import get_pre_define_colors
|
13 |
-
|
14 |
from utils.predict import xclip_pred
|
15 |
|
16 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
-
# def initialize_model():
|
18 |
-
# global XCLIP, OWLVIT_PRECESSOR
|
19 |
-
# if XCLIP is None or OWLVIT_PRECESSOR is None:
|
20 |
-
# XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
|
21 |
|
22 |
#! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
|
23 |
try:
|
24 |
import spaces
|
25 |
XCLIP, OWLVIT_PRECESSOR = None, None
|
|
|
26 |
except:
|
|
|
27 |
print(f"Not at Huggingface demo, load model to main process.")
|
28 |
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
|
29 |
-
|
|
|
|
|
30 |
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
|
31 |
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
|
32 |
IMAGES_FOLDER = "data/images"
|
33 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
34 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
35 |
-
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
|
36 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
37 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
38 |
|
|
|
10 |
from PIL import Image
|
11 |
|
12 |
from plots import get_pre_define_colors
|
13 |
+
from utils.load_model import load_xclip
|
14 |
from utils.predict import xclip_pred
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
#! Huggingface does not allow load model to main process, so we need to load the model when needed, it may not help in improve the speed of the app.
|
18 |
try:
|
19 |
import spaces
|
20 |
XCLIP, OWLVIT_PRECESSOR = None, None
|
21 |
+
DEVICE = 'cuda'
|
22 |
except:
|
23 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
print(f"Not at Huggingface demo, load model to main process.")
|
25 |
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE)
|
26 |
+
|
27 |
+
print(f"Device: {DEVICE}")
|
28 |
+
|
29 |
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json"
|
30 |
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r"))
|
31 |
IMAGES_FOLDER = "data/images"
|
32 |
# XCLIP_RESULTS = json.load(open("data/jsons/xclip_org.json", "r"))
|
33 |
IMAGE2GT = json.load(open("data/jsons/image2gt.json", 'r'))
|
34 |
+
CUB_DESC_EMBEDS = torch.load('data/text_embeddings/cub_200_desc.pt')
|
35 |
CUB_IDX2NAME = json.load(open('data/jsons/cub_desc_idx2name.json', 'r'))
|
36 |
CUB_IDX2NAME = {int(k): v for k, v in CUB_IDX2NAME.items()}
|
37 |
|
utils/predict.py
CHANGED
@@ -91,7 +91,7 @@ def xclip_pred(new_desc: dict,
|
|
91 |
modified_class_idx = 200
|
92 |
else:
|
93 |
n_classes = 200
|
94 |
-
query_embeds = cub_embeds
|
95 |
idx2name = cub_idx2name
|
96 |
modified_class_idx = None
|
97 |
|
|
|
91 |
modified_class_idx = 200
|
92 |
else:
|
93 |
n_classes = 200
|
94 |
+
query_embeds = cub_embeds.to(device)
|
95 |
idx2name = cub_idx2name
|
96 |
modified_class_idx = None
|
97 |
|