Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/design_proposals/design_2_mcmc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Possible design extensions to support MCMC sampling algorithms

Eric Meissner(2018-11-12)

## Motivation

To support the first MCMC (Markov-Chain Monte Carlo) inference algorithm in MXFusion, a few issues need to be resolved. These are primarily due to the Markovian part of MCMC, requiring a one-step time series style approach to generate proposal samples from the previous samples. This document discusses designs for those issues that are extensible to future MCMC algorithms, but don't necessarily solve the complete set of time-series modelling issues. I will take the Metropolis-Hastings algorithm as standard in this document, as that's the first MCMC inference we'll be implementing and it is more general(and more difficult to implement in MXFusion) than say Gibbs sampling or Hamiltonian Monte Carlo.

### Problem Setup
The two primary design choices to be made are:
1. Where are MCMC samples stored / transferred to future inference algorithms? What is the interface for this, and how are they initialized?
2. How do we take most recent sample in the chain, ```z_t```, and use it to generate the next sample ```z_t+1```.
* This is non-trivial in MXFusion because it builds a circular reference into the proposal's FactorGraph. ```z_t+1``` depends on ```z_t``` but for a Model (without time-series support) saying a Variable depends on itself, ```m.z = Distribution(inputs=[m.z, ...]```, doesn't work right now.


## Proposed Changes


For problem (1), I propose we extend the InferenceParameters class (either as a subclass or a flag during initialization controlled by the Inference method) to include storing parameters for latent variables (LV), which is what the output samples of an MCMC algorithm are. Throughout the Inference method, these LV parameters would keep an up-to-date list of samples generated. In this solution, MCMC samples are easily serialized with no extra effort, and the existing TransferInference class can be extended readily to support reuse of these samples in later algorithms.
* Downside: This doesn't out-of-the-box allow for things like multi-chain generation and storage without storing the parameters elsewhere outside of the Inference loop, but that could easily be added in the future as needed. Another downside is that it complicates the InferenceParameters a bit, but either introducing subclasses or a flag that changes the behavior to include LV generation, and InferenceParameters

For problem (2), I propose:
* Developing the InferenceAlgorithm.compute() method as a method that takes in variables (LVs ```theta_t``` and parameters) for timestep ```t``` and outputs the samples for the LVs for the next timestep in the chain ```t+1``` (```theta_t+1```).
* The Inference class handles initializing parameters and LVs for ```t=0```, calling the InferenceAlgorithm.compute() in a loop for the number of samples requested, and managing the correct timestep of LV sample values that are passed into InferenceAlgorithm.compute().
* Treat the proposal distribution as a Posterior FactorGraph, leveraging the standard draw_samples and log_pdf methods.
* Introducing mapping variables into this proposal distribution that cut the connections needed to go from ```theta_t``` to ```theta_t+1```.
* The transfer of values one timestep to the next for LVs (i.e. to generate a sample at timestep ```3```, it needs as input the sample from timestep ```2``` which was the output at that timestep) is handled in the Inference class.

A working proof of concept is attached in the original pull request for this. It does not have a correctly extended InferenceParameters class, nor does it store the sample values in InferenceParameters directly. It successfully trained the Getting Started tutorial and the PPCA tutorial after adding a prior to ```m.w``` and using ```variance=1e-4``` for the proposal distributions.


## Rejected Alternatives

An alternative solution to problem (1) is to not store the MCMC samples at all in the Inference object, and simply return the samples after the Inference method completes. The main downsides to this are that the user then has to maintain those samples for use in future algorithms, and manually serialize those samples. I also don't see an easy alternative solution to the problem of initializing/storing the latent variables correctly **during Inference** without replicated code from InferenceParameters needing to go into the MCMCInference class.

An alternative solution to problem (2) is ...
2 changes: 1 addition & 1 deletion mxfusion/components/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def replicate_self(self, attribute_map=None):

