Spaces:
Running
Running
import gradio as gr | |
import torch | |
import requests | |
import tempfile | |
from pathlib import Path | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
_model_cache = {} | |
def load_model_and_processor(hf_token: str): | |
""" | |
Loads the MAIRA-2 model and processor from Hugging Face using the provided token. | |
The loaded objects are cached keyed by the token. | |
""" | |
if hf_token in _model_cache: | |
return _model_cache[hf_token] | |
device = torch.device("cpu") | |
model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/maira-2", | |
trust_remote_code=True, | |
use_auth_token=hf_token | |
) | |
processor = AutoProcessor.from_pretrained( | |
"microsoft/maira-2", | |
trust_remote_code=True, | |
use_auth_token=hf_token | |
) | |
model.eval() | |
model.to(device) | |
_model_cache[hf_token] = (model, processor) | |
return model, processor | |
def get_sample_data() -> dict: | |
""" | |
Downloads sample chest X-ray images and associated data. | |
""" | |
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png" | |
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png" | |
def download_and_open(url: str) -> Image.Image: | |
response = requests.get(url, headers={"User-Agent": "MAIRA-2"}, stream=True) | |
return Image.open(response.raw).convert("RGB") | |
frontal = download_and_open(frontal_image_url) | |
lateral = download_and_open(lateral_image_url) | |
return { | |
"frontal": frontal, | |
"lateral": lateral, | |
"indication": "Dyspnea.", | |
"technique": "PA and lateral views of the chest.", | |
"comparison": "None.", | |
"phrase": "Pleural effusion." | |
} | |
def generate_report(hf_token, frontal, lateral, indication, technique, comparison, use_grounding): | |
""" | |
Generates a radiology report using the MAIRA-2 model. | |
If any image/text input is missing, sample data is used. | |
""" | |
try: | |
model, processor = load_model_and_processor(hf_token) | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
device = torch.device("cpu") | |
sample = get_sample_data() | |
if frontal is None: | |
frontal = sample["frontal"] | |
if lateral is None: | |
lateral = sample["lateral"] | |
if not indication: | |
indication = sample["indication"] | |
if not technique: | |
technique = sample["technique"] | |
if not comparison: | |
comparison = sample["comparison"] | |
processed_inputs = processor.format_and_preprocess_reporting_input( | |
current_frontal=frontal, | |
current_lateral=lateral, | |
prior_frontal=None, # No prior study is used in this demo. | |
indication=indication, | |
technique=technique, | |
comparison=comparison, | |
prior_report=None, | |
return_tensors="pt", | |
get_grounding=use_grounding, | |
) | |
# Move all tensors to the CPU | |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} | |
# Remove keys containing "image_sizes" to prevent unexpected keyword errors. | |
processed_inputs = dict(processed_inputs) | |
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] | |
for key in keys_to_remove: | |
processed_inputs.pop(key, None) | |
max_tokens = 450 if use_grounding else 300 | |
with torch.no_grad(): | |
output_decoding = model.generate( | |
**processed_inputs, | |
max_new_tokens=max_tokens, | |
use_cache=True, | |
) | |
prompt_length = processed_inputs["input_ids"].shape[-1] | |
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) | |
decoded_text = decoded_text.lstrip() # Remove any leading whitespace | |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) | |
return prediction | |
def run_phrase_grounding(hf_token, frontal, phrase): | |
""" | |
Runs phrase grounding using the MAIRA-2 model. | |
If image or phrase is missing, sample data is used. | |
""" | |
try: | |
model, processor = load_model_and_processor(hf_token) | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
device = torch.device("cpu") | |
sample = get_sample_data() | |
if frontal is None: | |
frontal = sample["frontal"] | |
if not phrase: | |
phrase = sample["phrase"] | |
processed_inputs = processor.format_and_preprocess_phrase_grounding_input( | |
frontal_image=frontal, | |
phrase=phrase, | |
return_tensors="pt", | |
) | |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()} | |
# Remove keys containing "image_sizes" to prevent unexpected keyword errors. | |
processed_inputs = dict(processed_inputs) | |
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k] | |
for key in keys_to_remove: | |
processed_inputs.pop(key, None) | |
with torch.no_grad(): | |
output_decoding = model.generate( | |
**processed_inputs, | |
max_new_tokens=150, | |
use_cache=True, | |
) | |
prompt_length = processed_inputs["input_ids"].shape[-1] | |
decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True) | |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text) | |
return prediction | |
def login_ui(hf_token): | |
"""Authenticate the user by loading the model.""" | |
try: | |
load_model_and_processor(hf_token) | |
return "🔓 Login successful! You can now use the model." | |
except Exception as e: | |
return f"❌ Login failed: {str(e)}" | |
def generate_report_ui(hf_token, frontal_path, lateral_path, indication, technique, comparison, | |
prior_frontal_path, prior_lateral_path, prior_report, grounding): | |
""" | |
Wrapper for generate_report that accepts file paths (from the UI) for images. | |
Prior study fields are ignored. | |
""" | |
try: | |
frontal = Image.open(frontal_path) if frontal_path else None | |
lateral = Image.open(lateral_path) if lateral_path else None | |
except Exception as e: | |
return f"❌ Error loading images: {str(e)}" | |
return generate_report(hf_token, frontal, lateral, indication, technique, comparison, grounding) | |
def run_phrase_grounding_ui(hf_token, frontal_path, phrase): | |
""" | |
Wrapper for run_phrase_grounding that accepts a file path for the frontal image. | |
""" | |
try: | |
frontal = Image.open(frontal_path) if frontal_path else None | |
except Exception as e: | |
return f"❌ Error loading image: {str(e)}" | |
return run_phrase_grounding(hf_token, frontal, phrase) | |
def save_temp_image(img: Image.Image) -> str: | |
"""Save a PIL image to a temporary file and return the file path.""" | |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
img.save(temp_file.name) | |
return temp_file.name | |
def load_sample_findings(): | |
""" | |
Loads sample data for the report generation tab. | |
Returns file paths for current study images, sample text fields, and dummy values for prior study. | |
""" | |
sample = get_sample_data() | |
return [ | |
save_temp_image(sample["frontal"]), # frontal image file path | |
save_temp_image(sample["lateral"]), # lateral image file path | |
sample["indication"], | |
sample["technique"], | |
sample["comparison"], | |
None, # prior frontal (not used) | |
None, # prior lateral (not used) | |
None, # prior report (not used) | |
False # grounding checkbox default | |
] | |
def load_sample_phrase(): | |
""" | |
Loads sample data for the phrase grounding tab. | |
Returns file path for the frontal image and a sample phrase. | |
""" | |
sample = get_sample_data() | |
return [save_temp_image(sample["frontal"]), sample["phrase"]] | |
with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo: | |
gr.Markdown( | |
""" | |
# MAIRA-2 Medical Assistant | |
**Authentication required** - You need a Hugging Face account and access token to use this model. | |
1. Get your access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) | |
2. Request model access at [https://huggingface.co/microsoft/maira-2](https://huggingface.co/microsoft/maira-2) | |
3. Paste your token below to begin | |
""" | |
) | |
with gr.Row(): | |
hf_token = gr.Textbox( | |
label="Hugging Face Token", | |
placeholder="hf_xxxxxxxxxxxxxxxxxxxx", | |
type="password" | |
) | |
login_btn = gr.Button("Authenticate") | |
login_status = gr.Textbox(label="Authentication Status", interactive=False) | |
login_btn.click( | |
login_ui, | |
inputs=hf_token, | |
outputs=login_status | |
) | |
with gr.Tabs(): | |
with gr.Tab("Report Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Current Study") | |
frontal = gr.Image(label="Frontal View", type="filepath") | |
lateral = gr.Image(label="Lateral View", type="filepath") | |
indication = gr.Textbox(label="Clinical Indication") | |
technique = gr.Textbox(label="Imaging Technique") | |
comparison = gr.Textbox(label="Comparison") | |
gr.Markdown("## Prior Study (Optional)") | |
prior_frontal = gr.Image(label="Prior Frontal View", type="filepath") | |
prior_lateral = gr.Image(label="Prior Lateral View", type="filepath") | |
prior_report = gr.Textbox(label="Prior Report") | |
grounding = gr.Checkbox(label="Include Grounding") | |
sample_btn = gr.Button("Load Sample Data") | |
with gr.Column(): | |
report_output = gr.Textbox(label="Generated Report", lines=10) | |
generate_btn = gr.Button("Generate Report") | |
sample_btn.click( | |
load_sample_findings, | |
outputs=[frontal, lateral, indication, technique, comparison, | |
prior_frontal, prior_lateral, prior_report, grounding] | |
) | |
generate_btn.click( | |
generate_report_ui, | |
inputs=[hf_token, frontal, lateral, indication, technique, comparison, | |
prior_frontal, prior_lateral, prior_report, grounding], | |
outputs=report_output | |
) | |
with gr.Tab("Phrase Grounding"): | |
with gr.Row(): | |
with gr.Column(): | |
pg_frontal = gr.Image(label="Frontal View", type="filepath") | |
phrase = gr.Textbox(label="Phrase to Ground") | |
pg_sample_btn = gr.Button("Load Sample Data") | |
with gr.Column(): | |
pg_output = gr.Textbox(label="Grounding Result", lines=3) | |
pg_btn = gr.Button("Find Phrase") | |
pg_sample_btn.click( | |
load_sample_phrase, | |
outputs=[pg_frontal, phrase] | |
) | |
pg_btn.click( | |
run_phrase_grounding_ui, | |
inputs=[hf_token, pg_frontal, phrase], | |
outputs=pg_output | |
) | |
demo.launch() | |