Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This file contains the unit tests for the utils.py file.""" | |
import numpy as np | |
from PIL import Image | |
import torch | |
# pylint: disable=g-bad-import-order | |
from modeling.model import utils | |
def test_scoremap2bbox(): | |
"""Test the scoremap2bbox function.""" | |
scoremap = np.zeros((10, 10)) | |
scoremap[1:5, 1:5] = 1 | |
scoremap[5:9, 5:9] = 2 | |
scoremap[5:9, 1:5] = 3 | |
scoremap[1:5, 5:9] = 4 | |
bbox, len_bboxes = utils.scoremap2bbox(scoremap, 0.5) | |
assert len_bboxes == 1 | |
assert bbox[0, 0] == 1 | |
assert bbox[0, 1] == 1 | |
assert bbox[0, 2] == 9 | |
assert bbox[0, 3] == 9 | |
def test_mask2chw(): | |
"""Test the mask2chw function.""" | |
mask = np.zeros((10, 10)) | |
mask[1:5, 1:5] = 1 | |
mask[5:9, 5:9] = 2 | |
mask[5:9, 1:5] = 3 | |
mask[1:5, 5:9] = 4 | |
mask = torch.tensor(mask) | |
mask_center, mask_height, mask_width = utils.mask2chw(mask) | |
assert len(mask_center) == 2 | |
assert mask_center[0] == 2 | |
assert mask_center[1] == 2 | |
assert mask_height == 4 | |
assert mask_width == 4 | |
def test_unpad(): | |
"""Test the unpad function.""" | |
image = np.zeros((10, 10, 1)) | |
image[1:5, 1:5] = 1 | |
image[5:9, 5:9] = 2 | |
image[5:9, 1:5] = 3 | |
image[1:5, 5:9] = 4 | |
unpad_image = utils.unpad(image, pad=(1, 1, 8, 8)) | |
assert len(unpad_image[0]) == 8, 'The width of the image is not 8.' | |
assert len(unpad_image[1]) == 8, 'The height of the image is not 8.' | |
unpad_image = utils.unpad(image, None) | |
assert (unpad_image == image).sum() == 100 | |
def test_apply_visual_prompts(): | |
"""Test the apply_visual_prompts function.""" | |
image = np.ones((5, 5)) | |
mask = np.array([ | |
[0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0], | |
[0, 0, 1.0, 0, 0], | |
[0, 0, 0, 0, 0], | |
[0, 0, 0, 0, 0], | |
]) | |
target = np.array([ | |
[1, 1, 255, 1, 1], | |
[1, 255, 1, 255, 1], | |
[255, 1, 1, 1, 255], | |
[1, 255, 1, 255, 1], | |
[1, 1, 255, 1, 1], | |
]) | |
mask[1:5, 1:5] = 1 | |
prompted_image = utils.apply_visual_prompts( | |
image, mask, visual_prompt_type='circle', thickness=1 | |
) | |
prompted_array = np.array(prompted_image) | |
assert (prompted_array == target).sum() == 25 | |
def test_reshape_transform(): | |
"""Test the reshape_transform function.""" | |
image = torch.zeros((101, 10, 32)) | |
image = utils.reshape_transform(image, height=10, width=10) | |
b, c, h, w = image.shape | |
assert b == 10 | |
assert c == 32 | |
assert h == 10 | |
assert w == 10 | |
def test_img_ms_and_flip(): | |
"""Test the img_ms_and_flip function.""" | |
image = np.zeros((120, 150)) | |
image[1:5, 1:5] = 1 | |
image[5:9, 5:9] = 2 | |
image[5:9, 1:5] = 3 | |
image[1:5, 5:9] = 4 | |
image = Image.fromarray(image) | |
image = utils.img_ms_and_flip(image, 120, 150, scales=[1.2], patch_size=16) | |
image = image[0] | |
h, w = image.shape[-2:] | |
assert h == int(np.ceil(1.2 * 120 / 16) * 16) | |
assert w == int(np.ceil(1.2 * 150 / 16) * 16) | |
if __name__ == '__main__': | |
test_scoremap2bbox() | |
test_mask2chw() | |
test_unpad() | |
test_apply_visual_prompts() | |
test_reshape_transform() | |
test_img_ms_and_flip() | |