X-iZhang commited on
Commit
edc96fb
·
verified ·
1 Parent(s): e1b2b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -82
app.py CHANGED
@@ -3,124 +3,258 @@ import torch
3
  import gradio as gr
4
  import os
5
  import requests
6
- import base64
7
-
8
  from libra.eval import libra_eval
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def generate_radiology_description(
 
 
 
 
11
  prompt: str,
12
- uploaded_current: str,
13
- uploaded_prior: str,
14
  temperature: float,
15
  top_p: float,
16
  num_beams: int,
17
- max_new_tokens: int
 
18
  ) -> str:
 
 
 
 
 
 
 
19
 
20
-
21
- if not uploaded_current or not uploaded_prior:
22
- return "Please upload both current and prior images."
23
 
24
-
25
- model_path = "X-iZhang/libra-v1.0-7b"
26
- conv_mode = "libra_v1"
 
 
 
 
27
 
28
- try:
 
 
 
 
 
 
29
 
30
- print("Before calling libra_eval")
31
  output = libra_eval(
32
- model_path=model_path,
33
- model_base=None,
34
- image_file=[uploaded_current, uploaded_prior],
35
  query=prompt,
36
  temperature=temperature,
37
  top_p=top_p,
38
  num_beams=num_beams,
39
  length_penalty=1.0,
40
  num_return_sequences=1,
41
- conv_mode=conv_mode,
42
  max_new_tokens=max_new_tokens
43
  )
44
- print("After calling libra_eval, result:", output)
45
  return output
46
  except Exception as e:
47
- return f"An error occurred: {str(e)}"
48
 
 
 
 
49
 
50
- with gr.Blocks() as demo:
51
-
52
- gr.Markdown("# Libra Radiology Report Generator (Local Upload Only)")
53
- gr.Markdown("Upload **Current** and **Prior** images below to generate a radiology description using the Libra model.")
54
 
55
-
56
- prompt_input = gr.Textbox(
57
- label="Prompt",
58
- value="Describe the key findings in these two images."
59
- )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- with gr.Row():
63
- uploaded_current = gr.Image(
64
- label="Upload Current Image",
65
- type="filepath"
66
- )
67
- uploaded_prior = gr.Image(
68
- label="Upload Prior Image",
69
- type="filepath"
70
- )
71
 
 
 
 
72
 
73
- with gr.Row():
74
- temperature_slider = gr.Slider(
75
- label="Temperature",
76
- minimum=0.1,
77
- maximum=1.0,
78
- step=0.1,
79
- value=0.7
80
- )
81
- top_p_slider = gr.Slider(
82
- label="Top P",
83
- minimum=0.1,
84
- maximum=1.0,
85
- step=0.1,
86
- value=0.8
 
87
  )
88
- num_beams_slider = gr.Slider(
89
- label="Number of Beams",
90
- minimum=1,
91
- maximum=20,
92
- step=1,
93
- value=2
 
 
 
 
 
94
  )
