Spaces:
Runtime error
Runtime error
Commit
·
277a2e9
1
Parent(s):
3740f48
Auto deploy
Browse files- stCompressService.py +5 -5
stCompressService.py
CHANGED
@@ -25,7 +25,7 @@ HF_SPACE = "HF_SPACE" in os.environ
|
|
25 |
|
26 |
|
27 |
@st.experimental_singleton
|
28 |
-
def loadModel(
|
29 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
30 |
|
31 |
config = Config.deserialize(ckpt["config"])
|
@@ -65,13 +65,13 @@ def decompressImage(sourceFile: File, model: BaseCompressor) -> torch.ByteTensor
|
|
65 |
|
66 |
|
67 |
|
68 |
-
def main(
|
69 |
-
if
|
70 |
device = torch.device("cpu")
|
71 |
else:
|
72 |
device = torch.device("cuda")
|
73 |
|
74 |
-
model = loadModel(
|
75 |
|
76 |
st.sidebar.markdown("""
|
77 |
<p align="center">
|
@@ -215,4 +215,4 @@ def main(debug: bool, quiet: bool, qp: int, disable_gpu: bool):
|
|
215 |
|
216 |
if __name__ == "__main__":
|
217 |
with torch.inference_mode():
|
218 |
-
main(
|
|
|
25 |
|
26 |
|
27 |
@st.experimental_singleton
|
28 |
+
def loadModel(device):
|
29 |
ckpt = torch.hub.load_state_dict_from_url(MODELS_URL, map_location=device, check_hash=True)
|
30 |
|
31 |
config = Config.deserialize(ckpt["config"])
|
|
|
65 |
|
66 |
|
67 |
|
68 |
+
def main():
|
69 |
+
if not torch.cuda.is_available():
|
70 |
device = torch.device("cpu")
|
71 |
else:
|
72 |
device = torch.device("cuda")
|
73 |
|
74 |
+
model = loadModel(device).eval()
|
75 |
|
76 |
st.sidebar.markdown("""
|
77 |
<p align="center">
|
|
|
215 |
|
216 |
if __name__ == "__main__":
|
217 |
with torch.inference_mode():
|
218 |
+
main()
|