Thedatababbler commited on
Commit
080a3e4
·
1 Parent(s): d0f0ee4
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -19,6 +19,7 @@ def mlm(image, text):
19
  }
20
  ans = list()
21
  res = defaultdict(list)
 
22
  for k, v in questions_dict.items():
23
  predicted_tokens = []
24
  tokenized_text = tokenizer.tokenize(v)
@@ -27,8 +28,9 @@ def mlm(image, text):
27
  segments_ids = [0] * len(tokenized_text)
28
 
29
  # Convert inputs to PyTorch tensors
30
- tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
31
- segments_tensors = torch.tensor([segments_ids]).to('cuda')
 
32
 
33
  masked_index = tokenized_text.index('[MASK]')
34
  with torch.no_grad():
 
19
  }
20
  ans = list()
21
  res = defaultdict(list)
22
+ device = 'cpu'
23
  for k, v in questions_dict.items():
24
  predicted_tokens = []
25
  tokenized_text = tokenizer.tokenize(v)
 
28
  segments_ids = [0] * len(tokenized_text)
29
 
30
  # Convert inputs to PyTorch tensors
31
+
32
+ tokens_tensor = torch.tensor([indexed_tokens]).to(device)
33
+ segments_tensors = torch.tensor([segments_ids]).to(device)
34
 
35
  masked_index = tokenized_text.index('[MASK]')
36
  with torch.no_grad():