def log_pdf(self, F, variables, targets=None):
"""
Computes the logarithm of the probability density/mass function (PDF/PMF) of the distribution.
Computes the logarithm of the probability density/mass function (PDF/PMF) of the distribution.
The inputs and outputs variables are fetched from the *variables* argument according to their UUIDs.

:param F: the MXNet computation mode (mxnet.symbol or mxnet.ndarray).
Expand Down
8 changes: 5 additions & 3 deletions mxfusion/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class Inference(object):
"""

def __init__(self, inference_algorithm, constants=None,
hybridize=False, dtype=None, context=None):
hybridize=False, dtype=None, context=None,
is_mcmc=False):

self.dtype = dtype if dtype is not None else get_default_dtype()
self.mxnet_context = context if context is not None else get_default_device()
Expand All @@ -52,13 +53,14 @@ def __init__(self, inference_algorithm, constants=None,
self._inference_algorithm = inference_algorithm
self.params = InferenceParameters(constants=constants,
dtype=self.dtype,
context=self.mxnet_context)
context=self.mxnet_context,
is_mcmc=is_mcmc)
self._initialized = False

def print_params(self):
"""
Returns a string with the inference parameters nicely formatted for display, showing which model they came from and their name + uuid.

Format:
> infr.print_params()
Variable(1ab23)(name=y) - (Model/Posterior(123ge2)) - (first mxnet values/shape)
Expand Down
12 changes: 9 additions & 3 deletions mxfusion/inference/inference_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class InferenceParameters(object):
:param context: The MXNet context
:type context: {mxnet.cpu or mxnet.gpu}
"""
def __init__(self, constants=None, dtype=None, context=None):
def __init__(self, constants=None, dtype=None, context=None, is_mcmc=False):
self.dtype = dtype if dtype is not None else get_default_dtype()
self.mxnet_context = context if context is not None else get_default_device()
self._constants = {}
Expand All @@ -51,6 +51,7 @@ def __init__(self, constants=None, dtype=None, context=None):
for k, v in constants.items()}
self._constants.update(constant_uuids)
self._params = ParameterDict()
self._is_mcmc = is_mcmc

def update_constants(self, constants):
"""
Expand Down Expand Up @@ -85,8 +86,13 @@ def initialize_params(self, graphs, observed_uuid):
self._constants[var.uuid] = var.constant

excluded = set(self._constants.keys()).union(observed_uuid)
for var in g.get_parameters(excluded=excluded,
include_inherited=False):
to_init = g.get_parameters(excluded=excluded,
include_inherited=False)
if self._is_mcmc:
to_init = to_init + g.get_latent_variables(observed_uuid)
for var in set(to_init):
if self._is_mcmc and var in self._params:
continue
var_shape = realize_shape(var.shape, self._constants)
init = initializer.Constant(var.initial_value_before_transformation) \
if var.initial_value is not None else None
Expand Down
192 changes: 192 additions & 0 deletions mxfusion/inference/mh_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================

from ..common.exceptions import InferenceError
from ..common.config import DEFAULT_DTYPE
from ..components.variables import Variable
from .variational import StochasticVariationalInference
from .inference_alg import SamplingAlgorithm
from .inference import Inference
from .map import MAP
from ..components.distributions import Normal
from ..models import Posterior
import mxnet as mx
from ..components.distributions.random_gen import MXNetRandomGenerator

class MetropolisHastingsAlgorithm(SamplingAlgorithm):
"""
The Metropolis-Hastings MCMCsampling algorithm.