95
- max_tokens_slider = gr.Slider(
96
- label="Max New Tokens",
97
- minimum=10,
98
- maximum=4096,
99
- step=10,
100
- value=128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- output_text = gr.Textbox(
105
- label="Generated Description",
106
- lines=10
107
- )
108
 
 
 
 
 
109
 
110
- generate_button = gr.Button("Generate Description")
111
- generate_button.click(
112
- fn=generate_radiology_description,
113
- inputs=[
114
- prompt_input,
115
- uploaded_current,
116
- uploaded_prior,
117
- temperature_slider,
118
- top_p_slider,
119
- num_beams_slider,
120
- max_tokens_slider
121
- ],
122
- outputs=output_text
123
- )
124
 
125
  if __name__ == "__main__":
126
- demo.launch()
 
 
 
 
 
3
  import gradio as gr
4
  import os
5
  import requests
6
+ import argparse
 
7
  from libra.eval import libra_eval
8
+ from libra.eval.run_libra import load_model
9
+
10
+ DEFAULT_MODEL_PATH = "X-iZhang/libra-v1.0-7b"
11
+
12
+ def get_model_short_name(model_path: str) -> str:
13
+ """
14
+ 提取模型路径最后一个 '/' 之后的部分,作为在下拉菜单中显示的名字。
15
+ 例如: "X-iZhang/libra-v1.0-7b" -> "libra-v1.0-7b"
16
+ """
17
+ return model_path.rstrip("/").split("/")[-1]
18
+
19
+ # 全局/或在main里定义都行,这里示例放在外层
20
+ loaded_models = {} # {model_key: reuse_model_object}
21
 
22
  def generate_radiology_description(
23
+ selected_model_name: str,
24
+ current_img_data,
25
+ prior_img_data,
26
+ use_no_prior: bool,
27
  prompt: str,
 
 
28
  temperature: float,
29
  top_p: float,
30
  num_beams: int,
31
+ max_new_tokens: int,
32
+ model_paths_dict: dict
33
  ) -> str:
34
+ """
35
+ 执行放射学报告推理:
36
+ 1) 根据下拉选的模型名称 -> 找到实际 model_path
37
+ 2) 确保用户选了 Current & Prior 图片
38
+ 3) 调用 libra_eval
39
+ """
40
+ real_model_path = model_paths_dict[selected_model_name]
41
 
42
+ # 若用户没选/没上传 Current Image,一定报错
43
+ if not current_img_data:
44
+ return "Error: Please select or upload the Current Image."
45
 
46
+ # 如果用户勾选了 without prior image,就把 prior_img_data 设为 current_img_data
47
+ if use_no_prior:
48
+ prior_img_data = current_img_data
49
+ else:
50
+ # 未勾选时,需要prior_img_data
51
+ if not prior_img_data:
52
+ return "Error: Please select or upload the Prior Image, or check 'Without Prior Image'."
53
 
54
+ # 若已经加载过该模型,则直接复用
55
+ if selected_model_name in loaded_models:
56
+ reuse_model = loaded_models[selected_model_name]
57
+ else:
58
+ reuse_model = load_model(real_model_path)
59
+ # 缓存起来
60
+ loaded_models[selected_model_name] = reuse_model
61
 
62
+ try:
63
  output = libra_eval(
64
+ libra_model=reuse_model,
65
+ image_file=[current_img_data, prior_img_data],
 
66
  query=prompt,
67
  temperature=temperature,
68
  top_p=top_p,
69
  num_beams=num_beams,
70
  length_penalty=1.0,
71
  num_return_sequences=1,
72
+ conv_mode="libra_v1",
73
  max_new_tokens=max_new_tokens
74
  )
 
75
  return output
76
  except Exception as e:
77
+ return f"An error occurred during model inference: {str(e)}"
78
 
79
+ def main():
80
+ # ========== 获取当前脚本 (app.py) 所在目录 ==========
81
+ cur_dir = os.path.abspath(os.path.dirname(__file__))
82
 
83
+ # ========== 准备本地示例图片的绝对路径 ==========
84
+ # 向上回退两级: app.py -> serve/ -> libra/ -> Libra/ (同级)
85
+ example_curent_path = os.path.join(cur_dir, "..", "..", "assets", "example_curent.jpg")
86
+ example_curent_path = os.path.abspath(example_curent_path)
87
 
88
+ example_prior_path = os.path.join(cur_dir, "..", "..", "assets", "example_prior.jpg")
89
+ example_prior_path = os.path.abspath(example_prior_path)
 
 
 
90
 
91
+ # Gradio Examples 要求:对单个 gr.Image 而言,每个示例写成 ["本地文件路径"]
92
+ IMAGE_EXAMPLES = [
93
+ [example_curent_path],
94
+ [example_prior_path]
95
+ ]
96
+ # ========== 命令行解析 (可选) ==========
97
+ parser = argparse.ArgumentParser(description="Demo for Radiology Image Description Generator (Local Examples)")
98
+ parser.add_argument(
99
+ "--model-path",
100
+ type=str,
101
+ default=DEFAULT_MODEL_PATH,
102
+ help="User-specified model path. If not provided, only default model is shown."
103
+ )
104
+ args = parser.parse_args()
105
+ cmd_model_path = args.model_path
106
 
107
+ # ========== 设置多模型下拉菜单 ==========
108
+ model_paths_dict = {}
109
+ user_key = get_model_short_name(cmd_model_path)
110
+ model_paths_dict[user_key] = cmd_model_path
111
+
112
+ # 如果用户传入的模型 != 默认模型,则加上默认模型选项
113
+ if cmd_model_path != DEFAULT_MODEL_PATH:
114
+ default_key = get_model_short_name(DEFAULT_MODEL_PATH)
115
+ model_paths_dict[default_key] = DEFAULT_MODEL_PATH
116
 
117
+ # (可选)若想预先加载模型,避免重复加载,可在此处:
118
+ # reuse_model = load_model(cmd_model_path)
119
+ # 然后在 generate_radiology_description 里改造传 reuse_model
120
 
121
+ # ========== 搭建 Gradio 界面 ==========
122
+ with gr.Blocks(title="Libra: Radiology Analysis with Direct URL Examples") as demo:
123
+ gr.Markdown("""
124
+ ## 🩻 Libra: Leveraging Temporal Images for Biomedical Radiology Analysis
125
+ [Project Page](https://x-izhang.github.io/Libra_v1.0/) | [Paper](https://arxiv.org/abs/2411.19378) | [Code](https://github.com/X-iZhang/Libra) | [Model](https://huggingface.co/X-iZhang/libra-v1.0-7b)
126
+
127
+ **Requires a GPU to run effectively!**
128
+ """)
129
+
130
+ # 下拉模型选择
131
+ model_dropdown = gr.Dropdown(
132
+ label="Select Model",
133
+ choices=list(model_paths_dict.keys()),
134
+ value=user_key,
135
+ interactive=True
136
  )
137
+
138
+ # 临床Prompt
139
+ prompt_input = gr.Textbox(
140
+ label="Clinical Prompt",
141
+ value="Provide a detailed description of the findings in the radiology image.",
142
+ lines=2,
143
+ info=(
144
+ "If clinical instructions are available, include them after the default prompt. "
145
+ "For example: “Provide a detailed description of the findings in the radiology image. "
146
+ "Following clinical context: Indication: chest pain, History: ...”"
147
+ )
148
  )
149
+
150
+ # Current & Prior 画像
151
+ with gr.Row():
152
+ with gr.Column():
153
+ gr.Markdown("### Current Image")
154
+ current_img = gr.Image(
155
+ label="Drop Or Upload Current Image",
156
+ type="filepath",
157
+ interactive=True
158
+ )
159
+
160
+ gr.Examples(
161
+ examples=IMAGE_EXAMPLES,
162
+ inputs=current_img,
163
+ label="Example Current Images"
164
+ )
165
+
166
+ with gr.Column():
167
+ gr.Markdown("### Prior Image")
168
+ prior_img = gr.Image(
169
+ label="Drop Or Upload Prior Image",
170
+ type="filepath",
171
+ interactive=True
172
+ )
173
+ # 新增一个复选框,勾选后表示「Without Prior Image」
174
+ with gr.Row():
175
+ gr.Examples(
176
+ examples=IMAGE_EXAMPLES,
177
+ inputs=prior_img,
178
+ label="Example Prior Images"
179
+ )
180
+ without_prior_checkbox = gr.Checkbox(
181
+ label="Without Prior Image",
182
+ value=False,
183
+ info="If checked, the current image will be used as the dummy prior image in the Libra model."
184
+ )
185
+
186
+
187
+ with gr.Accordion("Parameters Settings", open=False):
188
+ temperature_slider = gr.Slider(
189
+ label="Temperature",
190
+ minimum=0.1, maximum=1.0, step=0.1, value=0.9
191
+ )
192
+ top_p_slider = gr.Slider(
193
+ label="Top P",
194
+ minimum=0.1, maximum=1.0, step=0.1, value=0.8
195
+ )
196
+ num_beams_slider = gr.Slider(
197
+ label="Number of Beams",
198
+ minimum=1, maximum=20, step=1, value=1
199
+ )
200
+ max_tokens_slider = gr.Slider(
201
+ label="Max output tokens",
202
+ minimum=10, maximum=4096, step=10, value=128
203
+ )
204
+
205
+ output_text = gr.Textbox(
206
+ label="Generated Findings Section",
207
+ lines=5
208
  )
209
 
210
+ generate_button = gr.Button("Generate Findings Description")
211
+ generate_button.click(
212
+ fn=lambda model_name, c_img, p_img, no_prior, prompt, temp, top_p, beams, tokens: generate_radiology_description(
213
+ model_name,
214
+ c_img,
215
+ p_img,
216
+ no_prior,
217
+ prompt,
218
+ temp,
219
+ top_p,
220
+ beams,
221
+ tokens,
222
+ model_paths_dict
223
+ ),
224
+ inputs=[
225
+ model_dropdown, # model_name
226
+ current_img, # c_img
227
+ prior_img, # p_img
228
+ without_prior_checkbox, # no_prior
229
+ prompt_input, # prompt
230
+ temperature_slider,# temp
231
+ top_p_slider, # top_p
232
+ num_beams_slider, # beams
233
+ max_tokens_slider # tokens
234
+ ],
235
+ outputs=output_text
236
+ )
237
+
238
+ # 界面底部插入条款说明
239
+ gr.Markdown("""
240
+ ### Terms of Use
241
 
242
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA.
243
+
244
+ By accessing or using this demo, you acknowledge and agree to the following:
 
245
 
246
+ - **Research & Non-Commercial Purposes**: This demo is provided solely for research and demonstration. It must not be used for commercial activities or profit-driven endeavors.
247
+ - **Not Medical Advice**: All generated content is experimental and must not replace professional medical judgment.
248
+ - **Content Moderationt**: While we apply basic safety checks, the system may still produce inaccurate or offensive outputs.
249
+ - **Responsible Use**: Do not use this demo for any illegal, harmful, hateful, violent, or sexual purposes.
250
 
251
+ By continuing to use this service, you confirm your acceptance of these terms. If you do not agree, please discontinue use immediately.
252
+ """)
253
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  if __name__ == "__main__":
256
+ main()
257
+
258
+
259
+ # if __name__ == "__main__":
260
+ # demo.launch()