Source code for trieste.utils.misc

# Copyright 2020 The Trieste Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from time import perf_counter
from types import TracebackType
from typing import Any, Callable, Generic, Mapping, NoReturn, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest
from typing_extensions import Final, final

from ..observer import OBJECTIVE
from ..types import Tag, TensorType

C = TypeVar("C", bound=Callable[..., object])
""" A type variable bound to `typing.Callable`. """


[docs] def jit(apply: bool = True, **optimize_kwargs: Any) -> Callable[[C], C]: """ A decorator that conditionally wraps a function with `tf.function`. :param apply: If `True`, the decorator is equivalent to `tf.function`. If `False`, the decorator does nothing. :param optimize_kwargs: Additional arguments to `tf.function`. :return: The decorator. """ def decorator(func: C) -> C: return tf.function(func, **optimize_kwargs) if apply else func return decorator
[docs] def shapes_equal(this: TensorType, that: TensorType) -> TensorType: """ Return a scalar tensor containing: `True` if ``this`` and ``that`` have equal runtime shapes, else `False`. """ return tf.rank(this) == tf.rank(that) and tf.reduce_all(tf.shape(this) == tf.shape(that))
[docs] def to_numpy(t: TensorType) -> "np.ndarray[Any, Any]": """ :param t: An array-like object. :return: ``t`` as a NumPy array. """ if isinstance(t, tf.Tensor): return t.numpy() return t
ResultType = TypeVar("ResultType", covariant=True) """ An unbounded covariant type variable. """
[docs] class Result(Generic[ResultType], ABC): """ Represents the result of an operation that can fail with an exception. It contains either the operation return value (in an :class:`Ok`), or the exception raised (in an :class:`Err`). To check whether instances such as >>> res = Ok(1) >>> other_res = Err(ValueError("whoops")) contain a value, use :attr:`is_ok` (or :attr:`is_err`) >>> res.is_ok True >>> other_res.is_ok False We can access the value if it :attr:`is_ok` using :meth:`unwrap`. >>> res.unwrap() 1 Trying to access the value of a failed :class:`Result`, or :class:`Err`, will raise the wrapped exception >>> other_res.unwrap() Traceback (most recent call last): ... ValueError: whoops **Note:** This class is not intended to be subclassed other than by :class:`Ok` and :class:`Err`. """ @property @abstractmethod def is_ok(self) -> bool: """`True` if this :class:`Result` contains a value, else `False`.""" @property def is_err(self) -> bool: """ `True` if this :class:`Result` contains an error, else `False`. The opposite of :attr:`is_ok`. """ return not self.is_ok
[docs] @abstractmethod def unwrap(self) -> ResultType: """ :return: The contained value, if it exists. :raise Exception: If there is no contained value. """
[docs] @final class Ok(Result[ResultType]): """Wraps the result of a successful evaluation.""" def __init__(self, value: ResultType): """ :param value: The result of a successful evaluation. """ self._value = value
[docs] def __repr__(self) -> str: """""" return f"Ok({self._value!r})"
@property def is_ok(self) -> bool: """`True` always.""" return True
[docs] def unwrap(self) -> ResultType: """ :return: The wrapped value. """ return self._value
[docs] @final class Err(Result[NoReturn]): """Wraps the exception that occurred during a failed evaluation.""" def __init__(self, exc: Exception): """ :param exc: The exception that occurred. """ self._exc = exc
[docs] def __repr__(self) -> str: """""" return f"Err({self._exc!r})"
@property def is_ok(self) -> bool: """`False` always.""" return False
[docs] def unwrap(self) -> NoReturn: """ :raise Exception: Always. Raises the wrapped exception. """ raise self._exc
[docs] class DEFAULTS: """Default constants used in Trieste.""" JITTER: Final[float] = 1e-6 """ The default jitter, typically used to stabilise computations near singular points, such as in Cholesky decomposition. """
K = TypeVar("K") """ An unbound type variable. """ U = TypeVar("U") """ An unbound type variable. """ V = TypeVar("V") """ An unbound type variable. """
[docs] def map_values(f: Callable[[U], V], mapping: Mapping[K, U]) -> Mapping[K, V]: """ Apply ``f`` to each value in ``mapping`` and return the result. If ``f`` does not modify its argument, :func:`map_values` does not modify ``mapping``. For example: >>> import math >>> squares = {'a': 1, 'b': 4, 'c': 9} >>> map_values(math.sqrt, squares)['b'] 2.0 >>> squares {'a': 1, 'b': 4, 'c': 9} :param f: The function to apply to the values in ``mapping``. :param mapping: A mapping. :return: A new mapping, whose keys are the same as ``mapping``, and values are the result of applying ``f`` to each value in ``mapping``. """ return {k: f(u) for k, u in mapping.items()}
T = TypeVar("T") """ An unbound type variable. """ def get_value_for_tag( mapping: Optional[Mapping[Tag, T]], *tags: Tag ) -> Tuple[Optional[Tag], Optional[T]]: """Return the value from a mapping for the first tag found from a sequence of tags. :param mapping: A mapping from tags to values. :param tags: A tag or a sequence of tags. Sequence is searched in order. If no tags are provided, the default tag OBJECTIVE is used. :return: The chosen tag and value of the tag in the mapping, or None for each if the mapping is None. :raises ValueError: If none of the tags are in the mapping and the mapping is not None. """ if not tags: tags = (OBJECTIVE,) if mapping is None: return None, None else: matched_tag = next((tag for tag in tags if tag in mapping), None) if matched_tag is None: raise ValueError(f"none of the tags '{tags}' found in mapping") return matched_tag, mapping[matched_tag] @dataclass(frozen=True) class LocalizedTag: """Manage a tag for a local model or dataset. These have a global tag and a local index.""" global_tag: Tag """ The global portion of the tag. """ local_index: Optional[int] """ The local index of the tag. """ def __post_init__(self) -> None: if self.local_index is not None and self.local_index < 0: raise ValueError(f"local index must be non-negative, got {self.local_index}") @property def is_local(self) -> bool: """Return True if the tag is a local tag.""" return self.local_index is not None @staticmethod def from_tag(tag: Union[Tag, LocalizedTag]) -> LocalizedTag: """Return a LocalizedTag from a given tag.""" if isinstance(tag, LocalizedTag): return tag else: return LocalizedTag(tag, None) def ignoring_local_tags(mapping: Mapping[Tag, T]) -> Mapping[Tag, T]: """ Filter out local tags from a mapping, returning a new mapping with only global tags. :param mapping: A mapping from tags to values. :return: A new mapping with only global tags. """ return {k: v for k, v in mapping.items() if not LocalizedTag.from_tag(k).is_local}
[docs] class Timer: """ Functionality for timing chunks of code. For example: >>> from time import sleep >>> with Timer() as timer: sleep(2.0) >>> timer.time # doctest: +SKIP 2.0 """ def __enter__(self) -> Timer: self.start = perf_counter() return self def __exit__( self, type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: self.end = perf_counter() self.time = self.end - self.start
[docs] def flatten_leading_dims( x: TensorType, output_dims: int = 2 ) -> Tuple[TensorType, Callable[[TensorType], TensorType]]: """ Flattens the leading dimensions of `x` (all but the last `output_dims` dimensions), and returns a function that can be used to restore them (typically after first manipulating the flattened tensor). """ tf.debugging.assert_positive(output_dims, message="output_dims must be positive") tf.debugging.assert_less_equal( output_dims, tf.rank(x), message="output_dims must no greater than tensor rank" ) x_batched_shape = tf.shape(x) batch_shape = x_batched_shape[: -output_dims + 1] if output_dims > 1 else x_batched_shape input_shape = x_batched_shape[-output_dims + 1 :] if output_dims > 1 else [] x_flat_shape = tf.concat([[-1], input_shape], axis=0) def unflatten(y: TensorType) -> TensorType: y_flat_shape = tf.shape(y) output_shape = y_flat_shape[1:] y_batched_shape = tf.concat([batch_shape, output_shape], axis=0) y_batched = tf.reshape(y, y_batched_shape) return y_batched return tf.reshape(x, x_flat_shape), unflatten
def get_variables(object: Any) -> tuple[tf.Variable, ...]: """ Return the sequence of variables in an object. This is essentially a reimplementation of the `variables` property of tf.Module but doesn't require that we, or any of our substructures, inherit from that. :return: A sequence of variables of the object (sorted by attribute name) followed by variables from all submodules recursively (breadth first). """ def _is_variable(obj: Any) -> bool: return isinstance(obj, tf.Variable) return tuple(_flatten(object, predicate=_is_variable, expand_composites=True)) _TF_MODULE_IGNORED_PROPERTIES = frozenset( ("_self_unconditional_checkpoint_dependencies", "_self_unconditional_dependency_names") ) def _flatten( # type: ignore[no-untyped-def] model, recursive=True, predicate=None, attribute_traversal_key=None, with_path=False, expand_composites=False, ): """ Flattened attribute values in sorted order by attribute name. This is taken verbatim from tensorflow core but uses a modified _flatten_module. """ if predicate is None: predicate = lambda _: True # noqa: E731 return _flatten_module( model, recursive=recursive, predicate=predicate, attributes_to_ignore=_TF_MODULE_IGNORED_PROPERTIES, attribute_traversal_key=attribute_traversal_key, with_path=with_path, expand_composites=expand_composites, ) def _flatten_module( # type: ignore[no-untyped-def] module, recursive, predicate, attribute_traversal_key, attributes_to_ignore, with_path, expand_composites, module_path=(), seen=None, ): """ Implementation of `flatten`. This is a reimplementation of the equivalent function in tf.Module so that we can extract the list of variables from a Trieste model wrapper without the need to inherit from it. """ if seen is None: seen = {id(module)} # [CHANGED] Differently from the original version, here we catch an exception # as some of the components of the wrapper do not implement __dict__ try: module_dict = vars(module) except TypeError: module_dict = {} submodules = [] for key in sorted(module_dict, key=attribute_traversal_key): if key in attributes_to_ignore: continue prop = module_dict[key] try: leaves = nest.flatten_with_tuple_paths(prop, expand_composites=expand_composites) except Exception: # pylint: disable=broad-except leaves = [] for leaf_path, leaf in leaves: leaf_path = (key,) + leaf_path if not with_path: leaf_id = id(leaf) if leaf_id in seen: continue seen.add(leaf_id) if predicate(leaf): if with_path: yield module_path + leaf_path, leaf else: yield leaf # [CHANGED] Differently from the original, here we skip checking whether the leaf # is a module, since the trieste models do NOT inherit from tf.Module if recursive: # and _is_module(leaf): # Walk direct properties first then recurse. submodules.append((module_path + leaf_path, leaf)) for submodule_path, submodule in submodules: subvalues = _flatten_module( submodule, recursive=recursive, predicate=predicate, attribute_traversal_key=attribute_traversal_key, attributes_to_ignore=_TF_MODULE_IGNORED_PROPERTIES, with_path=with_path, expand_composites=expand_composites, module_path=submodule_path, seen=seen, ) for subvalue in subvalues: # Predicate is already tested for these values. yield subvalue