File size: 5,313 Bytes
7f56050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json

from KMVE_RG.models.SGF_model import SGF
from KMVE_RG.modules.tokenizers import Tokenizer
from KMVE_RG.modules.metrics import compute_scores
import numpy as np
from utils.thyroid_gen_config import config as thyroid_args
from utils.liver_gen_config import config as liver_args
from utils.breast_gen_config import config as breast_args

import gradio as gr
import torch
from PIL import Image
import os
from torchvision import transforms

np.random.seed(9233)
torch.manual_seed(9233)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class Generator(object):
    def __init__(self, model_type):
        if model_type == '甲状腺':
            self.args = thyroid_args
        elif model_type == '乳腺':
            self.args = breast_args
        elif model_type == '肝脏':
            self.args = liver_args
        self.tokenizer = Tokenizer(self.args)
        self.model = SGF(self.args, self.tokenizer)
        sd = torch.load(self.args.models)['state_dict']
        msg = self.model.load_state_dict(sd)
        print(msg)
        self.model.eval()
        self.metrics = compute_scores
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])
        with open(self.args.ann_path, 'r', encoding='utf-8-sig') as f:
            self.data = json.load(f)
        print('模型加载完成')

    def image_process(self, img_paths):
        image_1 = Image.open(os.path.join(self.args.image_dir, img_paths[0])).convert('RGB')
        image_2 = Image.open(os.path.join(self.args.image_dir, img_paths[1])).convert('RGB')
        if self.transform is not None:
            image_1 = self.transform(image_1)
            image_2 = self.transform(image_2)
        image = torch.stack((image_1, image_2), 0)
        return image

    def generate(self, uid):
        img_paths, report = self.data[uid]['img_paths'], self.data[uid]['report']
        imgs = self.image_process(img_paths)
        imgs = imgs.unsqueeze(0)
        with torch.no_grad():
            output, _ = self.model(imgs, mode='sample')
        pred = self.tokenizer.decode(output[0].cpu().numpy())
        gt = self.tokenizer.decode(self.tokenizer(report[:self.args.max_seq_length])[1:])
        scores = self.metrics({0: [gt]}, {0: [pred]})
        return pred, gt, scores

    def visualize_images(self, uid):
        image_1 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][0])).convert('RGB')
        image_2 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][1])).convert('RGB')
        return image_1, image_2

# 主应用程序
def demo():
    with gr.Blocks() as app:
        gr.Markdown("# 超声报告生成Demo")
        gr.Markdown('### SIAT认知与交互技术中心')
        gr.Markdown('### 项目主页:https://lijunrio.github.io/Ultrasound-Report-Generation/')

        # 选择模型
        with gr.Row():
            model_choice = gr.Radio(choices=["甲状腺", "乳腺", "肝脏"], label="请选择模型类型", interactive=True)

        model = gr.State()

        # 展示UID按钮
        uids = [f"uid_{i}" for i in range(20)]
        with gr.Row():
            uid_choice = gr.Radio(choices=[f"{uid}" for uid in uids], label="请选择uid", interactive=False)

        # 定义展示图片的组件
        with gr.Row():
            image1_display = gr.Image(label="图像1", visible=True)
            image2_display = gr.Image(label="图像2", visible=True)

        # 定义生成报告的按钮和文本框
        generate_button = gr.Button("生成报告", interactive=False)
        generated_report_display = gr.Textbox(label="生成的报告", visible=True)
        ground_truth_display = gr.Textbox(label="Ground Truth报告", visible=True)
        nlp_score_display = gr.Textbox(label="NLP得分", visible=True)

        # 加载模型的回调函数
        def load_model_and_uids(model_type):
            model = Generator(model_type)
            return model, gr.update(interactive=True)

        # 点击UID按钮后加载对应的图片
        def on_uid_click(model, uid):
            image1, image2 = model.visualize_images(uid)
            # 显示图片和生成按钮
            return image1, image2, gr.update(interactive=True)

        # 点击生成按钮生成报告
        def on_generate_click(model, uid):
            generated_report, ground_truth_report, nlp_score = model.generate(uid)
            # 展示生成的报告、Ground Truth 和 NLP 得分
            return generated_report, ground_truth_report, f"NLP得分: {nlp_score}"

        # 链接模型选择与UID按钮显示
        model_choice.change(load_model_and_uids, inputs=model_choice, outputs=[model, uid_choice])

        # 链接UID按钮点击与图片显示
        uid_choice.change(on_uid_click, inputs=[model, uid_choice], outputs=[image1_display, image2_display, generate_button])

        generate_button.click(on_generate_click, inputs=[model, uid_choice], outputs=[generated_report_display, ground_truth_display, nlp_score_display])

    return app

if __name__ == '__main__':
    # 启动应用程序
    demo().launch()