Spaces:
Sleeping
Sleeping
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
import os | |
from typing import Tuple | |
from PIL import Image, ImageDraw, ImageFont | |
from huggingface_hub import hf_hub_download | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from inference_utils.inference import interactive_infer_image | |
from inference_utils.output_processing import check_mask_stats | |
from modeling import build_model | |
from modeling.BaseModel import BaseModel | |
from utilities.arguments import load_opt_from_config_files | |
from utilities.constants import BIOMED_CLASSES | |
from utilities.distributed import init_distributed | |
zero_tensor = torch.zeros(1, 1, 1) | |
class PredictionTarget: | |
target: str | |
pred_mask: torch.Tensor = zero_tensor | |
adjusted_p_value: float = -1.0 | |
class Model: | |
def init(self): | |
self._model = init_model() | |
def predict( | |
self, image: Image.Image, modality_type: str, targets: list[str] | |
) -> Tuple[Image.Image, str]: | |
image_annotated, prediction_targets_not_found = predict( | |
self._model, image, modality_type, targets | |
) | |
targets_not_found_str = ( | |
"\n".join( | |
f"{t.target} ({t.adjusted_p_value:.3f})" | |
for t in prediction_targets_not_found | |
) | |
if prediction_targets_not_found | |
else "All targets were found!" | |
) | |
return image_annotated, targets_not_found_str | |
def init_model() -> BaseModel: | |
# Download model | |
model_file = hf_hub_download( | |
repo_id="microsoft/BiomedParse", | |
filename="biomedparse_v1.pt", | |
token=os.getenv("HF_TOKEN"), | |
) | |
# Initialize model | |
conf_files = "configs/biomedparse_inference.yaml" | |
opt = load_opt_from_config_files([conf_files]) | |
opt = init_distributed(opt) | |
model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda() | |
with torch.no_grad(): | |
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings( | |
BIOMED_CLASSES + ["background"], is_eval=True | |
) | |
return model | |
def predict( | |
model: BaseModel, image: Image.Image, modality_type: str, targets: list[str] | |
) -> Tuple[Image.Image, list[PredictionTarget]]: | |
assert len(targets) > 0, "At least one target is required" | |
prediction_tasks = [PredictionTarget(target=target) for target in targets] | |
# Convert to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Get predictions | |
pred_mask = interactive_infer_image(model, image, targets) | |
for i, pt in enumerate(prediction_tasks): | |
pt.pred_mask = pred_mask[i] | |
image_np = np.array(image) | |
for pt in prediction_tasks: | |
adj_p_value = check_mask_stats( | |
image_np, pt.pred_mask * 255, modality_type, pt.target | |
) | |
pt.adjusted_p_value = float(adj_p_value) | |
pred_targets_found, pred_tasks_not_found = segregate_prediction_tasks( | |
prediction_tasks, 0.05 | |
) | |
# Generate visualization | |
colors = generate_colors(len(pred_targets_found)) | |
masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_targets_found))] | |
pred_overlay = overlay_masks(image, masks, colors) | |
pred_overlay = add_legend(pred_overlay, pred_targets_found, colors) | |
return pred_overlay, pred_tasks_not_found | |
def segregate_prediction_tasks( | |
prediction_tasks: list[PredictionTarget], p_value_threshold: float | |
) -> tuple[list[PredictionTarget], list[PredictionTarget]]: | |
"""Segregates Prediction Tasks by p-value | |
Prediction tasks with a p-value higher than p_value_threshold go into the targets_found list. | |
Otherwise, they go into the targets_not_found list. | |
""" | |
targets_found = [] | |
targets_not_found = [] | |
for pt in prediction_tasks: | |
if pt.adjusted_p_value > p_value_threshold: | |
targets_found.append(pt) | |
else: | |
targets_not_found.append(pt) | |
return targets_found, targets_not_found | |
def generate_colors(n: int) -> list[Tuple[int, int, int]]: | |
cmap = plt.get_cmap("tab10") | |
colors = [ | |
(int(255 * cmap(i)[0]), int(255 * cmap(i)[1]), int(255 * cmap(i)[2])) | |
for i in range(n) | |
] | |
return colors | |
def overlay_masks( | |
image: Image.Image, | |
masks: list[np.ndarray], | |
colors: list[Tuple[int, int, int]], | |
) -> Image.Image: | |
overlay = image.copy() | |
overlay = np.array(overlay, dtype=np.uint8) | |
for mask, color in zip(masks, colors): | |
overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype( | |
np.uint8 | |
) | |
return Image.fromarray(overlay) | |
def add_legend( | |
image: Image.Image, | |
pred_targets_found: list[PredictionTarget], | |
colors: list[Tuple[int, int, int]], | |
) -> Image.Image: | |
if len(pred_targets_found) == 0: | |
return image | |
# Convert to numpy for manipulation | |
pred_overlay = np.array(image) | |
# Calculate dimensions based on image resolution | |
image_width = pred_overlay.shape[1] | |
font_size = max(16, int(image_width * 0.02)) # Scale with image width, minimum 16px | |
box_size = int(font_size * 1.5) # Color box proportional to font | |
entry_height = int(box_size * 1.5) # Space between entries | |
legend_padding = int(font_size * 0.75) # Padding scales with font | |
# Calculate total legend height | |
legend_height = entry_height * len(pred_targets_found) | |
total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding | |
# Create new image with space for legend | |
new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8) | |
new_image[: pred_overlay.shape[0], :] = pred_overlay | |
new_image[pred_overlay.shape[0] :] = 255 # White background for legend | |
# Convert to PIL once for all legend entries | |
img_pil = Image.fromarray(new_image) | |
draw = ImageDraw.Draw(img_pil) | |
# Try to load a system font with proper scaling | |
font = None | |
system_fonts = [ | |
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", # Linux | |
"/System/Library/Fonts/Helvetica.ttc", # macOS | |
"C:\\Windows\\Fonts\\arial.ttf", # Windows | |
] | |
for font_path in system_fonts: | |
try: | |
font = ImageFont.truetype(font_path, font_size) | |
break | |
except (OSError, IOError): | |
continue | |
if font is None: | |
# Fallback to default font if no system fonts are available | |
font = ImageFont.load_default() | |
# Get font metrics for proper vertical centering | |
bbox = font.getbbox("Aj") # Use tall characters to get true height | |
font_height = bbox[3] - bbox[1] # bottom - top | |
# Draw legend entries | |
start_y = pred_overlay.shape[0] + legend_padding | |
for i, task in enumerate(pred_targets_found): | |
# Draw color box | |
box_x = legend_padding | |
box_y = start_y + i * entry_height | |
box_coords = (box_x, box_y, box_x + box_size, box_y + box_size) | |
draw.rectangle(box_coords, fill=colors[i]) | |
# Draw text (vertically centered with color box) | |
text_y = box_y + (box_size - font_height) // 2 # Center text with box | |
# Format text with truncated p-value | |
p_value_truncated = "{:.2f}".format(task.adjusted_p_value) | |
legend_text = f"{task.target} ({p_value_truncated})" | |
draw.text( | |
(box_x + box_size + legend_padding, text_y), | |
legend_text, | |
fill=(0, 0, 0), | |
font=font, | |
) | |
return img_pil | |