Source code for trieste.observer
# Copyright 2020 The Trieste Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Definitions and utilities for observers of objective functions. """
from __future__ import annotations
from typing import Callable, Mapping, Union
import tensorflow as tf
from typing_extensions import Final
from .data import Dataset
from .types import Tag, TensorType
[docs]SingleObserver = Callable[[TensorType], Dataset]
"""
Type alias for an observer of the objective function (that takes query points and returns an
unlabelled dataset).
"""
[docs]MultiObserver = Callable[[TensorType], Mapping[Tag, Dataset]]
"""
Type alias for an observer of the objective function (that takes query points and returns labelled
datasets).
"""
[docs]Observer = Union[SingleObserver, MultiObserver]
"""
Type alias for an observer, returning either labelled datasets or a single unlabelled dataset.
"""
[docs]OBJECTIVE: Final[Tag] = "OBJECTIVE"
"""
A tag typically used by acquisition rules to denote the data sets and models corresponding to the
optimization objective.
"""
def _is_finite(t: TensorType) -> TensorType:
return tf.logical_and(tf.math.is_finite(t), tf.logical_not(tf.math.is_nan(t)))
[docs]def filter_finite(query_points: TensorType, observations: TensorType) -> Dataset:
"""
:param query_points: A tensor of shape (N x M).
:param observations: A tensor of shape (N x 1).
:return: A :class:`~trieste.data.Dataset` containing all the rows in ``query_points`` and
``observations`` where the ``observations`` are finite numbers.
:raise ValueError or InvalidArgumentError: If ``query_points`` or ``observations`` have invalid
shapes.
"""
tf.debugging.assert_shapes([(observations, ("N", 1))])
mask = tf.reshape(_is_finite(observations), [-1])
return Dataset(tf.boolean_mask(query_points, mask), tf.boolean_mask(observations, mask))
[docs]def map_is_finite(query_points: TensorType, observations: TensorType) -> Dataset:
"""
:param query_points: A tensor.
:param observations: A tensor.
:return: A :class:`~trieste.data.Dataset` containing all the rows in ``query_points``,
along with the tensor result of mapping the elements of ``observations`` to: `1` if they are
a finite number, else `0`, with dtype `tf.uint8`.
:raise ValueError or InvalidArgumentError: If ``query_points`` and ``observations`` do not
satisfy the shape constraints of :class:`~trieste.data.Dataset`.
"""
return Dataset(query_points, tf.cast(_is_finite(observations), tf.uint8))