Spaces:
Sleeping
Sleeping
update
Browse files- model_factory.py +0 -3
- models/ram.py +3 -1
model_factory.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import torch
|
2 |
from models.ram import RAM
|
3 |
|
4 |
-
from huggingface_hub import hf_hub_download
|
5 |
-
|
6 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
7 |
|
8 |
def get_model():
|
@@ -16,5 +14,4 @@ def get_model():
|
|
16 |
:return: model
|
17 |
"""
|
18 |
model = RAM()
|
19 |
-
model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
|
20 |
return model
|
|
|
1 |
import torch
|
2 |
from models.ram import RAM
|
3 |
|
|
|
|
|
4 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
5 |
|
6 |
def get_model():
|
|
|
14 |
:return: model
|
15 |
"""
|
16 |
model = RAM()
|
|
|
17 |
return model
|
models/ram.py
CHANGED
@@ -7,7 +7,9 @@ from deepinv.physics import Physics, LinearPhysics, Downsampling
|
|
7 |
from deepinv.utils import TensorList
|
8 |
from deepinv.utils.tensorlist import TensorList
|
9 |
|
10 |
-
|
|
|
|
|
11 |
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
12 |
|
13 |
class RAM(nn.Module):
|
|
|
7 |
from deepinv.utils import TensorList
|
8 |
from deepinv.utils.tensorlist import TensorList
|
9 |
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
14 |
|
15 |
class RAM(nn.Module):
|