Thedatababbler commited on
Commit
0a95385
·
1 Parent(s): a596520
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -18,7 +18,7 @@ def mlm(image, text):
18
  #'def': f'{cls_name} is a . [SEP]',
19
  }
20
  ans = list()
21
- res = defaultdict(list)
22
  device = 'cpu'
23
  for k, v in questions_dict.items():
24
  predicted_tokens = []
@@ -36,15 +36,18 @@ def mlm(image, text):
36
  with torch.no_grad():
37
  predictions = model(tokens_tensor, segments_tensors)
38
 
39
- _, predicted_index = torch.topk(predictions[0][0][masked_index], 2)#.item()
40
  predicted_index = predicted_index.detach().cpu().numpy()
41
  #print(predicted_index)
42
  for idx in predicted_index:
43
  predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0])
44
- for i in range(2):
45
- res[text][k].append(predicted_tokens)
 
 
 
46
 
47
- return image, res
48
 
49
  def to_black(image, text):
50
  output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
18
  #'def': f'{cls_name} is a . [SEP]',
19
  }
20
  ans = list()
21
+ res = defaultdict()
22
  device = 'cpu'
23
  for k, v in questions_dict.items():
24
  predicted_tokens = []
 
36
  with torch.no_grad():
37
  predictions = model(tokens_tensor, segments_tensors)
38
 
39
+ _, predicted_index = torch.topk(predictions[0][0][masked_index], 1)#.item()
40
  predicted_index = predicted_index.detach().cpu().numpy()
41
  #print(predicted_index)
42
  for idx in predicted_index:
43
  predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0])
44
+ # for i in range(1):
45
+ # res[text][k].append(predicted_tokens)
46
+ res[k] = predicted_tokens[0]
47
+ color, shape, loc = res['color'], res['shape'], res['location']
48
+ ans = f'{color} color, {shape} shape, cat at {loc}'
49
 
50
+ return image, ans
51
 
52
  def to_black(image, text):
53
  output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)