Skip to content

Commit 850bcc8

Browse files
committed
bugfix, requires GPflow>=2.1.4
1 parent 2ef1cd3 commit 850bcc8

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
requirements = [
99
"numpy",
1010
"scipy",
11-
"gpflow>=2.1.0",
11+
"gpflow>=2.1.4",
1212
"tensorflow>=2.0",
1313
]
1414

1515
setup(
1616
name="vbpp",
17-
version="0.0.2",
17+
version="0.1.0",
1818
author="ST John",
1919
# author_email="",
2020
description="Variational Bayes for Point Processes using GPflow",

vbpp/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
def _Psi_matrix(self):
142142
Ψ = tf_calc_Psi_matrix(self.kernel, self.inducing_variable, self.domain)
143143
psi_jitter_matrix = self.psi_jitter * tf.eye(
144-
len(self.inducing_variable), dtype=default_float()
144+
self.inducing_variable.num_inducing, dtype=default_float()
145145
)
146146
return Ψ + psi_jitter_matrix
147147

@@ -217,7 +217,7 @@ def _elbo_integral_term(self, Kuu):
217217
# int_var_fx = γ |T| + trace_terms
218218
# trace_terms = - Tr(Kzz⁻¹ Ψ) + Tr(Kzz⁻¹ S Kzz⁻¹ Ψ)
219219
trace_terms = tf.reduce_sum(
220-
(Rinv_L_LT_RinvT - tf.eye(len(self.inducing_variable), dtype=default_float()))
220+
(Rinv_L_LT_RinvT - tf.eye(self.inducing_variable.num_inducing, dtype=default_float()))
221221
* Rinv_Ψ_RinvT
222222
)
223223

0 commit comments

Comments
 (0)