File size: 7,957 Bytes
5b4c740
 
 
 
 
 
 
 
 
 
 
 
0a9ad49
99b73a0
0a9ad49
 
fc90b14
287d863
 
 
99b73a0
287d863
0a9ad49
 
287d863
 
 
 
 
 
 
0a9ad49
 
 
 
 
 
 
 
 
 
99b73a0
 
 
287d863
0a9ad49
 
 
 
 
 
 
fc90b14
380e0a4
fc90b14
 
0a9ad49
 
 
 
287d863
99b73a0
0a9ad49
99b73a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287d863
99b73a0
 
287d863
 
0a9ad49
 
 
 
 
 
287d863
 
 
 
 
 
0a9ad49
287d863
0a9ad49
 
 
 
 
 
 
 
 
 
 
fc90b14
0a9ad49
287d863
 
0a9ad49
fc90b14
 
0a9ad49
 
fc90b14
354d315
0a9ad49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287d863
 
0a9ad49
99b73a0
0a9ad49
 
 
 
99b73a0
287d863
 
0a9ad49
 
 
 
 
99b73a0
 
 
 
 
287d863
99b73a0
fc90b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import os
from typing import Tuple

from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import numpy as np
import torch

from inference_utils.inference import interactive_infer_image
from inference_utils.output_processing import check_mask_stats
from modeling import build_model
from modeling.BaseModel import BaseModel
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
from utilities.distributed import init_distributed


zero_tensor = torch.zeros(1, 1, 1)


@dataclass
class PredictionTarget:
    target: str
    pred_mask: torch.Tensor = zero_tensor
    adjusted_p_value: float = -1.0


class Model:
    def init(self):
        self._model = init_model()

    def predict(
        self, image: Image.Image, modality_type: str, targets: list[str]
    ) -> Tuple[Image.Image, str]:
        image_annotated, prediction_targets_not_found = predict(
            self._model, image, modality_type, targets
        )
        targets_not_found_str = (
            "\n".join(
                f"{t.target} ({t.adjusted_p_value:.3f})"
                for t in prediction_targets_not_found
            )
            if prediction_targets_not_found
            else "All targets were found!"
        )
        return image_annotated, targets_not_found_str


def init_model() -> BaseModel:
    # Download model
    model_file = hf_hub_download(
        repo_id="microsoft/BiomedParse",
        filename="biomedparse_v1.pt",
        token=os.getenv("HF_TOKEN"),
    )

    # Initialize model
    conf_files = "configs/biomedparse_inference.yaml"
    opt = load_opt_from_config_files([conf_files])
    opt = init_distributed(opt)

    model = BaseModel(opt, build_model(opt)).from_pretrained(model_file).eval().cuda()
    with torch.no_grad():
        model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(
            BIOMED_CLASSES + ["background"], is_eval=True
        )

    return model


def predict(
    model: BaseModel, image: Image.Image, modality_type: str, targets: list[str]
) -> Tuple[Image.Image, list[PredictionTarget]]:
    assert len(targets) > 0, "At least one target is required"

    prediction_tasks = [PredictionTarget(target=target) for target in targets]

    # Convert to RGB if needed
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Get predictions
    pred_mask = interactive_infer_image(model, image, targets)

    for i, pt in enumerate(prediction_tasks):
        pt.pred_mask = pred_mask[i]

    image_np = np.array(image)

    for pt in prediction_tasks:
        adj_p_value = check_mask_stats(
            image_np, pt.pred_mask * 255, modality_type, pt.target
        )
        pt.adjusted_p_value = float(adj_p_value)

    pred_targets_found, pred_tasks_not_found = segregate_prediction_tasks(
        prediction_tasks, 0.05
    )

    # Generate visualization
    colors = generate_colors(len(pred_targets_found))
    masks = [1 * (pred_mask[i] > 0.5) for i in range(len(pred_targets_found))]
    pred_overlay = overlay_masks(image, masks, colors)

    pred_overlay = add_legend(pred_overlay, pred_targets_found, colors)

    return pred_overlay, pred_tasks_not_found


