Home>

Use GPy for the first time. 12-dimensional oil flow data is embedded in the latent space using GPLVM, and the coordinates of each data in the latent space are obtained.

About Dataset

oil flow data is downloaded from here from the gzipped MATLAB workspace and decompressed You can use it as a 3Class.mat file.

oil flow data contains 12-dimensional vectors and 3-dimensional labels for 1000 data.
3D labels are
Class 0 [0,0,1]
Class 0 [0,1,0]
Class C [1,0,0]
It is a form.

import scipy.io
oil_flow_dataset = scipy.io.loadmat ("3Class.mat")
oil_flow_dataset.keys ()
#dict_keys (['DataTrnLbls', 'DataVdnLbls', 'DataTrn', 'DataTstFrctns', 'DataTrnFrctns', 'DataTst', 'DataVdn', 'DataVdnFrctns', 'DataTstLbls'])
oil_flow_dataset ["DataTrnLbls"]. shape
# (1000,3)
oil_flow_dataset ["DataTrn"]. shape
# (1000,12)

Trn represents Train, Vdn represents Valid, and Tst represents Test.

What you did

By executing the following code, you can see how oil_flow_data is embedded in a two-dimensional latent space using GPLVM.

import numpy as np
import GPy
import scipy.io
import matplotlib.pyplot as plt
oil_flow_dataset = scipy.io.loadmat ("3Class.mat")
observed_data = oil_flow_dataset ["DataTrn"]
normalized_observed_data = (observed_data-observed_data.mean (axis = 0))/observed_data.var (axis = 0)
GT = oil_flow_dataset ["DataTrnLbls"]. Nonzero () [1]
model = GPy.models.GPLVM (normalized_observed_data, input_dim = 2)
model.optimize (messages = True, max_iters = 1e3)
model.plot_latent (labels = GT)
plt.savefig ("gplvm.png")

What I want to do is to get the coordinates of each data (each △ in the figure above) in this two-dimensional latent space. In short, I want to plot the latent space using matplotlib normally without using plot_latent.

  • Answer # 1

    Thank you for using GPLVM. Thank you.

    For questions, it seems that model.X can access latent variables.
    The following code gives a display similar to plot_latent.

    import matplotlib.pyplot as plt
    % matplotlib inline
    spec = oil_flow_dataset ["DataTrnLbls"]. nonzero () [1]
    for c in range (0,3):
        arg = np.where (spec == c) [0]
        plt.plot (model.X [arg, 1], model.X [arg, 0], '.', c = 'C' + str (c))


Trends