From d232d113fcfe3c520b49eab9fb12ea9a97f9c05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=89=9F=E8=89=AF=E5=A8=81?= <87120182+whiteode@users.noreply.github.com> Date: Fri, 26 Sep 2025 11:34:33 +0800 Subject: [PATCH] Implement nn_kernel function and visualize samples Added Python code for neural network kernel and Gaussian process prior samples visualization. --- chapter_gaussian-processes/gp-priors.md | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/chapter_gaussian-processes/gp-priors.md b/chapter_gaussian-processes/gp-priors.md index f37b23a2c0..d95aa58e0c 100644 --- a/chapter_gaussian-processes/gp-priors.md +++ b/chapter_gaussian-processes/gp-priors.md @@ -148,6 +148,34 @@ In some cases, we can essentially evaluate this covariance function in closed fo The RBF kernel is _stationary_, meaning that it is _translation invariant_, and therefore can be written as a function of $\tau = x-x'$. Intuitively, stationarity means that the high-level properties of the function, such as rate of variation, do not change as we move in input space. The neural network kernel, however, is _non-stationary_. Below, we show sample functions from a Gaussian process with this kernel. We can see that the function looks qualitatively different near the origin. +```{.python .input} +def nn_kernel(x1, x2): + x1 = x1.flatten() + x2 = x2.flatten() + N, M = len(x1), len(x2) + cov_matrix = np.zeros((N, M)) + for i in range(N): + for j in range(M): + tilde_x_i = np.array([1, x1[i]]) + tilde_x_j = np.array([1, x2[j]]) + numerator = 2 * np.dot(tilde_x_i, tilde_x_j) + term_i = 1 + 2 * np.dot(tilde_x_i, tilde_x_i) + term_j = 1 + 2 * np.dot(tilde_x_j, tilde_x_j) + arg = np.clip(numerator / np.sqrt(term_i * term_j), -1.0, 1.0) + cov_matrix[i, j] = (2 / np.pi) * np.arcsin(arg) + return cov_matrix + +x_points = np.linspace(-5, 5, 100) +meanvec = np.zeros(len(x_points)) +covmat = nn_kernel(x_points, x_points) +prior_samples = np.random.multivariate_normal(meanvec, covmat, size=5) + +d2l.plt.plot(x_points, prior_samples.T, alpha=0.7) +d2l.plt.show() + +``` + + ## Summary