Skip to content

Commit 1935f9b

Browse files
authored
Initialize startprob/transmat of VI Hmms with Dirichlet sampled estimates, similar to EM HMM. (hmmlearn#506)
* Initialize startprob_posterior and transmat_posterior of the Variational HMMs similar to the EM models (using Dirichlet Distribution) bump number of random initializations to demonstrate learning the best model in the Variational Inference Example Reduce the variational gaussian tests, with no loss of coverage * add note about Variational Gaussian Test, and set random seed. * Improve readability with consistency in test code.
1 parent 6f23f82 commit 1935f9b

File tree

5 files changed

+39
-51
lines changed

5 files changed

+39
-51
lines changed

examples/plot_variational_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def gaussian_hinton_diagram(startprob, transmat, means,
6666
rs = check_random_state(2022)
6767
sample_length = 500
6868
num_samples = 1
69-
num_inits = 1
69+
# With random initialization, it takes a few tries to find the
70+
# best solution
71+
num_inits = 5
7072
num_states = np.arange(1, 7)
7173
verbose = False
7274

lib/hmmlearn/base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,17 +1043,19 @@ def _init(self, X, lengths=None):
10431043
these should be ``n_samples``.
10441044
"""
10451045
self._check_and_set_n_features(X)
1046-
uniform_prior = 1 / self.n_components
1047-
# We could consider random initialization here as well
1046+
nc = self.n_components
1047+
uniform_prior = 1 / nc
1048+
random_state = check_random_state(self.random_state)
10481049
if (self._needs_init("s", "startprob_posterior_")
10491050
or self._needs_init("s", "startprob_prior_")):
10501051
if self.startprob_prior is None:
10511052
startprob_init = uniform_prior
10521053
else:
10531054
startprob_init = self.startprob_prior
10541055

1055-
self.startprob_prior_ = np.full(self.n_components, startprob_init)
1056-
self.startprob_posterior_ = self.startprob_prior_ * len(lengths)
1056+
self.startprob_prior_ = np.full(nc, startprob_init)
1057+
self.startprob_posterior_ = random_state.dirichlet(
1058+
np.full(nc, uniform_prior)) * len(lengths)
10571059

10581060
if (self._needs_init("t", "transmat_posterior_")
10591061
or self._needs_init("t", "transmat_prior_")):
@@ -1062,9 +1064,10 @@ def _init(self, X, lengths=None):
10621064
else:
10631065
transmat_init = self.transmat_prior
10641066
self.transmat_prior_ = np.full(
1065-
(self.n_components, self.n_components), transmat_init)
1066-
self.transmat_posterior_ = (
1067-
self.transmat_prior_ * sum(lengths) / self.n_components)
1067+
(nc, nc), transmat_init)
1068+
self.transmat_posterior_ = random_state.dirichlet(
1069+
np.full(nc, uniform_prior), size=nc)
1070+
self.transmat_posterior_ *= sum(lengths) / nc
10681071

10691072
n_fit_scalars_per_param = self._get_n_fit_scalars_per_param()
10701073
if n_fit_scalars_per_param is not None:

lib/hmmlearn/tests/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,12 @@ def compare_variational_and_em_models(variational, em, sequences, lengths):
8585
vi_obs, vi_states = variational.sample(100, random_state=42)
8686
assert np.all(em_obs == vi_obs)
8787
assert np.all(em_states == vi_states)
88+
89+
90+
def vi_uniform_startprob_and_transmat(model, lengths):
91+
nc = model.n_components
92+
model.startprob_prior_ = np.full(nc, 1/nc)
93+
model.startprob_posterior_ = np.full(nc, 1/nc) * len(lengths)
94+
model.transmat_prior_ = np.full((nc, nc), 1/nc)
95+
model.transmat_posterior_ = np.full((nc, nc), 1/nc)*sum(lengths)
96+
return model

lib/hmmlearn/tests/test_variational_categorical.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from hmmlearn import hmm, vhmm
66
from . import (
7-
assert_log_likelihood_increasing, compare_variational_and_em_models)
7+
assert_log_likelihood_increasing, compare_variational_and_em_models,
8+
vi_uniform_startprob_and_transmat)
89

910

1011
class TestVariationalCategorical:
@@ -218,8 +219,9 @@ def test_fit_and_compare_with_em(self, implementation):
218219
sequences, lengths = self.get_from_one_beal(7, 100, 1984)
219220
model = vhmm.VariationalCategoricalHMM(
220221
4, n_iter=500, random_state=1984,
222+
init_params="e",
221223
implementation=implementation)
222-
224+
vi_uniform_startprob_and_transmat(model, lengths)
223225
model.fit(sequences, lengths)
224226

225227
# The 1st hidden state will be "unused"

lib/hmmlearn/tests/test_variational_gaussian.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from hmmlearn import hmm, vhmm
66
from . import (
77
assert_log_likelihood_increasing, compare_variational_and_em_models,
8-
make_covar_matrix, normalized)
8+
make_covar_matrix, normalized, vi_uniform_startprob_and_transmat)
99

1010

1111
def get_mcgrory_titterington():
@@ -21,20 +21,6 @@ def get_mcgrory_titterington():
2121
return m1
2222

2323

24-
def get_mcgrory_titterington2d():
25-
""" A subtle variation on the 1D Case..."""
26-
m1 = hmm.GaussianHMM(4, init_params="", covariance_type="tied")
27-
m1.n_features = 4
28-
m1.startprob_ = np.array([1/4., 1/4., 1/4., 1/4.])
29-
m1.transmat_ = np.array([[0.2, 0.2, 0.3, 0.3],
30-
[0.3, 0.2, 0.2, 0.3],
31-
[0.2, 0.3, 0.3, 0.2],
32-
[0.3, 0.3, 0.2, 0.2]])
33-
m1.means_ = np.array([[-1.5, -1.5], [0, 0], [1.5, 1.5], [3., 3]])
34-
m1.covars_ = np.sqrt([[0.25, 0], [0, .25]])
35-
return m1
36-
37-
3824
def get_sequences(length, N, model, rs=None):
3925
sequences = []
4026
lengths = []
@@ -55,34 +41,41 @@ def test_random_fit(self, implementation, params='stmc', n_features=3,
5541
n_components=3, **kwargs):
5642
h = hmm.GaussianHMM(n_components, self.covariance_type,
5743
implementation=implementation, init_params="")
58-
rs = check_random_state(None)
44+
rs = check_random_state(1)
5945
h.startprob_ = normalized(rs.rand(n_components))
6046
h.transmat_ = normalized(
6147
rs.rand(n_components, n_components), axis=1)
6248
h.means_ = rs.randint(-20, 20, (n_components, n_features))
6349
h.covars_ = make_covar_matrix(
6450
self.covariance_type, n_components, n_features, random_state=rs)
65-
66-
lengths = [200] * 20
51+
lengths = [200] * 5
6752
X, _state_sequence = h.sample(sum(lengths), random_state=rs)
6853
# Now learn a model
6954
model = vhmm.VariationalGaussianHMM(
70-
n_components, n_iter=1000, tol=1e-9, random_state=rs,
55+
n_components, n_iter=50, tol=1e-9, random_state=rs,
7156
covariance_type=self.covariance_type,
7257
implementation=implementation)
73-
assert_log_likelihood_increasing(model, X, lengths, n_iter=100)
58+
59+
# Depending on the random seed, the model may converge rather quickly,
60+
# and throw an assertion in this test, as the function we call
61+
# computes each iteration independently by calling fit() `n_iter`
62+
# times.
63+
assert_log_likelihood_increasing(model, X, lengths, n_iter=10)
7464

7565
@pytest.mark.parametrize("implementation", ["scaling", "log"])
7666
def test_fit_mcgrory_titterington1d(self, implementation):
7767
random_state = check_random_state(234234)
68+
# Setup to assure convergence
7869

7970
sequences, lengths = get_sequences(500, 1,
8071
model=get_mcgrory_titterington(),
8172
rs=random_state)
8273
model = vhmm.VariationalGaussianHMM(
8374
5, n_iter=1000, tol=1e-9, random_state=random_state,
75+
init_params="mc",
8476
covariance_type=self.covariance_type,
8577
implementation=implementation)
78+
vi_uniform_startprob_and_transmat(model, lengths)
8679
model.fit(sequences, lengths)
8780
# Perform one check that we are converging to the right answer
8881
assert (model.means_posterior_[-1][0]
@@ -101,27 +94,6 @@ def test_fit_mcgrory_titterington1d(self, implementation):
10194

10295
compare_variational_and_em_models(model, em_hmm, sequences, lengths)
10396

104-
@pytest.mark.parametrize("implementation", ["scaling", "log"])
105-
def test_fit_mcgrory_titterington2d(self, implementation):
106-
sequences, lengths = get_sequences(100, 1,
107-
model=get_mcgrory_titterington2d())
108-
109-
model = vhmm.VariationalGaussianHMM(
110-
5, n_iter=1000, tol=1e-9, random_state=None,
111-
covariance_type=self.covariance_type,
112-
implementation=implementation)
113-
model.fit(sequences, lengths)
114-
115-
em_hmm = hmm.GaussianHMM(n_components=model.n_components,
116-
implementation=implementation,
117-
covariance_type=self.covariance_type)
118-
em_hmm.startprob_ = model.startprob_
119-
em_hmm.transmat_ = model.transmat_
120-
em_hmm.means_ = model.means_posterior_
121-
em_hmm.covars_ = model._covars_
122-
123-
compare_variational_and_em_models(model, em_hmm, sequences, lengths)
124-
12597
@pytest.mark.parametrize("implementation", ["scaling", "log"])
12698
def test_common_initialization(self, implementation):
12799
sequences, lengths = get_sequences(50, 10,

0 commit comments

Comments
 (0)