SIAT-RZJS commited on
Commit
7f56050
·
verified ·
1 Parent(s): 2953191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -189
app.py CHANGED
@@ -1,189 +1,130 @@
1
- import json
2
-
3
- from KMVE_RG.models.SGF_model import SGF
4
- from KMVE_RG.modules.tokenizers import Tokenizer
5
- from KMVE_RG.modules.metrics import compute_scores
6
- import numpy as np
7
- from Demo.utils.thyroid_gen_config import config as thyroid_args
8
- from Demo.utils.liver_gen_config import config as liver_args
9
- from Demo.utils.breast_gen_config import config as breast_args
10
-
11
- import gradio as gr
12
- import torch
13
- from PIL import Image
14
- import os
15
- from torchvision import transforms
16
-
17
- np.random.seed(9233)
18
- torch.manual_seed(9233)
19
- torch.backends.cudnn.deterministic = True
20
- torch.backends.cudnn.benchmark = False
21
-
22
- class Generator(object):
23
- def __init__(self, model_type):
24
- if model_type == '甲状腺':
25
- self.args = thyroid_args
26
- elif model_type == '乳腺':
27
- self.args = breast_args
28
- elif model_type == '肝脏':
29
- self.args = liver_args
30
- self.tokenizer = Tokenizer(self.args)
31
- self.model = SGF(self.args, self.tokenizer)
32
- sd = torch.load(self.args.models)['state_dict']
33
- msg = self.model.load_state_dict(sd)
34
- print(msg)
35
- self.model.eval()
36
- self.metrics = compute_scores
37
- self.transform = transforms.Compose([
38
- transforms.Resize((224, 224)),
39
- transforms.ToTensor(),
40
- transforms.Normalize((0.485, 0.456, 0.406),
41
- (0.229, 0.224, 0.225))])
42
- with open(self.args.ann_path, 'r', encoding='utf-8-sig') as f:
43
- self.data = json.load(f)
44
- print('模型加载完成')
45
-
46
- def image_process(self, img_paths):
47
- image_1 = Image.open(os.path.join(self.args.image_dir, img_paths[0])).convert('RGB')
48
- image_2 = Image.open(os.path.join(self.args.image_dir, img_paths[1])).convert('RGB')
49
- if self.transform is not None:
50
- image_1 = self.transform(image_1)
51
- image_2 = self.transform(image_2)
52
- image = torch.stack((image_1, image_2), 0)
53
- return image
54
-
55
- def generate(self, uid):
56
- img_paths, report = self.data[uid]['img_paths'], self.data[uid]['report']
57
- imgs = self.image_process(img_paths)
58
- imgs = imgs.unsqueeze(0)
59
- with torch.no_grad():
60
- output, _ = self.model(imgs, mode='sample')
61
- pred = self.tokenizer.decode(output[0].cpu().numpy())
62
- gt = self.tokenizer.decode(self.tokenizer(report[:self.args.max_seq_length])[1:])
63
- scores = self.metrics({0: [gt]}, {0: [pred]})
64
- return pred, gt, scores
65
-
66
- def visualize_images(self, uid):
67
- image_1 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][0])).convert('RGB')
68
- image_2 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][1])).convert('RGB')
69
- return image_1, image_2
70
-
71
- # 主应用程序
72
- def demo():
73
- with gr.Blocks() as app:
74
- gr.Markdown("# 超声报告生成Demo")
75
- gr.Markdown('### SIAT认知与交互技术中心')
76
- gr.Markdown('### 项目主页:https://lijunrio.github.io/Ultrasound-Report-Generation/')
77
-
78
- # 选择模型
79
- with gr.Row():
80
- model_choice = gr.Radio(choices=["甲状腺", "乳腺", "肝脏"], label="请选择模型类型", interactive=True)
81
-
82
- model = gr.State()
83
-
84
- # 展示UID按钮
85
- uids = [f"uid_{i}" for i in range(20)]
86
- with gr.Row():
87
- uid_choice = gr.Radio(choices=[f"{uid}" for uid in uids], label="请选择uid", interactive=False)
88
-
89
- # 定义展示图片的组件
90
- with gr.Row():
91
- image1_display = gr.Image(label="图像1", visible=True)
92
- image2_display = gr.Image(label="图像2", visible=True)
93
-
94
- # 定义生成报告的按钮和文本框
95
- generate_button = gr.Button("生成报告", interactive=False)
96
- generated_report_display = gr.Textbox(label="生成的报告", visible=True)
97
- ground_truth_display = gr.Textbox(label="Ground Truth报告", visible=True)
98
- nlp_score_display = gr.Textbox(label="NLP得分", visible=True)
99
-
100
- # 加载模型的回调函数
101
- def load_model_and_uids(model_type):
102
- model = Generator(model_type)
103
- return model, gr.update(interactive=True)
104
-
105
- # 点击UID按钮后加载对应的图片
106
- def on_uid_click(model, uid):
107
- image1, image2 = model.visualize_images(uid)
108
- # 显示图片和生成按钮
109
- return image1, image2, gr.update(interactive=True)
110
-
111
- # 点击生成按钮生成报告
112
- def on_generate_click(model, uid):
113
- generated_report, ground_truth_report, nlp_score = model.generate(uid)
114
- # 展示生成的报告、Ground Truth 和 NLP 得分
115
- return generated_report, ground_truth_report, f"NLP得分: {nlp_score}"
116
-
117
- # 链接模型选择与UID按钮显示
118
- model_choice.change(load_model_and_uids, inputs=model_choice, outputs=[model, uid_choice])
119
-
120
- # 链接UID按钮点击与图片显示
121
- # for uid_button in uid_buttons:
122
- # uid_button.click(on_uid_click, inputs=[model, uid_button], outputs=[image1_display, image2_display, uid_choice])
123
- uid_choice.change(on_uid_click, inputs=[model, uid_choice], outputs=[image1_display, image2_display, generate_button])
124
-
125
- generate_button.click(on_generate_click, inputs=[model, uid_choice], outputs=[generated_report_display, ground_truth_display, nlp_score_display])
126
-
127
- return app
128
-
129
- # 主应用程序
130
- # def demo():
131
- # with gr.Blocks() as app:
132
- # gr.Markdown("# 医学报告生成 Demo")
133
- #
134
- # # 选择模型类型
135
- # model_choice = gr.Radio(choices=["甲状腺", "乳腺", "肝脏"], label="请选择模型类型", interactive=True)
136
- #
137
- # # 创建空的 Generator 实例(将稍后初始化)
138
- # generator_instance = gr.State()
139
- #
140
- # # 展示 UID 按钮
141
- # uids = [f"uid_{i}" for i in range(20)]
142
- # selected_uid = gr.State() # 用于存储当前选择的 UID
143
- # uid_buttons = [gr.Button(f"{uid}") for uid in uids]
144
- #
145
- # # 定义展示图片的组件
146
- # with gr.Row():
147
- # image1_display = gr.Image(label="图像 1", visible=False)
148
- # image2_display = gr.Image(label="图像 2", visible=False)
149
- #
150
- # # 定义生成报告的按钮和文本框
151
- # generate_button = gr.Button("生成报告", visible=False)
152
- # generated_report_display = gr.Textbox(label="生成的报告", visible=False)
153
- # ground_truth_display = gr.Textbox(label="Ground Truth 报告", visible=False)
154
- # nlp_score_display = gr.Textbox(label="NLP 得分", visible=False)
155
- #
156
- # # 模型选择后初始化 Generator 类
157
- # def initialize_generator(model_type):
158
- # generator = Generator(model_type) # 初始化 Generator
159
- # return generator, True # 返回生成器实例,显示 UID 按钮
160
- #
161
- # # 点击 UID 按钮后可视化对应图片
162
- # def on_uid_click(uid, generator):
163
- # image1, image2 = generator.visual_images(uid)
164
- # return image1, image2, uid, True # 返回图片、UID,显示生成按钮
165
- #
166
- # # 点击生成按钮生成 Ground Truth 报告、预测结果和 NLP 分数
167
- # def on_generate_click(generator, uid):
168
- # ground_truth, predict, nlp_score = generator.generate(uid)
169
- # return ground_truth, predict, nlp_score
170
- #
171
- # # 链接模型选择与生成器初始化
172
- # model_choice.change(initialize_generator, inputs=model_choice, outputs=[generator_instance, uid_buttons[0]])
173
- #
174
- # # 链接 UID 按钮点击与图片显示
175
- # for i, uid_button in enumerate(uid_buttons):
176
- # uid_button.click(on_uid_click, inputs=[selected_uid, generator_instance],
177
- # outputs=[image1_display, image2_display, selected_uid, generate_button],
178
- # fn=lambda uid=uids[i]: uid)
179
- #
180
- # # 点击生成按钮时生成报告
181
- # generate_button.click(on_generate_click, inputs=[generator_instance, selected_uid],
182
- # outputs=[ground_truth_display, generated_report_display, nlp_score_display])
183
- #
184
- # return app
185
-
186
- if __name__ == '__main__':
187
- # 启动应用程序
188
- demo().launch()
189
-
 
