File size: 34,471 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 |
# mypy: allow-untyped-defs
import json
import logging
import os
import struct
from typing import Any, List, Optional
import torch
import numpy as np
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import (
HistogramProto,
Summary,
SummaryMetadata,
)
from tensorboard.compat.proto.tensor_pb2 import TensorProto
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
from tensorboard.plugins.custom_scalar import layout_pb2
from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
from ._convert_np import make_np
from ._utils import _prepare_video, convert_to_HWC
__all__ = [
"half_to_int",
"int_to_half",
"hparams",
"scalar",
"histogram_raw",
"histogram",
"make_histogram",
"image",
"image_boxes",
"draw_boxes",
"make_image",
"video",
"make_video",
"audio",
"custom_scalars",
"text",
"tensor_proto",
"pr_curve_raw",
"pr_curve",
"compute_curve",
"mesh",
]
logger = logging.getLogger(__name__)
def half_to_int(f: float) -> int:
"""Casts a half-precision float value into an integer.
Converts a half precision floating point value, such as `torch.half` or
`torch.bfloat16`, into an integer value which can be written into the
half_val field of a TensorProto for storage.
To undo the effects of this conversion, use int_to_half().
"""
buf = struct.pack("f", f)
return struct.unpack("i", buf)[0]
def int_to_half(i: int) -> float:
"""Casts an integer value to a half-precision float.
Converts an integer value obtained from half_to_int back into a floating
point value.
"""
buf = struct.pack("i", i)
return struct.unpack("f", buf)[0]
def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
return [half_to_int(x) for x in t.flatten().tolist()]
def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
return torch.view_as_real(t).flatten().tolist()
def _tensor_to_list(t: torch.Tensor) -> List[Any]:
return t.flatten().tolist()
# type maps: torch.Tensor type -> (protobuf type, protobuf val field)
_TENSOR_TYPE_MAP = {
torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
torch.short: ("DT_INT16", "int_val", _tensor_to_list),
torch.int: ("DT_INT32", "int_val", _tensor_to_list),
torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
}
def _calc_scale_factor(tensor):
converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
return 1 if converted.dtype == np.uint8 else 255
def _draw_single_box(
image,
xmin,
ymin,
xmax,
ymax,
display_str,
color="black",
color_text="black",
thickness=2,
):
from PIL import ImageDraw, ImageFont
font = ImageFont.load_default()
draw = ImageDraw.Draw(image)
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=thickness,
fill=color,
)
if display_str:
text_bottom = bottom
# Reverse list and print from bottom to top.
_left, _top, _right, _bottom = font.getbbox(display_str)
text_width, text_height = _right - _left, _bottom - _top
margin = np.ceil(0.05 * text_height)
draw.rectangle(
[
(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom),
],
fill=color,
)
draw.text(
(left + margin, text_bottom - text_height - margin),
display_str,
fill=color_text,
font=font,
)
return image
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
"""Output three `Summary` protocol buffers needed by hparams plugin.
`Experiment` keeps the metadata of an experiment, such as the name of the
hyperparameters and the name of the metrics.
`SessionStartInfo` keeps key-value pairs of the hyperparameters
`SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
Args:
hparam_dict: A dictionary that contains names of the hyperparameters
and their values.
metric_dict: A dictionary that contains names of the metrics
and their values.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
Returns:
The `Summary` protobufs for Experiment, SessionStartInfo and
SessionEndInfo
"""
import torch
from tensorboard.plugins.hparams.api_pb2 import (
DataType,
Experiment,
HParamInfo,
MetricInfo,
MetricName,
Status,
)
from tensorboard.plugins.hparams.metadata import (
EXPERIMENT_TAG,
PLUGIN_DATA_VERSION,
PLUGIN_NAME,
SESSION_END_INFO_TAG,
SESSION_START_INFO_TAG,
)
from tensorboard.plugins.hparams.plugin_data_pb2 import (
HParamsPluginData,
SessionEndInfo,
SessionStartInfo,
)
# TODO: expose other parameters in the future.
# hp = HParamInfo(name='lr',display_name='learning rate',
# type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
# max_value=100))
# mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
# description='', dataset_type=DatasetType.DATASET_VALIDATION)
# exp = Experiment(name='123', description='456', time_created_secs=100.0,
# hparam_infos=[hp], metric_infos=[mt], user='tw')
if not isinstance(hparam_dict, dict):
logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: hparam_dict should be a dictionary, nothing logged."
)
if not isinstance(metric_dict, dict):
logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
raise TypeError(
"parameter: metric_dict should be a dictionary, nothing logged."
)
hparam_domain_discrete = hparam_domain_discrete or {}
if not isinstance(hparam_domain_discrete, dict):
raise TypeError(
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
)
for k, v in hparam_domain_discrete.items():
if (
k not in hparam_dict
or not isinstance(v, list)
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
):
raise TypeError(
f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
)
hps = []
ssi = SessionStartInfo()
for k, v in hparam_dict.items():
if v is None:
continue
if isinstance(v, (int, float)):
ssi.hparams[k].number_value = v
if k in hparam_domain_discrete:
domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
values=[
struct_pb2.Value(number_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_FLOAT64"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, str):
ssi.hparams[k].string_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(string_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_STRING"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, bool):
ssi.hparams[k].bool_value = v
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(bool_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_BOOL"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, torch.Tensor):
v = make_np(v)[0]
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
continue
raise ValueError(
"value should be one of int, float, str, bool, or torch.Tensor"
)
content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
exp = Experiment(hparam_infos=hps, metric_infos=mts)
content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
smd = SummaryMetadata(
plugin_data=SummaryMetadata.PluginData(
plugin_name=PLUGIN_NAME, content=content.SerializeToString()
)
)
sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
return exp, ssi, sei
def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
"""Output a `Summary` protocol buffer containing a single scalar value.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: A real numeric Tensor containing a single value.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
new_style: Whether to use new style (tensor field) or old style (simple_value
field). New style could lead to faster data loading.
Returns:
A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
Raises:
ValueError: If tensor has the wrong shape or type.
"""
tensor = make_np(tensor).squeeze()
assert (
tensor.ndim == 0
), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
# python float is double precision in numpy
scalar = float(tensor)
if new_style:
tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
if double_precision:
tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(
value=[
Summary.Value(
tag=name,
tensor=tensor_proto,
metadata=smd,
)
]
)
else:
return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
def tensor_proto(tag, tensor):
"""Outputs a `Summary` protocol buffer containing the full tensor.
The generated Summary has a Tensor.proto containing the input Tensor.
Args:
name: A name for the generated node. Will also serve as the series name in
TensorBoard.
tensor: Tensor to be converted to protobuf
Returns:
A tensor protobuf in a `Summary` protobuf.
Raises:
ValueError: If tensor is too big to be converted to protobuf, or
tensor data type is not supported
"""
if tensor.numel() * tensor.itemsize >= (1 << 31):
raise ValueError(
"tensor is bigger than protocol buffer's hard limit of 2GB in size"
)
if tensor.dtype in _TENSOR_TYPE_MAP:
dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
tensor_proto = TensorProto(
**{
"dtype": dtype,
"tensor_shape": TensorShapeProto(
dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
),
field_name: conversion_fn(tensor),
},
)
else:
raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
smd = SummaryMetadata(plugin_data=plugin_data)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
min: A float or int min value
max: A float or int max value
num: Int number of values
sum: Float or int sum of all values
sum_squares: Float or int sum of squares for all values
bucket_limits: A numeric `Tensor` with upper value per bucket
bucket_counts: A numeric `Tensor` with number of values per bucket
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
hist = HistogramProto(
min=min,
max=max,
num=num,
sum=sum,
sum_squares=sum_squares,
bucket_limit=bucket_limits,
bucket=bucket_counts,
)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def histogram(name, values, bins, max_bins=None):
# pylint: disable=line-too-long
"""Output a `Summary` protocol buffer with a histogram.
The generated
[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
has one summary value containing a histogram for `values`.
This op reports an `InvalidArgument` error if any value is not finite.
Args:
name: A name for the generated node. Will also serve as a series name in
TensorBoard.
values: A real numeric `Tensor`. Any shape. Values to use to
build the histogram.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
values = make_np(values)
hist = make_histogram(values.astype(float), bins, max_bins)
return Summary(value=[Summary.Value(tag=name, histo=hist)])
def make_histogram(values, bins, max_bins=None):
"""Convert values into a histogram proto using logic from histogram.cc."""
if values.size == 0:
raise ValueError("The input has no element.")
values = values.reshape(-1)
counts, limits = np.histogram(values, bins=bins)
num_bins = len(counts)
if max_bins is not None and num_bins > max_bins:
subsampling = num_bins // max_bins
subsampling_remainder = num_bins % subsampling
if subsampling_remainder != 0:
counts = np.pad(
counts,
pad_width=[[0, subsampling - subsampling_remainder]],
mode="constant",
constant_values=0,
)
counts = counts.reshape(-1, subsampling).sum(axis=-1)
new_limits = np.empty((counts.size + 1,), limits.dtype)
new_limits[:-1] = limits[:-1:subsampling]
new_limits[-1] = limits[-1]
limits = new_limits
# Find the first and the last bin defining the support of the histogram:
cum_counts = np.cumsum(np.greater(counts, 0))
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
start = int(start)
end = int(end) + 1
del cum_counts
# TensorBoard only includes the right bin limits. To still have the leftmost limit
# included, we include an empty bin left.
# If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
# first nonzero-count bin:
counts = (
counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
)
limits = limits[start : end + 1]
if counts.size == 0 or limits.size == 0:
raise ValueError("The histogram is empty, please file a bug report.")
sum_sq = values.dot(values)
return HistogramProto(
min=values.min(),
max=values.max(),
num=len(values),
sum=values.sum(),
sum_squares=sum_sq,
bucket_limit=limits.tolist(),
bucket=counts.tolist(),
)
def image(tag, tensor, rescale=1, dataformats="NCHW"):
"""Output a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 3-D with shape `[height, width,
channels]` and where `channels` can be:
* 1: `tensor` is interpreted as Grayscale.
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The `name` in the outputted Summary.Value protobufs is generated based on the
name, with a suffix depending on the max_outputs setting:
* If `max_outputs` is 1, the summary value tag is '*name*/image'.
* If `max_outputs` is greater than 1, the summary value tags are
generated sequentially as '*name*/image/0', '*name*/image/1', etc.
Args:
tag: A name for the generated node. Will also serve as a series name in
TensorBoard.
tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
channels]` where `channels` is 1, 3, or 4.
'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
The image() function will scale the image values to [0, 255] by applying
a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
will be clipped.
Returns:
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
tensor = make_np(tensor)
tensor = convert_to_HWC(tensor, dataformats)
# Do not assume that user passes in values in [0, 255], use data type to detect
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
image = make_image(tensor, rescale=rescale)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def image_boxes(
tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
):
"""Output a `Summary` protocol buffer with images."""
tensor_image = make_np(tensor_image)
tensor_image = convert_to_HWC(tensor_image, dataformats)
tensor_boxes = make_np(tensor_boxes)
tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
image = make_image(
tensor_image.clip(0, 255).astype(np.uint8),
rescale=rescale,
rois=tensor_boxes,
labels=labels,
)
return Summary(value=[Summary.Value(tag=tag, image=image)])
def draw_boxes(disp_image, boxes, labels=None):
# xyxy format
num_boxes = boxes.shape[0]
list_gt = range(num_boxes)
for i in list_gt:
disp_image = _draw_single_box(
disp_image,
boxes[i, 0],
boxes[i, 1],
boxes[i, 2],
boxes[i, 3],
display_str=None if labels is None else labels[i],
color="Red",
)
return disp_image
def make_image(tensor, rescale=1, rois=None, labels=None):
"""Convert a numpy representation of an image to Image protobuf."""
from PIL import Image
height, width, channel = tensor.shape
scaled_height = int(height * rescale)
scaled_width = int(width * rescale)
image = Image.fromarray(tensor)
if rois is not None:
image = draw_boxes(image, rois, labels=labels)
ANTIALIAS = Image.Resampling.LANCZOS
image = image.resize((scaled_width, scaled_height), ANTIALIAS)
import io
output = io.BytesIO()
image.save(output, format="PNG")
image_string = output.getvalue()
output.close()
return Summary.Image(
height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string,
)
def video(tag, tensor, fps=4):
tensor = make_np(tensor)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
video = make_video(tensor, fps)
return Summary(value=[Summary.Value(tag=tag, image=video)])
def make_video(tensor, fps):
try:
import moviepy # noqa: F401
except ImportError:
print("add_video needs package moviepy")
return
try:
from moviepy import editor as mpy
except ImportError:
print(
"moviepy is installed, but can't import moviepy.editor.",
"Some packages could be missing [imageio, requests]",
)
return
import tempfile
t, h, w, c = tensor.shape
# encode sequence of images into gif string
clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
try: # newer version of moviepy use logger instead of progress_bar argument.
clip.write_gif(filename, verbose=False, logger=None)
except TypeError:
try: # older version of moviepy does not support progress_bar argument.
clip.write_gif(filename, verbose=False, progress_bar=False)
except TypeError:
clip.write_gif(filename, verbose=False)
with open(filename, "rb") as f:
tensor_string = f.read()
try:
os.remove(filename)
except OSError:
logger.warning("The temporary file used by moviepy cannot be deleted.")
return Summary.Image(
height=h, width=w, colorspace=c, encoded_image_string=tensor_string
)
def audio(tag, tensor, sample_rate=44100):
array = make_np(tensor)
array = array.squeeze()
if abs(array).max() > 1:
print("warning: audio amplitude out of range, auto clipped.")
array = array.clip(-1, 1)
assert array.ndim == 1, "input tensor should be 1 dimensional."
array = (array * np.iinfo(np.int16).max).astype("<i2")
import io
import wave
fio = io.BytesIO()
with wave.open(fio, "wb") as wave_write:
wave_write.setnchannels(1)
wave_write.setsampwidth(2)
wave_write.setframerate(sample_rate)
wave_write.writeframes(array.data)
audio_string = fio.getvalue()
fio.close()
audio = Summary.Audio(
sample_rate=sample_rate,
num_channels=1,
length_frames=array.shape[-1],
encoded_audio_string=audio_string,
content_type="audio/wav",
)
return Summary(value=[Summary.Value(tag=tag, audio=audio)])
def custom_scalars(layout):
categories = []
for k, v in layout.items():
charts = []
for chart_name, chart_meatadata in v.items():
tags = chart_meatadata[1]
if chart_meatadata[0] == "Margin":
assert len(tags) == 3
mgcc = layout_pb2.MarginChartContent(
series=[
layout_pb2.MarginChartContent.Series(
value=tags[0], lower=tags[1], upper=tags[2]
)
]
)
chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
else:
mlcc = layout_pb2.MultilineChartContent(tag=tags)
chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
charts.append(chart)
categories.append(layout_pb2.Category(title=k, chart=charts))
layout = layout_pb2.Layout(category=categories)
plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[layout.SerializeToString()],
tensor_shape=TensorShapeProto(),
)
return Summary(
value=[
Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
]
)
def text(tag, text):
plugin_data = SummaryMetadata.PluginData(
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_STRING",
string_val=[text.encode(encoding="utf_8")],
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
)
return Summary(
value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
)
def pr_curve_raw(
tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
):
if num_thresholds > 127: # weird, value > 127 breaks protobuf
num_thresholds = 127
data = np.stack((tp, fp, tn, fn, precision, recall))
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
# weird, value > 127 breaks protobuf
num_thresholds = min(num_thresholds, 127)
data = compute_curve(
labels, predictions, num_thresholds=num_thresholds, weights=weights
)
pr_curve_plugin_data = PrCurvePluginData(
version=0, num_thresholds=num_thresholds
).SerializeToString()
plugin_data = SummaryMetadata.PluginData(
plugin_name="pr_curves", content=pr_curve_plugin_data
)
smd = SummaryMetadata(plugin_data=plugin_data)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=data.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=data.shape[0]),
TensorShapeProto.Dim(size=data.shape[1]),
]
),
)
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
def compute_curve(labels, predictions, num_thresholds=None, weights=None):
_MINIMUM_COUNT = 1e-7
if weights is None:
weights = 1.0
# Compute bins of true positives and false positives.
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float64)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights,
)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights,
)
# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
return np.stack((tp, fp, tn, fn, precision, recall))
def _get_tensor_summary(
name, display_name, description, tensor, content_type, components, json_config
):
"""Create a tensor summary with summary metadata.
Args:
name: Uniquely identifiable name of the summary op. Could be replaced by
combination of name and type to make it unique even outside of this
summary.
display_name: Will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
tensor: Tensor to display in summary.
content_type: Type of content inside the Tensor.
components: Bitmask representing present parts (vertices, colors, etc.) that
belong to the summary.
json_config: A string, JSON-serialized dictionary of ThreeJS classes
configuration.
Returns:
Tensor summary with metadata.
"""
import torch
from tensorboard.plugins.mesh import metadata
tensor = torch.as_tensor(tensor)
tensor_metadata = metadata.create_summary_metadata(
name,
display_name,
content_type,
components,
tensor.shape,
description,
json_config=json_config,
)
tensor = TensorProto(
dtype="DT_FLOAT",
float_val=tensor.reshape(-1).tolist(),
tensor_shape=TensorShapeProto(
dim=[
TensorShapeProto.Dim(size=tensor.shape[0]),
TensorShapeProto.Dim(size=tensor.shape[1]),
TensorShapeProto.Dim(size=tensor.shape[2]),
]
),
)
tensor_summary = Summary.Value(
tag=metadata.get_instance_name(name, content_type),
tensor=tensor,
metadata=tensor_metadata,
)
return tensor_summary
def _get_json_config(config_dict):
"""Parse and returns JSON string from python dictionary."""
json_config = "{}"
if config_dict is not None:
json_config = json.dumps(config_dict, sort_keys=True)
return json_config
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
def mesh(
tag, vertices, colors, faces, config_dict, display_name=None, description=None
):
"""Output a merged `Summary` protocol buffer with a mesh/point cloud.
Args:
tag: A name for this summary operation.
vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
coordinates of vertices.
faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
vertices within each triangle.
colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
vertex.
display_name: If set, will be used as the display name in TensorBoard.
Defaults to `name`.
description: A longform readable description of the summary data. Markdown
is supported.
config_dict: Dictionary with ThreeJS classes names and configuration.
Returns:
Merged summary for mesh/point cloud representation.
"""
from tensorboard.plugins.mesh import metadata
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
json_config = _get_json_config(config_dict)
summaries = []
tensors = [
(vertices, MeshPluginData.VERTEX),
(faces, MeshPluginData.FACE),
(colors, MeshPluginData.COLOR),
]
tensors = [tensor for tensor in tensors if tensor[0] is not None]
components = metadata.get_components_bitmask(
[content_type for (tensor, content_type) in tensors]
)
for tensor, content_type in tensors:
summaries.append(
_get_tensor_summary(
tag,
display_name,
description,
tensor,
content_type,
components,
json_config,
)
)
return Summary(value=summaries)
|