Thedatababbler
tested
fcca540
import gradio as gr
import cv2
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from collections import defaultdict
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
model = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
def mlm(image, text):
print(text)
questions_dict = {
#'location': f'[CLS] Only [MASK] cells have a {cls_name}. [SEP]', #num of mask?
# 'location': f'[CLS] The {cls_name} normally appears at or near the [MASK] of a cell. [SEP]',
# 'color': f'[CLS] When a cell is histologically stained, the {cls_name} are in [MASK] color. [SEP]',
# 'shape': f'[CLS] Mostly the shape of {cls_name} is [MASK]. [SEP]',
'location': f'This {text} is at [MASK] place',
'color': f'This {text} is in [MASK] color',
'shape': f'This {text} is in [MASK] shape',
#'def': f'{cls_name} is a . [SEP]',
}
ans = list()
res = defaultdict()
device = 'cpu'
for k, v in questions_dict.items():
predicted_tokens = []
print(v)
tokenized_text = tokenizer.tokenize(v)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Create the segments tensors.
segments_ids = [0] * len(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]).to(device)
segments_tensors = torch.tensor([segments_ids]).to(device)
masked_index = tokenized_text.index('[MASK]')
with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
_, predicted_index = torch.topk(predictions[0][0][masked_index], 1)#.item()
predicted_index = predicted_index.detach().cpu().numpy()
print(predicted_index)
for idx in predicted_index:
predicted_tokens.append(tokenizer.convert_ids_to_tokens([idx])[0])
# for i in range(1):
# res[text][k].append(predicted_tokens)
print(predicted_tokens)
res[k] = predicted_tokens[0]
color, shape, loc = res['color'], res['shape'], res['location']
ans = f'{color} color, {shape} shape, {text} at {loc}'
print(ans)
return image, ans
def to_black(image, text):
output = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
outputs = [output, text]
return outputs
interface = gr.Interface(fn=mlm, inputs=["image", "text"], outputs=["image", "text"])
interface.launch()