1
+ import json
2
+
3
+ from KMVE_RG.models.SGF_model import SGF
4
+ from KMVE_RG.modules.tokenizers import Tokenizer
5
+ from KMVE_RG.modules.metrics import compute_scores
6
+ import numpy as np
7
+ from utils.thyroid_gen_config import config as thyroid_args
8
+ from utils.liver_gen_config import config as liver_args
9
+ from utils.breast_gen_config import config as breast_args
10
+
11
+ import gradio as gr
12
+ import torch
13
+ from PIL import Image
14
+ import os
15
+ from torchvision import transforms
16
+
17
+ np.random.seed(9233)
18
+ torch.manual_seed(9233)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+ class Generator(object):
23
+ def __init__(self, model_type):
24
+ if model_type == '甲状腺':
25
+ self.args = thyroid_args
26
+ elif model_type == '乳腺':
27
+ self.args = breast_args
28
+ elif model_type == '肝脏':
29
+ self.args = liver_args
30
+ self.tokenizer = Tokenizer(self.args)
31
+ self.model = SGF(self.args, self.tokenizer)
32
+ sd = torch.load(self.args.models)['state_dict']
33
+ msg = self.model.load_state_dict(sd)
34
+ print(msg)
35
+ self.model.eval()
36
+ self.metrics = compute_scores
37
+ self.transform = transforms.Compose([
38
+ transforms.Resize((224, 224)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize((0.485, 0.456, 0.406),
41
+ (0.229, 0.224, 0.225))])
42
+ with open(self.args.ann_path, 'r', encoding='utf-8-sig') as f:
43
+ self.data = json.load(f)
44
+ print('模型加载完成')
45
+
46
+ def image_process(self, img_paths):
47
+ image_1 = Image.open(os.path.join(self.args.image_dir, img_paths[0])).convert('RGB')
48
+ image_2 = Image.open(os.path.join(self.args.image_dir, img_paths[1])).convert('RGB')
49
+ if self.transform is not None:
50
+ image_1 = self.transform(image_1)
51
+ image_2 = self.transform(image_2)
52
+ image = torch.stack((image_1, image_2), 0)
53
+ return image
54
+
55
+ def generate(self, uid):
56
+ img_paths, report = self.data[uid]['img_paths'], self.data[uid]['report']
57
+ imgs = self.image_process(img_paths)
58
+ imgs = imgs.unsqueeze(0)
59
+ with torch.no_grad():
60
+ output, _ = self.model(imgs, mode='sample')
61
+ pred = self.tokenizer.decode(output[0].cpu().numpy())
62
+ gt = self.tokenizer.decode(self.tokenizer(report[:self.args.max_seq_length])[1:])
63
+ scores = self.metrics({0: [gt]}, {0: [pred]})
64
+ return pred, gt, scores
65
+
66
+ def visualize_images(self, uid):
67
+ image_1 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][0])).convert('RGB')
68
+ image_2 = Image.open(os.path.join(self.args.image_dir, self.data[uid]['img_paths'][1])).convert('RGB')
69
+ return image_1, image_2
70
+
71
+ # 主应用程序
72
+ def demo():
73
+ with gr.Blocks() as app:
74
+ gr.Markdown("# 超声报告生成Demo")
75
+ gr.Markdown('### SIAT认知与交互技术中心')
76
+ gr.Markdown('### 项目主页:https://lijunrio.github.io/Ultrasound-Report-Generation/')
77
+
78
+ # 选择模型
79
+ with gr.Row():
80
+ model_choice = gr.Radio(choices=["甲状腺", "乳腺", "肝脏"], label="请选择模型类型", interactive=True)
81
+
82
+ model = gr.State()
83
+
84
+ # 展示UID按钮
85
+ uids = [f"uid_{i}" for i in range(20)]
86
+ with gr.Row():
87
+ uid_choice = gr.Radio(choices=[f"{uid}" for uid in uids], label="请选择uid", interactive=False)
88
+
89
+ # 定义展示图片的组件
90
+ with gr.Row():
91
+ image1_display = gr.Image(label="图像1", visible=True)
92
+ image2_display = gr.Image(label="图像2", visible=True)
93
+
94
+ # 定义生成报告的按钮和文本框
95
+ generate_button = gr.Button("生成报告", interactive=False)
96
+ generated_report_display = gr.Textbox(label="生成的报告", visible=True)
97
+ ground_truth_display = gr.Textbox(label="Ground Truth报告", visible=True)
98
+ nlp_score_display = gr.Textbox(label="NLP得分", visible=True)
99
+
100
+ # 加载模型的回调函数
101
+ def load_model_and_uids(model_type):
102
+ model = Generator(model_type)
103
+ return model, gr.update(interactive=True)
104
+
105
+ # 点击UID按钮后加载对应的图片
106
+ def on_uid_click(model, uid):
107
+ image1, image2 = model.visualize_images(uid)
108
+ # 显示图片和生成按钮
109
+ return image1, image2, gr.update(interactive=True)
110
+
111
+ # 点击生成按钮生成报告
112
+ def on_generate_click(model, uid):
113
+ generated_report, ground_truth_report, nlp_score = model.generate(uid)
114
+ # 展示生成的报告、Ground Truth 和 NLP 得分
115
+ return generated_report, ground_truth_report, f"NLP得分: {nlp_score}"
116
+
117
+ # 链接模型选择与UID按钮显示
118
+ model_choice.change(load_model_and_uids, inputs=model_choice, outputs=[model, uid_choice])
119
+
120
+ # 链接UID按钮点击与图片显示
121
+ uid_choice.change(on_uid_click, inputs=[model, uid_choice], outputs=[image1_display, image2_display, generate_button])
122
+
123
+ generate_button.click(on_generate_click, inputs=[model, uid_choice], outputs=[generated_report_display, ground_truth_display, nlp_score_display])
124
+
125
+ return app
126
+
127
+ if __name__ == '__main__':
128
+ # 启动应用程序
129
+ demo().launch()
130
+