:param model: the definition of the probabilistic model
:type model: Model
:param proposal: the proposal distribution to draw comparison samples against.
:type proposal: Distribution
:param observed: A list of observed variables
:type observed: [Variable]
:param num_samples: the number of samples used in estimating the variational lower bound
:type num_samples: int
:param target_variables: (optional) the target variables to sample
:type target_variables: [UUID]
"""
def __init__(self, model, observed, proposal=None, num_samples=1,
target_variables=None, variance=None, rand_gen=None, dtype=None):
self._rand_gen = MXNetRandomGenerator if rand_gen is None else \
rand_gen
self._dtype = dtype if dtype is not None else DEFAULT_DTYPE
self._copy_map = {} # {current: prior}
self._proposals_chosen = 0
self.variance = variance if variance is not None else mx.nd.array([1.], dtype=self._dtype)
if proposal is None:
proposal = Posterior(model)
for rv in model.get_latent_variables(observed):
rv_previous = Variable(shape=rv.shape)
self._copy_map[rv.uuid] = rv_previous.uuid
proposal[rv].set_prior(Normal(mean=rv_previous, variance=self.variance))
self._reversed_copy_map = {v:k for k,v in self._copy_map.items()}

super(MetropolisHastingsAlgorithm, self).__init__(
model=model, observed=observed, num_samples=num_samples,
target_variables=target_variables, extra_graphs=[proposal])

def compute(self, F, variables):
"""
Returns Metropolis-Hastings samples.
:param x: {'uuid': latest sample}
:rtype: {'uuid': next sample}
"""
x = variables.copy()
x = {k: v for k,v in x.items() if k not in self._copy_map}
# draw proposal samples using the last steps output
x_proposal = self.proposal.draw_samples(F, variables=x, num_samples=1)
x_proposal_full = x_proposal.copy()
# compute new ratio
x_proposal_full.update({k:v for k,v in x.items() if k not in x_proposal})

# # swapped = swap the uuids of new and old latent variables
swapped = {k:v for k,v in x_proposal_full.items() if k not in self._copy_map and k not in self._reversed_copy_map}
swapped.update({self._copy_map[k]: v for k,v in x_proposal_full.items() if k in self._copy_map})
swapped.update({self._reversed_copy_map[k]: v for k,v in x_proposal_full.items() if k in self._reversed_copy_map})

proposal_new = self.proposal.log_pdf(F, x_proposal_full)
proposal_old = self.proposal.log_pdf(F, swapped)
model_new = self.model.log_pdf(F, x_proposal_full)
model_old = self.model.log_pdf(F, variables)

alpha = (model_new - model_old)
alpha += (proposal_old - proposal_new)
r_min = F.exp(F.minimum(mx.nd.array([0], dtype=self._dtype),alpha))

unif_sample = self._rand_gen.sample_uniform(0,1)

# return this step's samples based on ratio
if unif_sample < r_min:
is_proposal = True
return_choice = x_proposal_full
else:
return_choice = variables
is_proposal = False
return_subset = {k:v for k,v in return_choice.items() if k in x_proposal}
return return_subset, is_proposal

@property
def proposal(self):
"""
Return the proposal distribution
"""
return self._extra_graphs[0]


class MCMCInference(Inference):
"""
The abstract class for MCMC-based inference methods.
An inference method consists of a few components: the applied inference algorithm, the model definition, and the inference parameters.

:param inference_algorithm: The applied inference algorithm
:type inference_algorithm: InferenceAlgorithm
:param graphs: a list of graph definitions required by the inference method. It includes the model definition and necessary posterior approximation.
:type graphs: [FactorGraph]
:param observed: A list of observed variables
:type observed: [Variable]
:param constants: Specify a list of model variables as constants
:type constants: {Variable: mxnet.ndarray}
:param hybridize: Whether to hybridize the MXNet Gluon block of the inference method.
:type hybridize: boolean
:param dtype: data type for internal numerical representation
:type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'}
:param context: The MXNet context
:type context: {mxnet.cpu or mxnet.gpu}
"""
def __init__(self, inference_algorithm, constants=None,
hybridize=False, dtype=None, context=None):
super(MCMCInference, self).__init__(
inference_algorithm=inference_algorithm, constants=constants,
hybridize=hybridize, dtype=dtype, context=context, is_mcmc=True)
self._number_proposals = 0
self.samples = {}

def create_executor(self):
"""
Return a MXNet Gluon block responsible for the execution of the inference method.
"""
infr = self._inference_algorithm.create_executor(
data_def=self.observed_variable_UUIDs, params=self.params,
var_ties=self.params.var_ties, rv_scaling=None)
if self._hybridize:
infr.hybridize()
infr.initialize(ctx=self.mxnet_context)
return infr

def run(self, optimizer='adam', learning_rate=1e-3, max_iter=2000,
verbose=False, n_prints=10, **kwargs):
"""
Run the inference method.

