Skip to content
This repository was archived by the owner on Jul 14, 2024. It is now read-only.
This repository was archived by the owner on Jul 14, 2024. It is now read-only.

figure out why tfp MVN works but distrax does not #47

@murphyk

Description

@murphyk

In https://github.com/probml/JSL/blob/main/jsl/demos/hmm_lillypad.py
we use

  hmm = HMM(trans_dist=distrax.Categorical(probs=A),
              init_dist=distrax.Categorical(probs=initial_probs),
              obs_dist=distrax.as_distribution(
                  tfp.substrates.jax.distributions.MultivariateNormalFullCovariance(loc=mu_collection,
                                                                                    covariance_matrix=cov_collection)))

but it fails when I switch to


    hmm = HMM(trans_dist=distrax.Categorical(probs=A),
            init_dist=distrax.Categorical(probs=initial_probs),
            obs_dist=distrax.MultivariateNormalFullCovariance(
                loc=mu_collection, covariance_matrix=cov_collection))

Why?

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions