Skip to content

Commit af69d42

Browse files
committed
Re-implement GMMHMM based upon:
Watanabe, Shinji, and Jen-Tzung Chien. Bayesian Speech and Language Processing. Cambridge University Press, 2015. which appears simpler than the current implementation. This will make the VariationalGMMHMM Easier to implement later on.
1 parent e01a10e commit af69d42

File tree

3 files changed

+109
-86
lines changed

3 files changed

+109
-86
lines changed

src/hmmlearn/_emissions.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,11 @@ def _initialize_sufficient_statistics(self):
134134
stats = super()._initialize_sufficient_statistics()
135135
stats['post'] = np.zeros(self.n_components)
136136
stats['obs'] = np.zeros((self.n_components, self.n_features))
137-
stats['obs**2'] = np.zeros((self.n_components, self.n_features))
138137
if self.covariance_type in ('tied', 'full'):
139138
stats['obs*obs.T'] = np.zeros((self.n_components, self.n_features,
140139
self.n_features))
140+
elif self.covariance_type in ('diag', 'spherical'):
141+
stats['obs**2'] = np.zeros((self.n_components, self.n_features))
141142
return stats
142143

143144
def _accumulate_sufficient_statistics(
@@ -181,7 +182,7 @@ def _generate_sample_from_state(self, state, random_state):
181182
)
182183

183184

184-
class BaseGMMHMM(BaseHMM):
185+
class BaseGMMHMM(_AbstractHMM):
185186

186187
def _get_n_fit_scalars_per_param(self):
187188
nc = self.n_components
@@ -222,11 +223,9 @@ def _compute_log_likelihood(self, X):
222223
def _initialize_sufficient_statistics(self):
223224
stats = super()._initialize_sufficient_statistics()
224225
stats['post_mix_sum'] = np.zeros((self.n_components, self.n_mix))
225-
stats['post_sum'] = np.zeros(self.n_components)
226-
227226
if 'm' in self.params:
228-
lambdas, mus = self.means_weight, self.means_prior
229-
stats['m_n'] = lambdas[:, :, None] * mus
227+
stats['m_n'] = np.zeros(
228+
(self.n_components, self.n_mix, self.n_features))
230229
if 'c' in self.params:
231230
stats['c_n'] = np.zeros_like(self.covars_)
232231

