Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4396298
1
Parent(s):
752e7b9
Add cache_dir in from_pretrained.
Browse files
app.py
CHANGED
@@ -19,10 +19,11 @@ import zipfile
|
|
19 |
|
20 |
|
21 |
transformers.utils.move_cache()
|
22 |
-
|
23 |
-
os.environ["
|
24 |
-
os.environ["
|
25 |
-
os.
|
|
|
26 |
|
27 |
|
28 |
torch.set_float32_matmul_precision('high')
|
@@ -96,7 +97,7 @@ usage_to_weights_file = {
|
|
96 |
'General-dynamic': 'BiRefNet_dynamic',
|
97 |
}
|
98 |
|
99 |
-
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
100 |
birefnet.to(device)
|
101 |
birefnet.eval(); birefnet.half()
|
102 |
|
@@ -109,7 +110,7 @@ def predict(images, resolution, weights_file):
|
|
109 |
# Load BiRefNet with chosen weights
|
110 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
111 |
print('Using weights: {}.'.format(_weights_file))
|
112 |
-
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
113 |
birefnet.to(device)
|
114 |
birefnet.eval(); birefnet.half()
|
115 |
|
|
|
19 |
|
20 |
|
21 |
transformers.utils.move_cache()
|
22 |
+
hf_cache_path = '/tmp/hf_cache'
|
23 |
+
os.environ["HUGGINGFACE_HUB_CACHE"] = hf_cache_path
|
24 |
+
os.environ["HF_HOME"] = hf_cache_path
|
25 |
+
os.environ["TRANSFORMERS_CACHE"] = hf_cache_path
|
26 |
+
os.makedirs(hf_cache_path, exist_ok=True)
|
27 |
|
28 |
|
29 |
torch.set_float32_matmul_precision('high')
|
|
|
97 |
'General-dynamic': 'BiRefNet_dynamic',
|
98 |
}
|
99 |
|
100 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True, cache_dir=hf_cache_path)
|
101 |
birefnet.to(device)
|
102 |
birefnet.eval(); birefnet.half()
|
103 |
|
|
|
110 |
# Load BiRefNet with chosen weights
|
111 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
112 |
print('Using weights: {}.'.format(_weights_file))
|
113 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True, cache_dir=hf_cache_path)
|
114 |
birefnet.to(device)
|
115 |
birefnet.eval(); birefnet.half()
|
116 |
|