# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import gradio as gr import torch from PIL import Image import numpy as np import matplotlib.pyplot as plt from huggingface_hub import hf_hub_download from modeling.BaseModel import BaseModel from modeling import build_model from utilities.distributed import init_distributed from utilities.arguments import load_opt_from_config_files from utilities.constants import BIOMED_CLASSES from inference_utils.inference import interactive_infer_image def overlay_masks(image, masks, colors): overlay = image.copy() overlay = np.array(overlay, dtype=np.uint8) for mask, color in zip(masks, colors): overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype( np.uint8 ) return Image.fromarray(overlay) def generate_colors(n): cmap = plt.get_cmap("tab10") colors = [tuple(int(255 * val) for val in cmap(i)[:3]) for i in range(n)] return colors def init_model(): # Download model model_file = hf_hub_download( repo_id="microsoft/BiomedParse", filename="biomedparse_v1.pt", token=os.getenv("HF_TOKEN"), ) # Initialize model conf_files = "configs/biomedparse_inference.yaml" opt = load_opt_from_config_files([conf_files]) opt = init_distributed(opt) model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda() with torch.no_grad(): model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings( BIOMED_CLASSES + ["background"], is_eval=True ) return model def predict(image, prompts): if not prompts: return None # Convert string input to list prompts = [p.strip() for p in prompts.split(",")] # Convert to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Get predictions pred_mask = interactive_infer_image(model, image, prompts) # Generate visualization colors = generate_colors(len(prompts)) pred_overlay = overlay_masks( image, [1 * (pred_mask[i] > 0.5) for i in range(len(prompts))], colors ) return pred_overlay def run(): global model model = init_model() demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox( label="Prompts", placeholder="Enter prompts separated by commas (e.g., neoplastic cells, inflammatory cells)", ), ], outputs=gr.Image(type="pil", label="Prediction"), title="BiomedParse Demo", description="Upload a biomedical image and enter prompts (separated by commas) to detect specific features.", examples=[ ["examples/144DME_as_F.jpeg", "edema"], ["examples/C3_EndoCV2021_00462.jpg", "polyp"], ["examples/covid_1585.png", "left lung"], ["examples/covid_1585.png", "right lung"], ["examples/covid_1585.png", "COVID-19 infection"], ["examples/ISIC_0015551.jpg", "lesion"], ["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "lung nodule"], ["examples/LIDC-IDRI-0140_143_280_CT_lung.png", "COVID-19 infection"], [ "examples/Part_1_516_pathology_breast.png", "connective tissue cells", ], [ "examples/Part_1_516_pathology_breast.png", "neoplastic cells", ], [ "examples/Part_1_516_pathology_breast.png", "neoplastic cells, inflammatory cells", ], ["examples/T0011.jpg", "optic disc"], ["examples/T0011.jpg", "optic cup"], ["examples/TCGA_HT_7856_19950831_8_MRI-FLAIR_brain.png", "glioma"], ], ) demo.launch(server_name="0.0.0.0", server_port=7860) if __name__ == "__main__": run()