Custom EmulatorsΒΆ

This example shows how to build a custom emulator by defining a subclass of layg.emulator.BaseEmulator.

The emulator simply learns the mean and standard deviation of the supplied training data.

In this example the emulated function is very simple: it returns real numbers drawn from a Gaussian distribution with some mean.

from typing import Callable

import matplotlib.pyplot as plt  # type: ignore
import numpy as np  # type: ignore

from layg import BaseEmulator, emulate  # NOQA


class MeanEmulator(BaseEmulator):
    """
    An emulator that returns the mean of the training values

    The error estimate is the standard deviation of the error in the cross validation data.
    This emulator is not very useful other than as an example of how to write one.
    """

    def set_emul_func(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
        self.emul_func: Callable[[np.ndarray], np.ndarray] = lambda x: np.mean(y_train)

    def set_emul_error_func(self, x_cv: np.ndarray, y_cv_err: np.ndarray) -> None:
        self.emul_error: Callable[[np.ndarray], np.ndarray] = lambda x: y_cv_err.std()


MEAN = 2 + np.random.uniform(size=1)


@emulate(MeanEmulator)
def noise(x: np.ndarray) -> np.ndarray:
    """
    Sample from a Gaussian distribution

    The scatter is small enough that the emulated value is always used.
    """

    return np.random.normal(loc=MEAN, scale=1e-2, size=1)


def main():
    """
    Plot some output from this emulator
    """

    NUM_TRAIN = noise.init_train_thresh
    NUM_TEST = 20
    XDIM = 1

    # Train the emulator
    x_train = np.random.uniform(size=(NUM_TRAIN, XDIM))
    y_train = np.array([noise(x) for x in x_train])

    # Output error estimates
    noise.output_err = True

    # Get values from the trained emulator
    x_emu = np.random.uniform(size=(NUM_TEST, XDIM))

    y_emu = np.zeros_like(x_emu)
    y_err = np.zeros_like(x_emu)

    for i, x in enumerate(x_emu):
        val, err = noise(x)
        y_emu[i] = val
        y_err[i] = err

    # Plot the results
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.scatter(x_train[:, 0], y_train, marker="+", label="training values")
    ax.errorbar(
        x_emu,
        y_emu,
        yerr=y_err.flatten(),
        linestyle="None",
        marker="o",
        capsize=3,
        label="emulator",
        color="red",
    )

    ax.legend()

    # `__file__` is undefined when running in sphinx
    try:
        fig.savefig(__file__ + ".png")
    except NameError:
        pass


def test_main():
    """
    Runs in pytest
    """
    main()


if __name__ == "__main__":
    main()

(Source code, png, hires.png, pdf)

../_images/example_custom_emulator.png