Source code for gpflux.experiment_support.tensorboard

#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
TensorBoard event iterator
"""
from dataclasses import dataclass
from typing import Any, Iterator, List, Type, Union

import tensorflow as tf
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import tensor_util


@dataclass
[docs]class Event: """Minimal container to hold TensorBoard event data""" tag: str # summary name, e.g. "loss" or "lengthscales" step: int value: Any dtype: Type
[docs]def tensorboard_event_iterator(file_pattern: Union[str, List[str], tf.Tensor]) -> Iterator[Event]: """ Iterator yielding preprocessed tensorboard Events. :param file_pattern: A string, a list of strings, or a `tf.Tensor` of string type (scalar or vector), representing the filename glob (i.e. shell wildcard) pattern(s) that will be matched. """ def get_scalar_value(value: Any) -> Any: # Note(Vincent): I'm sorry this is messy... # Using `value.simple_value` returns 0.0 for # np.ndarray values, so we need to try `MakeNdarray` # first, which breaks for non-tensors. try: v = tensor_util.MakeNdarray(value.tensor).item() except Exception: try: v = value.simple_value except Exception: raise ValueError("Unable to read value from tensor") return v event_files = tf.data.Dataset.list_files(file_pattern) serialized_examples = tf.data.TFRecordDataset(event_files) for serialized_example in serialized_examples: event = event_pb2.Event.FromString(serialized_example.numpy()) for value in event.summary.value: v = get_scalar_value(value) yield Event(tag=value.tag, step=event.step, value=v, dtype=type(v))