File size: 5,313 Bytes
7f56050 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
from KMVE_RG.models.SGF_model import SGF
from KMVE_RG.modules.tokenizers import Tokenizer
from KMVE_RG.modules.metrics import compute_scores
import numpy as np
from utils.thyroid_gen_config import config as thyroid_args
from utils.liver_gen_config import config as liver_args
from utils.breast_gen_config import config as breast_args
import gradio as gr
import torch
from PIL import Image
import os
from torchvision import transforms
np.random.seed(9233)
torch.manual_seed(9233)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class Generator(object):
def __init__(self, model_type):
if model_type == '甲状腺':
self.args = thyroid_args
elif model_type == '乳腺':
self.args = breast_args
elif model_type == '肝脏':
self.args = liver_args
self.tokenizer = Tokenizer(self.args)
self.model = SGF(self.args, self.tokenizer)
sd = torch.load(self.args.models)['state_dict']
msg = self.model.load_state_dict(sd)
print(msg)
self.model.eval()
self.metrics = compute_scores
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
with open(self.args.ann_path, 'r', encoding='utf-8-sig') as f:
self.data = json.load(f)
print('模型加载完成')
def image_process(self, img_paths):
image_1 = Image.open(os.path.join(self.args.image_dir, img_paths[0])).convert('RGB')
image_2 = Image.open(os.path.join(self.args.image_dir, img_paths[1])).convert('RGB')
if self.transform is not None:
image_1 = self.transform(image_1)
image_2 = self.transform(image_2)
image = torch.stack((image_1, image_2), 0)
return image
def generate(self, uid):
img_paths, report = self.data[uid]['img_paths'], self.data[uid]['report']
imgs = self.image_process(img_paths)
imgs = imgs.unsqueeze(0)
with torch.no_grad():
output, _ = self.model(imgs, mode='sample')
pred = self.tokenizer.decode(output[0].cpu().numpy())
gt = self.tokenizer.decode(self.tokenizer(report[:self.args.max_seq_length])[1:])
scores = self.metrics({0: [gt]}, {0: [pred]})
return pred, gt, scores
def visualize_images(self, uid):
image_1 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][0])).convert('RGB')
image_2 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][1])).convert('RGB')
return image_1, image_2
# 主应用程序
def demo():
with gr.Blocks() as app:
gr.Markdown("# 超声报告生成Demo")
gr.Markdown('### SIAT认知与交互技术中心')
gr.Markdown('### 项目主页:https://lijunrio.github.io/Ultrasound-Report-Generation/')
# 选择模型
with gr.Row():
model_choice = gr.Radio(choices=["甲状腺", "乳腺", "肝脏"], label="请选择模型类型", interactive=True)
model = gr.State()
# 展示UID按钮
uids = [f"uid_{i}" for i in range(20)]
with gr.Row():
uid_choice = gr.Radio(choices=[f"{uid}" for uid in uids], label="请选择uid", interactive=False)
# 定义展示图片的组件
with gr.Row():
image1_display = gr.Image(label="图像1", visible=True)
image2_display = gr.Image(label="图像2", visible=True)
# 定义生成报告的按钮和文本框
generate_button = gr.Button("生成报告", interactive=False)
generated_report_display = gr.Textbox(label="生成的报告", visible=True)
ground_truth_display = gr.Textbox(label="Ground Truth报告", visible=True)
nlp_score_display = gr.Textbox(label="NLP得分", visible=True)
# 加载模型的回调函数
def load_model_and_uids(model_type):
model = Generator(model_type)
return model, gr.update(interactive=True)
# 点击UID按钮后加载对应的图片
def on_uid_click(model, uid):
image1, image2 = model.visualize_images(uid)
# 显示图片和生成按钮
return image1, image2, gr.update(interactive=True)
# 点击生成按钮生成报告
def on_generate_click(model, uid):
generated_report, ground_truth_report, nlp_score = model.generate(uid)
# 展示生成的报告、Ground Truth 和 NLP 得分
return generated_report, ground_truth_report, f"NLP得分: {nlp_score}"
# 链接模型选择与UID按钮显示
model_choice.change(load_model_and_uids, inputs=model_choice, outputs=[model, uid_choice])
# 链接UID按钮点击与图片显示
uid_choice.change(on_uid_click, inputs=[model, uid_choice], outputs=[image1_display, image2_display, generate_button])
generate_button.click(on_generate_click, inputs=[model, uid_choice], outputs=[generated_report_display, ground_truth_display, nlp_score_display])
return app
if __name__ == '__main__':
# 启动应用程序
demo().launch()
|