mterris commited on
Commit
9368b8d
·
1 Parent(s): a1a82a6
Files changed (2) hide show
  1. model_factory.py +0 -3
  2. models/ram.py +3 -1
model_factory.py CHANGED
@@ -1,8 +1,6 @@
1
  import torch
2
  from models.ram import RAM
3
 
4
- from huggingface_hub import hf_hub_download
5
-
6
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
 
8
  def get_model():
@@ -16,5 +14,4 @@ def get_model():
16
  :return: model
17
  """
18
  model = RAM()
19
- model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
20
  return model
 
1
  import torch
2
  from models.ram import RAM
3
 
 
 
4
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
5
 
6
  def get_model():
 
14
  :return: model
15
  """
16
  model = RAM()
 
17
  return model
models/ram.py CHANGED
@@ -7,7 +7,9 @@ from deepinv.physics import Physics, LinearPhysics, Downsampling
7
  from deepinv.utils import TensorList
8
  from deepinv.utils.tensorlist import TensorList
9
 
10
- cuda = True if torch.cuda.is_available() else False
 
 
11
  Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
12
 
13
  class RAM(nn.Module):
 
7
  from deepinv.utils import TensorList
8
  from deepinv.utils.tensorlist import TensorList
9
 
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
  Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
14
 
15
  class RAM(nn.Module):