:param optimizer: the choice of optimizer (default: 'adam')
:type optimizer: str
:param learning_rate: the learning rate of the gradient optimizer (default: 0.001)
:type learning_rate: float
:param max_iter: the maximum number of iterations of gradient optimization
:type max_iter: int
:param verbose: whether to print per-iteration messages.
:type verbose: boolean
:param **kwargs: The keyword arguments specify the data for inferences. The key of each argument is the name of the corresponding
variable in model definition and the value of the argument is the data in numpy array format.
"""
data = [kwargs[v] for v in self.observed_variable_names]
self.initialize(**kwargs)

self.params._params.initialize(ctx=self.mxnet_context)
infr = self.create_executor()
iter_step = max(max_iter // n_prints, 1)
self._number_proposals = 0
number_proposals = 0
for i in range(max_iter):
sample, is_proposal = infr(mx.nd.zeros(1), *data)
if is_proposal:
number_proposals += 1
for k,v in sample.items():
if k not in self.samples:
self.samples[k] = v
else:
self.samples[k] = mx.nd.concat(self.samples[k], v, dim=0)
v_shaped = mx.nd.reshape(v, shape=self.params._params.get(k).data().shape)
self.params._params.get(k).set_data(v_shaped)
self.params._params.get(self.inference_algorithm._copy_map[k]).set_data(v_shaped)
if i > 0 and i % iter_step == 0:
self._number_proposals += number_proposals
if verbose:
print("{}th step. Acceptance rate so far: {:.3f}, latest {}".format(i, self._number_proposals / i, number_proposals / iter_step))
number_proposals = 0

return self.samples
12 changes: 12 additions & 0 deletions mxfusion/models/factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def draw_samples(self, F, variables, num_samples=1, targets=None):
"""
Draw samples from the target variables of the Factor Graph. If the ``targets`` argument is None, draw samples from all the variables
that are *not* in the conditional variables. If the ``targets`` argument is given, this method returns a list of samples of variables in the order of the target argument, otherwise it returns a dict of samples where the keys are the UUIDs of variables and the values are the samples.
Adds drawn samples directly into the variables argument.

:param F: the MXNet computation mode (``mxnet.symbol`` or ``mxnet.ndarray``).
:param variables: The set of variables
Expand Down Expand Up @@ -455,6 +456,17 @@ def get_constants(self):
"""
return [v for v in self.variables.values() if v.type == VariableType.CONSTANT]

def get_latent_variables(self, observed):
"""
Get the latent variables of the model.

:param observed: a list of observed variables.
:type observed: [UUID]
:returns: the list of latent variables.
:rtype: [Variable]
"""
return [v for v in self.variables.values() if v.type == VariableType.RANDVAR and v.uuid not in observed]


@staticmethod
def reconcile_graphs(current_graphs, primary_previous_graph, secondary_previous_graphs=None, primary_current_graph=None):
Expand Down
11 changes: 0 additions & 11 deletions mxfusion/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,6 @@ def __init__(self, name=None, verbose=False):
"""
super(Model, self).__init__(name=name, verbose=verbose)

def get_latent_variables(self, observed):
"""
Get the latent variables of the model.

:param observed: a list of observed variables.
:type observed: [UUID]
:returns: the list of latent variables.
:rtype: [Variable]
"""
return [v for v in self.variables.values() if v.type == VariableType.RANDVAR and v.uuid not in observed]

def _replicate_class(self, **kwargs):
"""
Returns a new instance of the derived FactorGraph's class.
Expand Down
Loading