Basic UsageΒΆ

An example of basic use of layg.

"""
An example use of the `layg` package
"""

import matplotlib.pyplot as plt  # type: ignore
import numpy as np  # type: ignore

# TODO: remove NOQA when isort is fixed
from layg import CholeskyNnEmulator as Emulator  # NOQA
from layg import emulate  # NOQA


def main():

    ndim = 2

    ######################
    ######################
    # Toy likelihood
    @emulate(Emulator)
    def loglike(x):
        if x.ndim != 1:
            loglist = []
            for x0 in x:
                loglist.append(-np.dot(x0, x0))
            return np.array(loglist)
        else:
            return np.array([-np.dot(x, x)])

    ######################
    ######################

    # Make fake data
    def get_x(ndim):
        """
        Sample from a Gaussian with mean 0 and std 1
        """

        return np.random.normal(0.0, 1.0, size=ndim)

    if ndim == 1:
        Xtrain = np.array([get_x(ndim) for _ in range(1000)])
        xlist = np.array([np.linspace(-3.0, 3.0, 11)]).T

    elif ndim == 2:

        Xtrain = np.array([get_x(ndim) for _ in range(10000)])
        xlist = np.array([get_x(ndim) for _ in range(10)])

    else:
        raise RuntimeError(
            "This number of dimensions has not been implemented for testing yet."
        )

    Ytrain = np.array([loglike(X) for X in Xtrain])
    loglike.train(Xtrain, Ytrain)

    loglike.output_err = True
    for x in xlist:
        print("x", x)
        print("val, err", loglike(np.array(x)))
    loglike.output_err = False

    # Plot an example
    assert loglike.trained

    fig = plt.figure()
    ax = fig.add_subplot(111)

    x_len = 100

    x_data_plot = np.zeros((x_len, ndim))
    for i in range(ndim):
        x_data_plot[:, i] = np.linspace(0, 1, x_len)

    y_true = np.array([loglike.true_func(x) for x in x_data_plot])
    y_emul = np.array([loglike(x) for x in x_data_plot])
    y_emul_raw = np.array([loglike.emulator.emul_func(x) for x in x_data_plot])

    ax.plot(x_data_plot[..., 0], y_true, label="true", color="black")
    ax.scatter(x_data_plot[..., 0], y_emul, label="emulated", marker="+")
    ax.scatter(
        x_data_plot[..., 0],
        y_emul_raw,
        label="emulated\n no error estimation",
        marker="+",
    )

    ax.legend()

    ax.set_xlabel("Input")
    ax.set_ylabel("Output")

    fig.savefig("check.png")


def test_main():
    main()


if __name__ == "__main__":
    main()

(Source code, png, hires.png, pdf)

../_images/example_basic.png