Thedatababbler commited on
Commit
d5b900f
·
1 Parent(s): bc22792
Files changed (1) hide show
  1. app.py +42 -1
app.py CHANGED
@@ -1,10 +1,51 @@
1
  import gradio as gr
2
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def to_black(image, text):
5
  output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
6
  outputs = [output, text]
7
  return outputs
8
 
9
- interface = gr.Interface(fn=to_black, inputs=["image", "text"], outputs=["image", "text"])
10
  interface.launch()
 
1
  import gradio as gr
2
  import cv2
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+ from collections import defaultdict
6
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
7
+ model = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
8
+
9
+ def mlm(image, text):
10
+ questions_dict = {
11
+ #'location': f'[CLS] Only [MASK] cells have a {cls_name}. [SEP]', #num of mask?
12
+ # 'location': f'[CLS] The {cls_name} normally appears at or near the [MASK] of a cell. [SEP]',
13
+ # 'color': f'[CLS] When a cell is histologically stained, the {cls_name} are in [MASK] color. [SEP]',
14
+ # 'shape': f'[CLS] Mostly the shape of {cls_name} is [MASK]. [SEP]',
15
+ 'location': f'[CLS] The location of {text} is at [MASK]. [SEP]',
16
+ 'color': f'[CLS] The typical color of {text} is [MASK]. [SEP]',
17
+ 'shape': f'[CLS] The typical shape of {text} is [MASK]. [SEP]',
18
+ #'def': f'{cls_name} is a . [SEP]',
19
+ }
20
+ ans = list()
21
+ res = defaultdict(list)
22
+ for k, v in questions_dict.items():
23
+ predicted_tokens = []
24
+ tokenized_text = tokenizer.tokenize(v)
25
+ indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
26
+ # Create the segments tensors.
27
+ segments_ids = [0] * len(tokenized_text)
28
+
29
+ # Convert inputs to PyTorch tensors
30
+ tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
31
+ segments_tensors = torch.tensor([segments_ids]).to('cuda')
32
+
33
+ masked_index = tokenized_text.index('[MASK]')
34
+ with torch.no_grad():
35
+ predictions = model(tokens_tensor, segments_tensors)
36
+
37
+ _, predicted_index = torch.topk(predictions[0][0][masked_index], topk)#.item()
38
+ predicted_index = predicted_index.detach().cpu().numpy()
39
+ #print(predicted_index)
40
+ for idx in predicted_index:
41
+ predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0])
42
+
43
+ return image, res
44
 
45
  def to_black(image, text):
46
  output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
47
  outputs = [output, text]
48
  return outputs
49
 
50
+ interface = gr.Interface(fn=mlm, inputs=["image", "text"], outputs=["image", "text"])
51
  interface.launch()