Nguyen Thai Thao Uyen commited on
Commit
59dad06
·
1 Parent(s): ebbf608

run.py update device

Browse files
Files changed (1) hide show
  1. run.py +3 -0
run.py CHANGED
@@ -9,6 +9,9 @@ import PIL
9
  def pred(src):
10
  # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
11
  # Load the model configuration
 
 
 
12
  cache_dir = "/code/cache"
13
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base",
14
  cache_dir=cache_dir)
 
9
  def pred(src):
10
  # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
11
  # Load the model configuration
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model.to(device)
14
+
15
  cache_dir = "/code/cache"
16
  model_config = SamConfig.from_pretrained("facebook/sam-vit-base",
17
  cache_dir=cache_dir)