def segregate_prediction_tasks(
    prediction_tasks: list[PredictionTarget], p_value_threshold: float
) -> tuple[list[PredictionTarget], list[PredictionTarget]]:
    """Segregates Prediction Tasks by p-value

    Prediction tasks with a p-value higher than p_value_threshold go into the targets_found list.
    Otherwise, they go into the targets_not_found list.
    """

    targets_found = []
    targets_not_found = []
    for pt in prediction_tasks:
        if pt.adjusted_p_value > p_value_threshold:
            targets_found.append(pt)
        else:
            targets_not_found.append(pt)

    return targets_found, targets_not_found


def generate_colors(n: int) -> list[Tuple[int, int, int]]:
    cmap = plt.get_cmap("tab10")
    colors = [
        (int(255 * cmap(i)[0]), int(255 * cmap(i)[1]), int(255 * cmap(i)[2]))
        for i in range(n)
    ]
    return colors


def overlay_masks(
    image: Image.Image,
    masks: list[np.ndarray],
    colors: list[Tuple[int, int, int]],
) -> Image.Image:
    overlay = image.copy()
    overlay = np.array(overlay, dtype=np.uint8)
    for mask, color in zip(masks, colors):
        overlay[mask > 0] = (overlay[mask > 0] * 0.4 + np.array(color) * 0.6).astype(
            np.uint8
        )
    return Image.fromarray(overlay)


def add_legend(
    image: Image.Image,
    pred_targets_found: list[PredictionTarget],
    colors: list[Tuple[int, int, int]],
) -> Image.Image:
    if len(pred_targets_found) == 0:
        return image

    # Convert to numpy for manipulation
    pred_overlay = np.array(image)

    # Calculate dimensions based on image resolution
    image_width = pred_overlay.shape[1]
    font_size = max(16, int(image_width * 0.02))  # Scale with image width, minimum 16px
    box_size = int(font_size * 1.5)  # Color box proportional to font
    entry_height = int(box_size * 1.5)  # Space between entries
    legend_padding = int(font_size * 0.75)  # Padding scales with font

    # Calculate total legend height
    legend_height = entry_height * len(pred_targets_found)
    total_height = pred_overlay.shape[0] + legend_height + 2 * legend_padding

    # Create new image with space for legend
    new_image = np.zeros((total_height, pred_overlay.shape[1], 3), dtype=np.uint8)
    new_image[: pred_overlay.shape[0], :] = pred_overlay
    new_image[pred_overlay.shape[0] :] = 255  # White background for legend

    # Convert to PIL once for all legend entries
    img_pil = Image.fromarray(new_image)
    draw = ImageDraw.Draw(img_pil)

    # Try to load a system font with proper scaling
    font = None
    system_fonts = [
        "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",  # Linux
        "/System/Library/Fonts/Helvetica.ttc",  # macOS
        "C:\\Windows\\Fonts\\arial.ttf",  # Windows
    ]
    for font_path in system_fonts:
        try:
            font = ImageFont.truetype(font_path, font_size)
            break
        except (OSError, IOError):
            continue

    if font is None:
        # Fallback to default font if no system fonts are available
        font = ImageFont.load_default()

    # Get font metrics for proper vertical centering
    bbox = font.getbbox("Aj")  # Use tall characters to get true height
    font_height = bbox[3] - bbox[1]  # bottom - top

    # Draw legend entries
    start_y = pred_overlay.shape[0] + legend_padding
    for i, task in enumerate(pred_targets_found):
        # Draw color box
        box_x = legend_padding
        box_y = start_y + i * entry_height
        box_coords = (box_x, box_y, box_x + box_size, box_y + box_size)
        draw.rectangle(box_coords, fill=colors[i])

        # Draw text (vertically centered with color box)
        text_y = box_y + (box_size - font_height) // 2  # Center text with box
        # Format text with truncated p-value
        p_value_truncated = "{:.2f}".format(task.adjusted_p_value)
        legend_text = f"{task.target} ({p_value_truncated})"
        draw.text(
            (box_x + box_size + legend_padding, text_y),
            legend_text,
            fill=(0, 0, 0),
            font=font,
        )

    return img_pil