@@ -254,7 +253,7 @@ def _accumulate_sufficient_statistics(self, stats, X, lattice,
254253

255254
post_mix = np.zeros((n_samples, self.n_components, self.n_mix))
256255
for p in range(self.n_components):
257-
log_denses = self._compute_log_weighted_gaussian_densities(X, p)
256+
log_denses = self._log_density_for_sufficient_statistics(X, p)
258257
log_normalize(log_denses, axis=-1)
259258
with np.errstate(under="ignore"):
260259
post_mix[:, p, :] = np.exp(log_denses)
@@ -263,33 +262,23 @@ def _accumulate_sufficient_statistics(self, stats, X, lattice,
263262
post_comp_mix = post_comp[:, :, None] * post_mix
264263

265264
stats['post_mix_sum'] += post_comp_mix.sum(axis=0)
266-
stats['post_sum'] += post_comp.sum(axis=0)
267-
268265
if 'm' in self.params: # means stats
269266
stats['m_n'] += np.einsum('ijk,il->jkl', post_comp_mix, X)
270267

271268
if 'c' in self.params: # covariance stats
272-
centered = X[:, None, None, :] - self.means_
273-
274-
def outer_f(x): # Outer product over features.
275-
return x[..., :, None] * x[..., None, :]
276-
277-
if self.covariance_type == 'full':
278-
centered_dots = outer_f(centered)
279-
c_n = np.einsum('ijk,ijklm->jklm', post_comp_mix,
280-
centered_dots)
281-
elif self.covariance_type == 'diag':
282-
centered2 = np.square(centered, out=centered) # reuse
283-
c_n = np.einsum('ijk,ijkl->jkl', post_comp_mix, centered2)
284-
elif self.covariance_type == 'spherical':
285-
# Faster than (x**2).sum(-1).
286-
centered_norm2 = np.einsum('...i,...i', centered, centered)
287-
c_n = np.einsum('ijk,ijk->jk', post_comp_mix, centered_norm2)
288-
elif self.covariance_type == 'tied':
289-
centered_dots = outer_f(centered)
290-
c_n = np.einsum('ijk,ijklm->jlm', post_comp_mix, centered_dots)
291-
292-
stats['c_n'] += c_n
269+
if self.covariance_type == "full":
270+
stats['c_n'] += np.einsum(
271+
'ijk,il,im->jklm', post_comp_mix, X, X)
272+
elif self.covariance_type == "tied":
273+
stats['c_n'] += np.einsum(
274+
'ijk,il,im->jlm', post_comp_mix, X, X)
275+
elif self.covariance_type == "diag":
276+
stats['c_n'] += np.einsum(
277+
'ijk,il->jkl', post_comp_mix, X**2)
278+
elif self.covariance_type == "spherical":
279+
stats['c_n'] += np.einsum(
280+
'ijk,il->jk', post_comp_mix, X**2)
281+
293282

294283
def _generate_sample_from_state(self, state, random_state):
295284
cur_weights = self.weights_[state]

src/hmmlearn/hmm.py

Lines changed: 88 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,23 @@ def _do_mstep(self, stats):
387387
(cvweight + stats['post'][:, None, None]))
388388

389389

390-
class GMMHMM(_emissions.BaseGMMHMM):
390+
class GMMHMM(_emissions.BaseGMMHMM, BaseHMM):
391391
"""
392392
Hidden Markov Model with Gaussian mixture emissions.
393393
394+
Note:
395+
The implementation supports both Maximum Likelihood Estimation(MLE)
396+
and Maximum a-posteriori (MAP) approximation. By default, the various
397+
priors are configered such that the MLE is learned. To configure the
398+
model to make MAP estimatation, set the various priors to 0.
399+
400+
This implementation is based upon:
401+
Watanabe, Shinji, and Jen-Tzung Chien. Bayesian Speech and Language
402+
Processing. Cambridge University Press, 2015.
403+
404+
TODO:
405+
Sources for MAP priors for spherical and tied covariance
406+
394407
Attributes
395408
----------
396409
monitor_ : ConvergenceMonitor
@@ -417,6 +430,7 @@ class GMMHMM(_emissions.BaseGMMHMM):
417430
* (n_components, n_mix, n_features) if "diag",
418431
* (n_components, n_mix, n_features, n_features) if "full"
419432
* (n_components, n_features, n_features) if "tied".
433+
420434
"""
421435

422436
def __init__(self, n_components=1, n_mix=1,
@@ -592,25 +606,29 @@ def compute_cv():
592606

593607
def _init_covar_priors(self):
594608
if self.covariance_type == "full":
609+
# Pages 157 of Bayesian Speech and Language Processing
595610
if self.covars_prior is None:
596611
self.covars_prior = 0.0
597612
if self.covars_weight is None:
598-
self.covars_weight = -(1.0 + self.n_features + 1.0)
613+
self.covars_weight = (1.0 + self.n_features)
599614
elif self.covariance_type == "tied":
615+
# TODO - Source for these
600616
if self.covars_prior is None:
601617
self.covars_prior = 0.0
602618
if self.covars_weight is None:
603619
self.covars_weight = -(self.n_mix + self.n_features + 1.0)
604620
elif self.covariance_type == "diag":
621+
# Pages 158 of Bayesian Speech and Language Processing
605622
if self.covars_prior is None:
606-
self.covars_prior = -1.5
623+
self.covars_prior = 0
607624
if self.covars_weight is None:
608-
self.covars_weight = 0.0
625+
self.covars_weight = 2
609626
elif self.covariance_type == "spherical":
627+
# TODO - Source for these
610628
if self.covars_prior is None:
611-
self.covars_prior = -(self.n_mix + 2.0) / 2.0
629+
self.covars_prior = 0
612630
if self.covars_weight is None:
613-
self.covars_weight = 0.0
631+
self.covars_weight = -(self.n_mix + 2.0) / 2.0
614632

615633
def _fix_priors_shape(self):
616634
nc = self.n_components
@@ -731,6 +749,9 @@ def _check(self):
731749
_log.warning("Covariance of state #%d, mixture #%d "
732750
"has a null eigenvalue.", i, j)
733751

752+
def _log_density_for_sufficient_statistics(self, X, component):
753+
return self._compute_log_weighted_gaussian_densities(X, component)
754+
734755
def _do_mstep(self, stats):
735756
super()._do_mstep(stats)
736757
nf = self.n_features
@@ -740,12 +761,16 @@ def _do_mstep(self, stats):
740761
if 'w' in self.params:
741762
alphas_minus_one = self.weights_prior - 1
742763
w_n = stats['post_mix_sum'] + alphas_minus_one
743-
w_d = (stats['post_sum'] + alphas_minus_one.sum(axis=1))[:, None]
764+
w_d = w_n.sum(axis=-1)[:, None]
744765
self.weights_ = w_n / w_d
745766

746767
# Maximizing means
747768
if 'm' in self.params:
748-
m_n = stats['m_n']
769+
m_n = stats['m_n'] + np.einsum(
770+
"cm,cmi->cmi",
771+
self.means_weight,
772+
self.means_prior
773+
)
749774
m_d = stats['post_mix_sum'] + self.means_weight
750775
# If a componenent has zero weight, then replace nan (0/0?) means
751776
# by 0 (0/1). The actual value is irrelevant as the component will
@@ -757,57 +782,66 @@ def _do_mstep(self, stats):
757782

758783
# Maximizing covariances
759784
if 'c' in self.params:
760-
lambdas, mus = self.means_weight, self.means_prior
761-
centered_means = self.means_ - mus
762-
763-
def outer_f(x): # Outer product over features.
764-
return x[..., :, None] * x[..., None, :]
765-
766785
if self.covariance_type == 'full':
767-
centered_means_dots = outer_f(centered_means)
768-
769-
psis_t = np.transpose(self.covars_prior, axes=(0, 1, 3, 2))
770-
nus = self.covars_weight
771-
772-
c_n = psis_t + lambdas[:, :, None, None] * centered_means_dots
773-
c_n += stats['c_n']
774-
c_d = (
775-
stats['post_mix_sum'] + 1 + nus + nf + 1
776-
)[:, :, None, None]
777-
786+
# Pages 156-157 of Bayesian Speech and Language Processing
787+
c_n = (self.covars_prior
788+
+ stats['c_n']
789+
+ np.einsum("ck,cki,ckj->ckij",
790+
self.means_weight,
791+
self.means_prior,
792+
self.means_prior)
793+
- np.einsum("ck,cki,ckj->ckij",
794+
stats['post_mix_sum'] + self.means_weight,
795+
self.means_,
796+
self.means_))
797+
# Note that when self.covars_weight = 0
798+
# and c_d <= 0, then we will have a failure. This is discussed
799+
# on page 156 of the above book.
800+
c_d = stats['post_mix_sum'] + self.covars_weight
801+
c_d -= self.n_features - 1
802+
c_d = c_d[:, :, None, None]
803+
elif self.covariance_type == 'tied':
804+
# inferred from 'full'
805+
c_n = (self.covars_prior
806+
+ stats['c_n']
807+
+ np.einsum("ck,cki,ckj->cij",
808+
self.means_weight,
809+
self.means_prior,
810+
self.means_prior)
811+
- np.einsum("ck,cki,ckj->cij",
812+
stats['post_mix_sum'] + self.means_weight,
813+
self.means_,
814+
self.means_))
815+
c_d = stats['post_mix_sum'].sum(axis=-1) + self.covars_weight
816+
c_d += (nm + nf + 1.0)
817+
c_d = c_d[:, None, None]
778818
elif self.covariance_type == 'diag':
779-
alphas = self.covars_prior
780-
betas = self.covars_weight
781-
centered_means2 = centered_means ** 2
782-
783-
c_n = lambdas[:, :, None] * centered_means2 + 2 * betas
784-
c_n += stats['c_n']
785-
c_d = stats['post_mix_sum'][:, :, None] + 1 + 2 * (alphas + 1)
786-
819+
# Pages 157-158 of Bayesian Speech and Language Processing
820+
c_n = (self.covars_prior
821+
+ stats['c_n']
822+
+ np.einsum("ck,cki->cki",
823+
self.means_weight,
824+
self.means_prior**2)
825+
- np.einsum("ck,cki->cki",
826+
stats['post_mix_sum'] + self.means_weight,
827+
self.means_**2))
828+
c_d = (stats['post_mix_sum'][:, :, None]
829+
+ self.covars_weight
830+
- 2)
787831
elif self.covariance_type == 'spherical':
788-
centered_means_norm2 = np.einsum( # Faster than (x**2).sum(-1)
789-
'...i,...i', centered_means, centered_means)
790-
791-
alphas = self.covars_prior
792-
betas = self.covars_weight
793-
794-
c_n = lambdas * centered_means_norm2 + 2 * betas
795-
c_n += stats['c_n']
796-
c_d = nf * (stats['post_mix_sum'] + 1) + 2 * (alphas + 1)
797-
798-
elif self.covariance_type == 'tied':
799-
centered_means_dots = outer_f(centered_means)
800-
801-
psis_t = np.transpose(self.covars_prior, axes=(0, 2, 1))
802-
nus = self.covars_weight
803-
804-
c_n = np.einsum('ij,ijkl->ikl',
805-
lambdas, centered_means_dots) + psis_t
806-
c_n += stats['c_n']
807-
c_d = (stats['post_sum'] + nm + nus + nf + 1)[:, None, None]
832+
# inferred from 'diag'
833+
c_n = (self.covars_prior
834+
+ stats['c_n']
835+
+ np.einsum("ck,cki->ck",
836+
self.means_weight,
837+
self.means_prior**2)
838+
- np.einsum("ck,cki->ck",
839+
stats['post_mix_sum'] + self.means_weight,
840+
self.means_**2)) / nf
841+
c_d = stats['post_mix_sum'] + self.covars_weight + (nm + 2)/2
808842

809843
self.covars_ = c_n / c_d
810-
844+
assert not np.isnan(self.covars_).any(), self.covars_
811845

812846
class MultinomialHMM(_emissions.BaseMultinomialHMM):
813847
"""

src/hmmlearn/tests/test_gmm_hmm_new.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,6 @@ def test_chunked(sellf, covtype, init_params='mcw'):
260260
model2.fit(data, lengths=[200] * 5)
261261

262262
assert_allclose(model1.means_, model2.means_, rtol=0, atol=1e-2)
263-
assert_allclose(model1.covars_, model2.covars_, rtol=0, atol=1e-3)
264-
assert_allclose(model1.weights_, model2.weights_, rtol=0, atol=1e-3)
263+
assert_allclose(model1.covars_, model2.covars_, rtol=0, atol=1e-2)
264+
assert_allclose(model1.weights_, model2.weights_, rtol=0, atol=1e-2)
265265
assert_allclose(model1.transmat_, model2.transmat_, rtol=0, atol=1e-2)

0 commit comments

Comments
 (0)