Spaces:
Running
Running
# app.py | |
import torch | |
import gradio as gr | |
import os | |
import requests | |
import base64 | |
from libra.eval import libra_eval | |
def generate_radiology_description( | |
prompt: str, | |
uploaded_current: str, | |
uploaded_prior: str, | |
temperature: float, | |
top_p: float, | |
num_beams: int, | |
max_new_tokens: int | |
) -> str: | |
if not uploaded_current or not uploaded_prior: | |
return "Please upload both current and prior images." | |
model_path = "X-iZhang/libra-v1.0-7b" | |
conv_mode = "libra_v1" | |
try: | |
print("Before calling libra_eval") | |
output = libra_eval( | |
model_path=model_path, | |
model_base=None, | |
image_file=[uploaded_current, uploaded_prior], | |
query=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
num_beams=num_beams, | |
length_penalty=1.0, | |
num_return_sequences=1, | |
conv_mode=conv_mode, | |
max_new_tokens=max_new_tokens | |
) | |
print("After calling libra_eval, result:", output) | |
return output | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Libra Radiology Report Generator (Local Upload Only)") | |
gr.Markdown("Upload **Current** and **Prior** images below to generate a radiology description using the Libra model.") | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
value="Describe the key findings in these two images." | |
) | |
with gr.Row(): | |
uploaded_current = gr.Image( | |
label="Upload Current Image", | |
type="filepath" | |
) | |
uploaded_prior = gr.Image( | |
label="Upload Prior Image", | |
type="filepath" | |
) | |
with gr.Row(): | |
temperature_slider = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.7 | |
) | |
top_p_slider = gr.Slider( | |
label="Top P", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.8 | |
) | |
num_beams_slider = gr.Slider( | |
label="Number of Beams", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=2 | |
) | |
max_tokens_slider = gr.Slider( | |
label="Max New Tokens", | |
minimum=10, | |
maximum=4096, | |
step=10, | |
value=128 | |
) | |
output_text = gr.Textbox( | |
label="Generated Description", | |
lines=10 | |
) | |
generate_button = gr.Button("Generate Description") | |
generate_button.click( | |
fn=generate_radiology_description, | |
inputs=[ | |
prompt_input, | |
uploaded_current, | |
uploaded_prior, | |
temperature_slider, | |
top_p_slider, | |
num_beams_slider, | |
max_tokens_slider | |
], | |
outputs=output_text | |
) | |
if __name__ == "__main__": | |
demo.launch() |