diff --git a/examples/notebooks/Variational Continual Learning.ipynb b/examples/notebooks/Variational Continual Learning.ipynb new file mode 100644 index 0000000..58a5c67 --- /dev/null +++ b/examples/notebooks/Variational Continual Learning.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import gzip\n", + "import sys\n", + "\n", + "import mxfusion as mf\n", + "import mxnet as mx\n", + "\n", + "import logging\n", + "logging.getLogger().setLevel(logging.DEBUG) # logging to stdout\n", + "\n", + "# Set the compute context, GPU is available otherwise CPU\n", + "ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SplitMnistGenerator:\n", + " def __init__(self, data, batch_size):\n", + " self.data = data\n", + " self.batch_size = batch_size\n", + "\n", + " def __iter__(self):\n", + " for i in range(5):\n", + " idx_train_0 = np.where(self.data['train_label'] == i * 2)[0]\n", + " idx_train_1 = np.where(self.data['train_label'] == i * 2 + 1)[0]\n", + " idx_test_0 = np.where(self.data['test_label'] == i * 2)[0]\n", + " idx_test_1 = np.where(self.data['test_label'] == i * 2 + 1)[0]\n", + " \n", + " x_train = np.vstack((self.data['train_data'][idx_train_0], self.data['train_data'][idx_train_1]))\n", + " y_train = np.vstack((np.ones((idx_train_0.shape[0], 1)), -np.ones((idx_train_1.shape[0], 1))))\n", + "\n", + " x_test = np.vstack((self.data['test_data'][idx_test_0], self.data['test_data'][idx_test_1]))\n", + " y_test = np.vstack((np.ones((idx_test_0.shape[0], 1)), -np.ones((idx_test_1.shape[0], 1))))\n", + " \n", + " batch_size = x_train.shape[0] if self.batch_size is None else self.batch_size \n", + " train_iter = mx.io.NDArrayIter(x_train, y_train, batch_size, shuffle=True)\n", + "\n", + " batch_size = x_test.shape[0] if self.batch_size is None else self.batch_size \n", + " test_iter = mx.io.NDArrayIter(x_test, y_test, batch_size)\n", + " \n", + " yield train_iter, test_iter\n", + " return\n", + "\n", + "mnist = mx.test_utils.get_mnist()\n", + "in_dim = np.prod(mnist['train_data'][0].shape)\n", + "\n", + "gen = SplitMnistGenerator(mnist, batch_size=None)\n", + "for task_id, (train, test) in enumerate(gen):\n", + " print(\"Task\", task_id)\n", + " print(\"Train data shape\" ,train.data[0][1].shape)\n", + " print(\"Train label shape\" ,train.label[0][1].shape)\n", + " print(\"Test data shape\" ,test.data[0][1].shape)\n", + " print(\"Test label shape\" ,test.label[0][1].shape)\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rand_from_batch(x_coreset, y_coreset, x_train, y_train, coreset_size):\n", + " \"\"\" Random coreset selection \"\"\"\n", + " # Randomly select from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)\n", + " idx = np.random.choice(x_train.shape[0], coreset_size, False)\n", + " x_coreset.append(x_train[idx,:])\n", + " y_coreset.append(y_train[idx,:])\n", + " x_train = np.delete(x_train, idx, axis=0)\n", + " y_train = np.delete(y_train, idx, axis=0)\n", + " return x_coreset, y_coreset, x_train, y_train \n", + "\n", + "def k_center(x_coreset, y_coreset, x_train, y_train, coreset_size):\n", + " \"\"\" K-center coreset selection \"\"\"\n", + " # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)\n", + " dists = np.full(x_train.shape[0], np.inf)\n", + " current_id = 0\n", + " dists = update_distance(dists, x_train, current_id)\n", + " idx = [ current_id ]\n", + "\n", + " for i in range(1, coreset_size):\n", + " current_id = np.argmax(dists)\n", + " dists = update_distance(dists, x_train, current_id)\n", + " idx.append(current_id)\n", + "\n", + " x_coreset.append(x_train[idx,:])\n", + " y_coreset.append(y_train[idx,:])\n", + " x_train = np.delete(x_train, idx, axis=0)\n", + " y_train = np.delete(y_train, idx, axis=0)\n", + " return x_coreset, y_coreset, x_train, y_train\n", + "\n", + "def update_distance(dists, x_train, current_id):\n", + " for i in range(x_train.shape[0]):\n", + " current_dist = np.linalg.norm(x_train[i,:]-x_train[current_id,:])\n", + " dists[i] = np.minimum(current_dist, dists[i])\n", + " return dists" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_vcl(network_shape, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True):\n", + " x_coresets, y_coresets = [], []\n", + " x_testsets, y_testsets = [], []\n", + "\n", + " all_acc = np.array([])\n", + "\n", + " for task_id, (train, test) in enumerate(data_gen):\n", + " x_testsets.append(test.data[0][1])\n", + " y_testsets.append(test.label[0][1])\n", + "\n", + " # Set the readout head to train\n", + " head = 0 if single_head else task_id\n", + " # bsize = x_train.shape[0] if (batch_size is None) else batch_size\n", + "\n", + " # Train network with maximum likelihood to initialize first model\n", + " if task_id == 0:\n", + " ml_model = VanillaNN(network_shape)\n", + " ml_model.train(x_train, y_train, task_id, no_epochs, bsize)\n", + " mf_weights = ml_model.get_weights()\n", + " mf_variances = None\n", + " ml_model.close_session()\n", + "\n", + " # Select coreset if needed\n", + " if coreset_size > 0:\n", + " x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size)\n", + "\n", + " # Train on non-coreset data\n", + " mf_model = MFVINN(network_shape, prev_means=mf_weights, prev_log_variances=mf_variances)\n", + " mf_model.train(x_train, y_train, head, no_epochs, bsize)\n", + " mf_weights, mf_variances = mf_model.get_weights()\n", + "\n", + " # Incorporate coreset data and make prediction\n", + " acc = utils.get_scores(mf_model, x_testsets, y_testsets, x_coresets, y_coresets, hidden_size, no_epochs, single_head, batch_size)\n", + " all_acc = utils.concatenate_results(acc, all_acc)\n", + "\n", + " mf_model.close_session()\n", + "\n", + " return all_acc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BaseNN:\n", + " def __init__(self, network_shape):\n", + " # input and output placeholders\n", + " self.task_idx = mx.sym.Variable(name='task_idx', dtype=np.float32)\n", + " self.net = None\n", + " \n", + " def train(self, train_iter, val_iter, ctx):\n", + " # data = mx.sym.var('data')\n", + " # Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)\n", + "# data = mx.sym.flatten(data=data)\n", + " \n", + " # create a trainable module on compute context\n", + " self.model = mx.mod.Module(symbol=self.net, context=ctx)\n", + " self.model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n", + " init = mx.init.Xavier(factor_type=\"in\", magnitude=2.34)\n", + " self.model.init_params(initializer=init, force_init=True)\n", + " self.model.fit(train_iter, # train data\n", + " eval_data=val_iter, # validation data\n", + " optimizer='adam', # use SGD to train\n", + " optimizer_params={'learning_rate': 0.001}, # use fixed learning rate\n", + " eval_metric='acc', # report accuracy during training\n", + " batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches\n", + " num_epoch=10) # train for at most 50 dataset passes\n", + " # predict accuracy of mlp\n", + " acc = mx.metric.Accuracy()\n", + " self.model.score(test_iter, acc)\n", + " return acc\n", + "\n", + " def prediction_prob(self, test_iter, task_idx):\n", + " # task_idx??\n", + " prob = self.model.predict(test_iter)\n", + " return prob\n", + " \n", + "class VanillaNN(BaseNN):\n", + " def __init__(self, network_shape, prev_weights=None, learning_rate=0.001):\n", + " super(VanillaNN, self).__init__(network_shape)\n", + "\n", + " # Create net\n", + " self.net = mx.gluon.nn.HybridSequential(prefix='vanilla_')\n", + " with self.net.name_scope():\n", + " for layer in network_shape[1:-1]:\n", + " self.net.add(mx.gluon.nn.Dense(layer, activation=\"relu\"))\n", + " # Last layer for classification\n", + " self.net.add(mx.gluon.nn.Dense(network_shape[-1], flatten=True, in_units=network_shape[-2]))\n", + " self.loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n", + " self.net.initialize(mx.init.Xavier(magnitude=2.34))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "network_shape = (in_dim, 256, 256, 2) # binary classification\n", + "batch_size = None\n", + "no_epochs = 120\n", + "single_head = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run vanilla VCL\n", + "mx.random.seed(42)\n", + "np.random.seed(42)\n", + "\n", + "coreset_size = 0\n", + "data_gen = SplitMnistGenerator(mnist, batch_size)\n", + "vcl_result = run_vcl(network_shape, no_epochs, data_gen, rand_from_batch, coreset_size, batch_size, single_head)\n", + "print(vcl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run random coreset VCL\n", + "mx.random.seed(42)\n", + "np.random.seed(42)\n", + "\n", + "coreset_size = 40\n", + "data_gen = SplitMnistGenerator(mnist, batch_size)\n", + "rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, \n", + " coreset.rand_from_batch, coreset_size, batch_size, single_head)\n", + "print(rand_vcl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run k-center coreset VCL\n", + "mx.random.seed(42)\n", + "np.random.seed(42)\n", + "\n", + "data_gen = SplitMnistGenerator(mnist, batch_size)\n", + "kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, \n", + " coreset.k_center, coreset_size, batch_size, single_head)\n", + "print(kcen_vcl_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot average accuracy\n", + "vcl_avg = np.nanmean(vcl_result, 1)\n", + "rand_vcl_avg = np.nanmean(rand_vcl_result, 1)\n", + "kcen_vcl_avg = np.nanmean(kcen_vcl_result, 1)\n", + "utils.plot('results/split.jpg', vcl_avg, rand_vcl_avg, kcen_vcl_avg)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/variational_continual_learning/coresets.py b/examples/variational_continual_learning/coresets.py new file mode 100644 index 0000000..a61463e --- /dev/null +++ b/examples/variational_continual_learning/coresets.py @@ -0,0 +1,162 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import numpy as np +import mxnet as mx +from mxnet.io import NDArrayIter, DataIter, DataBatch +from abc import ABCMeta, abstractmethod +import itertools + + +class MultiIter(DataIter): + def __init__(self, iter_list): + super().__init__() + self.iterators = [] if iter_list is None else iter_list + + def __next__(self): + if len(self.iterators) == 0: + raise StopIteration + + if len(self.iterators) == 1: + return next(self.iterators[0]) + + data = [] + labels = [] + for iterator in self.iterators: + batch = next(iterator) + data.append(batch.data) + labels.append(batch.label) + return DataBatch(data=mx.nd.concat(*data, axis=0), label=mx.nd.concat(*labels, axis=0), pad=0) + + def __len__(self): + return len(self.iterators) + + def __getitem__(self, item): + return self.iterators[item] + + def reset(self): + for i in self.iterators: + i.reset() + + @property + def provide_data(self): + return list(itertools.chain(map(lambda i: i.provide_data, self.iterators))) + + @property + def provide_label(self): + return list(itertools.chain(map(lambda i: i.provide_label, self.iterators))) + + def append(self, iterator): + if not isinstance(iterator, (DataIter, NDArrayIter)): + raise ValueError("Expected either a DataIter or NDArray object, received: {}".format(type(iterator))) + self.iterators.append(iterator) + + +class Coreset(metaclass=ABCMeta): + """ + Abstract base class for coresets + """ + def __init__(self): + """ + Initialise the coreset + """ + self.iterator = None + self.reset() + + @abstractmethod + def selector(self, data): + pass + + def update(self, iterator): + data, labels = iterator.data[0][1].asnumpy(), iterator.label[0][1].asnumpy() + idx = self.selector(data) + self.iterator.append(NDArrayIter(data=data[idx, :], label=labels[idx], shuffle=False, batch_size=len(idx))) + + data = np.delete(data, idx, axis=0) + labels = np.delete(labels, idx, axis=0) + batch_size = min(iterator.batch_size, data.shape[0]) + + return NDArrayIter(data=data, label=labels, shuffle=False, batch_size=batch_size) + + def reset(self): + self.iterator = MultiIter([]) + + @staticmethod + def merge(coreset): + # For sizes 0 and 1 just return the original coreset + if len(coreset.iterator) <= 1: + return coreset + merged = coreset.__class__(coreset_size=coreset.coreset_size) + merged.append() + raise NotImplementedError + + +class Vanilla(Coreset): + """ + Vanilla coreset that is always size 0 + """ + def __init__(self): + super().__init__() + self.coreset_size = 0 + + def update(self, iterator): + return iterator + + def selector(self, data): + raise NotImplementedError + + +class Random(Coreset): + """ + Randomly select from (data, labels) and add to current coreset + """ + def __init__(self, coreset_size): + """ + Initialise the coreset + :param coreset_size: Size of the coreset + :type coreset_size: int + """ + super().__init__() + if coreset_size == 0: + raise ValueError("Coreset size should be > 0") + self.coreset_size = coreset_size + + def selector(self, data): + return np.random.choice(data.shape[0], self.coreset_size, False) + + +class KCenter(Random): + """ + Select k centers from (data, labels) and add to current coreset + """ + def selector(self, data): + dists = np.full(data.shape[0], np.inf) + current_id = 0 + + # TODO: This looks horribly inefficient + dists = self.update_distance(dists, data, current_id) + idx = [current_id] + + for i in range(1, self.coreset_size): + current_id = np.argmax(dists) + dists = self.update_distance(dists, data, current_id) + idx.append(current_id) + return idx + + @staticmethod + def update_distance(dists, data, current_id): + for i in range(data.shape[0]): + current_dist = np.linalg.norm(data[i, :] - data[current_id, :]) + dists[i] = np.minimum(current_dist, dists[i]) + return dists diff --git a/examples/variational_continual_learning/experiment.py b/examples/variational_continual_learning/experiment.py new file mode 100644 index 0000000..c57727c --- /dev/null +++ b/examples/variational_continual_learning/experiment.py @@ -0,0 +1,207 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import numpy as np + +from examples.variational_continual_learning.models import VanillaNN, BayesianNN +from examples.variational_continual_learning.coresets import Coreset + + +class Experiment: + def __init__(self, network_shape, num_epochs, learning_rate, optimizer, data_generator, + coreset, batch_size, single_head, ctx, verbose): + self.network_shape = network_shape + self.original_network_shape = network_shape # Only used when resetting + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.optimizer = optimizer + self.data_generator = data_generator + self.coreset = coreset + self.batch_size = batch_size + self.single_head = single_head + self.context = ctx + self.verbose = verbose + + # The following are to keep lint happy: + self.overall_accuracy = None + self.test_iterators = None + self.vanilla_model = None + self.bayesian_model = None + + self.task_ids = [] + + @property + def model_params(self): + return dict( + network_shape=self.network_shape, + learning_rate=self.learning_rate, + optimizer=self.optimizer, + max_iter=self.num_epochs, + ctx=self.context, + verbose=self.verbose + ) + + def reset(self): + self.coreset.reset() + self.network_shape = self.original_network_shape + self.overall_accuracy = np.array([]) + self.test_iterators = dict() + self.task_ids = [] + + print("Creating Vanilla Model") + self.vanilla_model = VanillaNN(**self.model_params) + + def new_task(self, task): + if self.single_head and self.bayesian_model is not None: + return + + if len(self.task_ids) > 0: + self.network_shape = self.network_shape[0:-1] + (self.network_shape[-1] + (task.number_of_classes,),) + + self.task_ids.append(task.task_id) + + # TODO: Would be nice if we could use the same object here + self.bayesian_model = BayesianNN(**self.model_params) + + def run(self): + self.reset() + + # To begin with, set the priors to None. + # We will in fact use the results of maximum likelihood as the first prior + priors = None + + for task in self.data_generator: + print("Task: ", task.task_id) + self.test_iterators[task.task_id] = task.test_iterator + + # Set the readout head to train_iterator + head = 0 if self.single_head else task.task_id + + # Update the coreset, and update the train iterator to remove the coreset data + train_iterator = self.coreset.update(task.train_iterator) + + label_shape = train_iterator.provide_label[0].shape + batch_size = label_shape[0] if self.batch_size is None else self.batch_size + + # Train network with maximum likelihood to initialize first model + if len(self.task_ids) == 0: + print("Training non-Bayesian neural network as starting point") + self.vanilla_model.train( + train_iterator=train_iterator, + validation_iterator=task.test_iterator, + head=head, + epochs=5, + batch_size=batch_size) + + priors = self.vanilla_model.net.collect_params() + train_iterator.reset() + + self.new_task(task) + + # Train on non-coreset data + print("Training main model") + self.bayesian_model.train( + train_iterator=train_iterator, + validation_iterator=task.test_iterator, + head=head, + epochs=self.num_epochs, + batch_size=self.batch_size, + priors=priors) + + # Set the priors for the next round of inference to be the current posteriors + priors = self.bayesian_model.posteriors + # print("Number of variables in priors: {}".format(len(priors.items()))) + + # Incorporate coreset data and make prediction + acc = self.get_scores() + print("Accuracies after task {}: [{}]".format(task.task_id, ", ".join(map("{:.3f}".format, acc)))) + self.overall_accuracy = concatenate_results(acc, self.overall_accuracy) + + def get_coreset(self, task_id): + """ + For multi-headed models gets the coreset for the given task id. + For single-headed models this will return a merged coreset + :param task_id: The task id + :return: iterator for the coreset + """ + if self.single_head: + # TODO: Cache the results if this is expensive? + iterator = Coreset.merge(self.coreset).iterator + else: + iterator = self.coreset.iterator + + if len(iterator) > 0: + return iterator[task_id] + return None + + def fine_tune(self, task_id): + """ + Fine tune the latest trained model using the coreset(s) + :param task_id: the task id + :return: the fine tuned prediction model + """ + coreset_iterator = self.get_coreset(task_id) + + if coreset_iterator is None: + print("Empty coreset: Using main model as prediction model for task {}".format(task_id)) + return self.bayesian_model + + coreset_iterator.reset() + batch_size = coreset_iterator.provide_label[0].shape[0] + prediction_model = BayesianNN(**self.model_params) + + priors = self.bayesian_model.posteriors + print("Number of variables in priors: {}".format(len(priors))) + + print("Fine tuning prediction model for task {}".format(task_id)) + prediction_model.train( + train_iterator=coreset_iterator, + validation_iterator=None, + head=task_id, + epochs=self.num_epochs, + batch_size=batch_size, + priors=priors) + + return prediction_model + + def get_scores(self): + scores = [] + # TODO: different learning rate and max iter here? + + for task_id, test_iterator in self.test_iterators.items(): + test_iterator.reset() + + head = 0 if self.single_head else task_id + prediction_model = self.fine_tune(task_id) + + print("Generating predictions for task {}".format(task_id)) + predictions = prediction_model.prediction_prob(test_iterator, head) + predicted_means = np.mean(predictions, axis=0) + predicted_labels = np.argmax(predicted_means, axis=1) + test_labels = test_iterator.label[0][1].asnumpy() + mt = test_labels.shape[0] + score = len(np.where(np.abs(predicted_labels[:mt] - test_labels) < 1e-10)[0]) * 1.0 / mt + scores.append(score) + return scores + + +def concatenate_results(score, all_score): + if all_score.size == 0: + all_score = np.reshape(score, (1, -1)) + else: + new_arr = np.empty((all_score.shape[0], all_score.shape[1] + 1)) + new_arr[:] = np.nan + new_arr[:, :-1] = all_score + all_score = np.vstack((new_arr, score)) + return all_score diff --git a/examples/variational_continual_learning/mlp.py b/examples/variational_continual_learning/mlp.py new file mode 100644 index 0000000..656beb0 --- /dev/null +++ b/examples/variational_continual_learning/mlp.py @@ -0,0 +1,64 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +from mxnet.gluon import Block +from mxnet.gluon.contrib.nn import Concurrent +from mxnet.gluon.nn import Dense, Sequential + + +class MLPSequential(Sequential): + def __init__(self, prefix, network_shape, single_head, **kwargs): + super().__init__(prefix=prefix, **kwargs) + + self.network_shape = network_shape + self.single_head = single_head + + with self.name_scope(): + for i in range(1, len(self.network_shape) - 1): + self.add(Dense(self.network_shape[i], activation="relu", in_units=self.network_shape[i - 1])) + + # Last layer for classification - one per head for multi-head networks + if self.single_head: + self.add(Dense(self.network_shape[-1], in_units=self.network_shape[-2])) + else: + for label_shape in self.network_shape[-1]: + self.add(Dense(label_shape, in_units=self.network_shape[-2])) + + +class MLP(Block): + def __init__(self, prefix, network_shape, single_head, **kwargs): + super().__init__(prefix=prefix, **kwargs) + + self.single_head = single_head + + with self.name_scope(): + self.hidden = Sequential() + for i in range(1, len(network_shape) - 1): + self.hidden.add(Dense(network_shape[i], activation="relu", in_units=network_shape[i - 1])) + + if single_head: + self.head = Dense(network_shape[-1], in_units=network_shape[-2]) + else: + self.concurrent = Concurrent() + for label_shape in network_shape[-1]: + self.concurrent.add(Dense(label_shape, in_units=network_shape[-2])) + + def forward(self, x): + for i in range(len(self.hidden)): + x = self.hidden[i](x) + + if self.single_head: + return self.head(x) + else: + return tuple(map(lambda h: h(x), self.concurrent)) diff --git a/examples/variational_continual_learning/mnist.py b/examples/variational_continual_learning/mnist.py new file mode 100644 index 0000000..5469795 --- /dev/null +++ b/examples/variational_continual_learning/mnist.py @@ -0,0 +1,97 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import numpy as np +from mxnet.io import NDArrayIter + + +class Task: + def __init__(self, task_id, task_details, train_iterator, test_iterator, number_of_classes): + self.task_id = task_id + self.task_details = task_details + self.train_iterator = train_iterator + self.test_iterator = test_iterator + self.number_of_classes = number_of_classes + + +class TaskGenerator: + def __init__(self, data, batch_size, tasks): + self.data = data + self.batch_size = batch_size + self.tasks = tasks + + +class SplitTaskGenerator(TaskGenerator): + def __iter__(self): + """ + Iterate over tasks + :return: the next task + :rtype: NDArrayIter + """ + for i, task in enumerate(self.tasks): + idx_train_0 = np.where(self.data['train_label'] == task[0])[0] + idx_train_1 = np.where(self.data['train_label'] == task[1])[0] + idx_test_0 = np.where(self.data['test_label'] == task[0])[0] + idx_test_1 = np.where(self.data['test_label'] == task[1])[0] + + x_train = np.vstack((self.data['train_data'][idx_train_0], self.data['train_data'][idx_train_1])) + y_train = np.hstack((np.ones((idx_train_0.shape[0],)), np.zeros((idx_train_1.shape[0],)))) + + x_test = np.vstack((self.data['test_data'][idx_test_0], self.data['test_data'][idx_test_1])) + y_test = np.hstack((np.ones((idx_test_0.shape[0],)), np.zeros((idx_test_1.shape[0],)))) + + batch_size = self.batch_size or x_train.shape[0] + train_iter = NDArrayIter(x_train, y_train, batch_size, shuffle=True) + + batch_size = self.batch_size or x_test.shape[0] + test_iter = NDArrayIter(x_test, y_test, batch_size) + + yield Task(i, task, train_iter, test_iter, number_of_classes=2) + return + + +class PermutedTaskGenerator(TaskGenerator): + def __iter__(self): + """ + Iterate over tasks + :return: the next task + :rtype: NDArrayIter + """ + for i, task in enumerate(self.tasks): + x_train = self.data['train_data'] + y_train = self.data['train_label'] + + x_test = self.data['test_data'] + y_test = self.data['test_label'] + + permutation = np.random.permutation(x_train.shape[1]) + + x_train = x_train[:, permutation] + x_test = x_test[:, permutation] + + # Convert to one hot encodings + # y_train = np.eye(10)[y_train] + # y_test = np.eye(10)[y_test] + + batch_size = self.batch_size or x_train.shape[0] + train_iter = NDArrayIter(x_train, y_train, batch_size, shuffle=True) + + batch_size = self.batch_size or x_test.shape[0] + test_iter = NDArrayIter(x_test, y_test, batch_size) + + # number_of_classes = y_train.shape[1] + number_of_classes = len(np.unique(y_train)) + + yield Task(i, task, train_iter, test_iter, number_of_classes) + return diff --git a/examples/variational_continual_learning/models.py b/examples/variational_continual_learning/models.py new file mode 100644 index 0000000..7b0e67b --- /dev/null +++ b/examples/variational_continual_learning/models.py @@ -0,0 +1,405 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== +import mxnet as mx +from mxnet.gluon import Trainer, ParameterDict +from mxnet.gluon.loss import SoftmaxCrossEntropyLoss +from mxnet.initializer import Xavier +from mxnet.metric import Accuracy + +from mxfusion import Model, Variable +from mxfusion.components import MXFusionGluonFunction +from mxfusion.components.distributions import Normal, Categorical +from mxfusion.inference import BatchInferenceLoop, create_Gaussian_meanfield, GradBasedInference, \ + StochasticVariationalInference, VariationalPosteriorForwardSampling, MinibatchInferenceLoop, \ + GradIteratorBasedInference + +import numpy as np + +from abc import ABC, abstractmethod + +from .mlp import MLP + + +class BaseNN(ABC): + prefix = None + + def __init__(self, network_shape, learning_rate, optimizer, max_iter, ctx, verbose): + self.model = None + self.network_shape = network_shape + self.learning_rate = learning_rate + self.optimizer = optimizer + self.max_iter = max_iter + self.loss = None + self.ctx = ctx + self.model = None + self.net = None + self.inference = None + self.create_net() + self.loss = SoftmaxCrossEntropyLoss() + self.verbose = verbose + + @property + def single_head(self): + if isinstance(self.network_shape[-1], int): + return True + if isinstance(self.network_shape[-1], (tuple, list)): + return False + raise ValueError("Unsupported network shape") + + @property + def num_heads(self): + return 1 if self.single_head else len(self.network_shape[-1]) + + def create_net(self): + # Create net + self.net = MLP(self.prefix, self.network_shape, self.single_head) + self.net.initialize(Xavier(magnitude=2.34), ctx=self.ctx) + + def forward(self, data): + # Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height) + data = mx.nd.flatten(data).as_in_context(self.ctx) + output = self.net(data) + return output + + def evaluate_accuracy(self, data_iterator, head=0): + """ + Evaluate the accuracy of the model on the given data iterator + :param data_iterator: data iterator + :param head: the head of the network (for multi-head models) + :return: accuracy + :rtype: float + """ + acc = Accuracy() + for i, batch in enumerate(data_iterator): + if self.single_head: + output = self.forward(batch.data[0]) + else: + output = self.forward(batch.data[0])[head] + + labels = batch.label[0].as_in_context(self.ctx) + predictions = mx.nd.argmax(output, axis=1) + acc.update(preds=predictions, labels=labels) + return acc.get()[1] + + @abstractmethod + def train(self, train_iterator, validation_iterator, head, batch_size, epochs, priors=None): + raise NotImplementedError + + def prediction_prob(self, test_iter, task_idx): + # TODO task_idx?? + prob = self.model.predict(test_iter) + return prob + + def get_weights(self): + params = self.net.collect_params() + return params + + @staticmethod + def print_status(epoch, loss, train_accuracy=float("nan"), validation_accuracy=float("nan")): + print("Epoch {:4d}. Loss: {:8.2f}, Train accuracy {:.3f}, Validation accuracy {:.3f}" + .format(epoch, loss, train_accuracy, validation_accuracy)) + + +class VanillaNN(BaseNN): + prefix = 'vanilla_' + + def train(self, train_iterator, validation_iterator, head, batch_size, epochs, priors=None): + trainer = Trainer(self.net.collect_params(), self.optimizer, dict(learning_rate=self.learning_rate)) + + num_examples = 0 + for epoch in range(epochs): + cumulative_loss = 0 + for i, batch in enumerate(train_iterator): + with mx.autograd.record(): + if self.single_head: + output = self.forward(batch.data[0].as_in_context(self.ctx)) + else: + output = self.forward(batch.data[0].as_in_context(self.ctx))[head] + labels = batch.label[0].as_in_context(self.ctx) + loss = self.loss(output, labels) + loss.backward() + trainer.step(batch_size=batch_size, ignore_stale_grad=True) + cumulative_loss += mx.nd.sum(loss).asscalar() + num_examples += len(labels) + + train_iterator.reset() + validation_iterator.reset() + train_accuracy = self.evaluate_accuracy(train_iterator, head=head) + validation_accuracy = self.evaluate_accuracy(validation_iterator, head=head) + self.print_status(epoch + 1, cumulative_loss / num_examples, train_accuracy, validation_accuracy) + + +class BayesianNN(BaseNN): + prefix = 'bayesian_' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.create_model() + + def create_model(self): + self.model = Model(verbose=self.verbose) + self.model.N = Variable() + self.model.f = MXFusionGluonFunction(self.net, num_outputs=self.num_heads, broadcastable=False) + self.model.x = Variable(shape=(self.model.N, self.network_shape[0])) + + if self.single_head: + self.model.r = self.model.f(self.model.x) + self.model.y = Categorical.define_variable( + log_prob=self.model.r, shape=(self.model.N, 1), num_classes=self.network_shape[-1]) + self.create_prior_variables(self.model.r) + else: + r = self.model.f(self.model.x) + for head, label_shape in enumerate(self.network_shape[-1]): + rh = r[head] if self.num_heads > 1 else r + setattr(self.model, 'r{}'.format(head), rh) + y = Categorical.define_variable(log_prob=rh, shape=(self.model.N, 1), num_classes=label_shape) + setattr(self.model, 'y{}'.format(head), y) + # TODO the statement below could probably be done only for the first head, since they all share the same + # factor parameters + self.create_prior_variables(rh) + + def create_prior_variables(self, r): + for v in r.factor.parameters.values(): + # First check that the variables haven't already been created (in multi-head case) + if getattr(self.model, v.inherited_name + "_mean", None) is not None: + continue + if getattr(self.model, v.inherited_name + "_variance", None) is not None: + continue + + means = Variable(shape=v.shape) + variances = Variable(shape=v.shape) + setattr(self.model, v.inherited_name + "_mean", means) + setattr(self.model, v.inherited_name + "_variance", variances) + v.set_prior(Normal(mean=means, variance=variances)) + + # noinspection PyUnresolvedReferences + def get_net_parameters(self, head): + if self.single_head: + r = self.model.r + else: + r = getattr(self.model, 'r{}'.format(head)) + return r.factor.parameters + + # noinspection PyUnresolvedReferences + def train(self, train_iterator, validation_iterator, head, batch_size, epochs, priors=None): + if self.single_head: + print("Running single-headed inference") + + dummy = mx.nd.flatten(mx.nd.zeros(shape=train_iterator.provide_data[0].shape)).as_in_context(self.ctx) + x_shape = dummy.shape + y_shape = train_iterator.provide_label[0].shape + + if self.verbose: + print("Data shape {}".format(x_shape)) + + # pass some data to initialise the net + self.net(dummy[:1]) + + observed = [self.model.x, self.model.y] + kwargs = {'x': x_shape, 'y': y_shape} + + q = create_Gaussian_meanfield(model=self.model, observed=observed) + alg = StochasticVariationalInference(num_samples=5, model=self.model, posterior=q, observed=observed) + + self.inference = GradIteratorBasedInference(inference_algorithm=alg) + self.inference.initialize(**kwargs) + + for v in self.get_net_parameters(head).values(): + v_name_mean = v.inherited_name + "_mean" + v_name_variance = v.inherited_name + "_variance" + + if priors is None or (v_name_mean not in priors and v_name_variance not in priors): + means = self.prior_mean(shape=v.shape) + variances = self.prior_variance(shape=v.shape) + elif isinstance(priors, ParameterDict): + # This is a maximum likelihood estimate + short_name = v.inherited_name.partition(self.prefix)[-1] + means = priors.get(short_name).data() + variances = self.prior_variance(shape=v.shape) + else: + # Use posteriors from previous round of inference + means = priors[v_name_mean] + variances = priors[v_name_variance] + + mean_prior = getattr(self.model, v_name_mean) + variance_prior = getattr(self.model, v_name_variance) + + # v.set_prior(Normal(mean=mean_prior, variance=variance_prior)) + + self.inference.params[mean_prior] = means + self.inference.params[variance_prior] = variances + + # Indicate that we don't want to perform inference over the priors + self.inference.params.param_dict[mean_prior]._grad_req = 'null' + self.inference.params.param_dict[variance_prior]._grad_req = 'null' + + self.inference.run(max_iter=self.max_iter, learning_rate=self.learning_rate, + verbose=False, callback=self.print_status, data=train_iterator) + + else: + print("Running multi-headed inference for head {}".format(head)) + + for i, batch in enumerate(train_iterator): + if i > 0: + raise NotImplementedError("Currently not supported for more than one batch of data. " + "Please switch to using the MinibatchInferenceLoop") + + data = mx.nd.flatten(batch.data[0]).as_in_context(self.ctx) + labels = mx.nd.expand_dims(batch.label[0], axis=-1).as_in_context(self.ctx) + + if self.verbose: + print("Data shape {}".format(data.shape)) + + # pass some data to initialise the net + self.net(data[:1]) + + if self.single_head: + observed = [self.model.x, self.model.y] + ignored = None + kwargs = dict(y=labels, x=data) + else: + observed = [self.model.x, getattr(self.model, "y{}".format(head))] + y_other = [getattr(self.model, "y{}".format(h)) for h in range(self.num_heads) if h != head] + r_other = [getattr(self.model, "r{}".format(h)) for h in range(self.num_heads) if h != head] + ignored = y_other + r_other + kwargs = {'x': data, 'y{}'.format(head): labels, 'ignored': dict((v.name, v) for v in ignored)} + + q = create_Gaussian_meanfield(model=self.model, ignored=ignored, observed=observed) + alg = StochasticVariationalInference(num_samples=5, model=self.model, posterior=q, observed=observed, + ignored=ignored) + + self.inference = GradBasedInference(inference_algorithm=alg) + self.inference.initialize(**kwargs) + + for v in self.get_net_parameters(head).values(): + v_name_mean = v.inherited_name + "_mean" + v_name_variance = v.inherited_name + "_variance" + + if priors is None or (v_name_mean not in priors and v_name_variance not in priors): + means = self.prior_mean(shape=v.shape) + variances = self.prior_variance(shape=v.shape) + elif isinstance(priors, ParameterDict): + # This is a maximum likelihood estimate + short_name = v.inherited_name.partition(self.prefix)[-1] + means = priors.get(short_name).data() + variances = self.prior_variance(shape=v.shape) + else: + # Use posteriors from previous round of inference + means = priors[v_name_mean] + variances = priors[v_name_variance] + + mean_prior = getattr(self.model, v_name_mean) + variance_prior = getattr(self.model, v_name_variance) + + # v.set_prior(Normal(mean=mean_prior, variance=variance_prior)) + + self.inference.params[mean_prior] = means + self.inference.params[variance_prior] = variances + + # Indicate that we don't want to perform inference over the priors + self.inference.params.param_dict[mean_prior]._grad_req = 'null' + self.inference.params.param_dict[variance_prior]._grad_req = 'null' + + self.inference.run(max_iter=self.max_iter, learning_rate=self.learning_rate, + verbose=False, callback=self.print_status, **kwargs) + + # noinspection PyUnresolvedReferences + @property + def posteriors(self): + q = self.inference.inference_algorithm.posterior + + # TODO: don't convert to numpy arrays + posteriors = dict() + if self.single_head: + for v_name, v in self.model.r.factor.parameters.items(): + posteriors[v.inherited_name + "_mean"] = self.inference.params[q[v.uuid].factor.mean].asnumpy() + posteriors[v.inherited_name + "_variance"] = self.inference.params[q[v.uuid].factor.variance].asnumpy() + else: + for head in range(self.num_heads): + for v in self.get_net_parameters(head).values(): + posteriors[v.inherited_name + "_mean"] = self.inference.params[q[v.uuid].factor.mean].asnumpy() + posteriors[v.inherited_name + "_variance"] = \ + self.inference.params[q[v.uuid].factor.variance].asnumpy() + # print("Head {}, variable {}, shape {}" + # .format(head, v.inherited_name, posteriors[v.inherited_name + '_mean'].shape)) + return posteriors + + # noinspection PyUnresolvedReferences + def prediction_prob(self, test_iter, head): + if self.inference is None: + raise RuntimeError("Model not yet learnt") + + predictions = [] + + for i, batch in enumerate(test_iter): + # if i > 0: + # raise NotImplementedError("Currently not supported for more than one batch of data. " + # "Please switch to using the MinibatchInferenceLoop") + + data = mx.nd.flatten(batch.data[0]).as_in_context(self.ctx) + + # pass some data to initialise the net + self.net(data[:1]) + + r = self.model.r if self.single_head else getattr(self.model, 'r{}'.format(head)) + y = self.model.y if self.single_head else getattr(self.model, 'y{}'.format(head)) + y_other = [getattr(self.model, "y{}".format(h)) for h in range(self.num_heads) if h != head] + r_other = [getattr(self.model, "r{}".format(h)) for h in range(self.num_heads) if h != head] + ignored = y_other + r_other + + if self.verbose: + print("Data shape {}".format(data.shape)) + + if len(ignored) > 0: + # Here we need to re-instantiate the ignored variables into the posterior if they don't exist + model = self.inference.inference_algorithm.model # .clone() + old_posterior = self.inference.inference_algorithm.posterior + new_posterior = old_posterior.clone(model=model) + + # Reattach the missing parts of the graph + new_posterior[r].set_prior(new_posterior[r.factor]) + new_posterior[r].factor.predecessors = [(k, new_posterior[v]) for k, v in r.factor.predecessors] + new_posterior[r].factor.successors = [(k, new_posterior[v]) for k, v in r.factor.successors] + new_posterior[y].set_prior(new_posterior[y.factor]) + new_posterior[y].factor.predecessors = [(k, new_posterior[v]) for k, v in y.factor.predecessors] + new_posterior[y].factor.successors = [(k, new_posterior[v]) for k, v in y.factor.successors] + + # Set the posterior to be the new posterior + self.inference.inference_algorithm._extra_graphs[0] = new_posterior + else: + old_posterior = None + + prediction_inference = VariationalPosteriorForwardSampling( + num_samples=10, observed=[self.model.x], + inherited_inference=self.inference, + target_variables=[r], + ignored=ignored # dict((v.name, v) for v in ignored) + ) + + res = prediction_inference.run(x=mx.nd.array(data)) + + if old_posterior is not None: + # Set the posterior back to the old posterior + self.inference.inference_algorithm._extra_graphs[0] = old_posterior + + predictions.append(res[0].asnumpy()) + return np.concatenate(predictions, axis=1) + + @staticmethod + def prior_mean(shape): + return mx.nd.zeros(shape=shape) + + @staticmethod + def prior_variance(shape): + return mx.nd.ones(shape=shape) * 3 diff --git a/examples/variational_continual_learning/variational_continual_learning.py b/examples/variational_continual_learning/variational_continual_learning.py new file mode 100644 index 0000000..d1040b8 --- /dev/null +++ b/examples/variational_continual_learning/variational_continual_learning.py @@ -0,0 +1,152 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +# ============================================================================== + +import numpy as np +import mxnet as mx + +import matplotlib.pyplot as plt +from datetime import datetime +import fire + +from examples.variational_continual_learning.experiment import Experiment +from examples.variational_continual_learning.mnist import SplitTaskGenerator, PermutedTaskGenerator +from examples.variational_continual_learning.coresets import Random, KCenter, Vanilla + +import logging +logging.getLogger().setLevel(logging.DEBUG) # logging to stdout + +# Set the compute context, GPU is available otherwise CPU +CTX = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu() +mx.context.Context.default_ctx = CTX + + +def set_seeds(seed=42): + mx.random.seed(seed) + np.random.seed(seed) + + +class VCLRunner: + """ + Entry point for variational continual learning examples + """ + @staticmethod + def run(task='split', learning_rate=0.01, optimizer='adam', num_epochs=120, num_tasks=None): + # Load data + data = mx.test_utils.get_mnist() + input_dim = int(np.prod(data['train_data'][0].shape)) # Note the data will get flattened later + verbose = False + + # noinspection PyUnreachableCode + if task.lower() == 'split': + title = "Split MNIST" + tasks = ((0, 1), (2, 3), (4, 5), (6, 7), (8, 9)) + tasks = tasks[:num_tasks] + # num_epochs = 120 + # num_epochs = 1 # 120 + batch_size = None + gen = SplitTaskGenerator + label_shape = 2 + network_shape = (input_dim, 256, 256, (label_shape,)) + single_head = False + coreset_size = 40 + elif task.lower() == 'permuted': + title = "Permuted MNIST" + tasks = range(10) + tasks = tasks[:num_tasks] + # num_epochs = 100 + # num_epochs = 1 + batch_size = 256 + gen = PermutedTaskGenerator + label_shape = 10 + network_shape = (input_dim, 100, 100, label_shape) + single_head = True + coreset_size = 200 + else: + raise ValueError("Unknown task type {}, possibilities are 'split' or 'permuted'".format(task)) + + experiment_parameters = ( + dict( + coreset=Vanilla(), + learning_rate=learning_rate, + optimizer=optimizer, + network_shape=network_shape, + num_epochs=num_epochs, + single_head=single_head), + dict( + coreset=Random(coreset_size=coreset_size), + learning_rate=learning_rate, + optimizer=optimizer, + network_shape=network_shape, + num_epochs=num_epochs, + single_head=single_head), + dict( + coreset=KCenter(coreset_size=coreset_size), + learning_rate=learning_rate, + optimizer=optimizer, + network_shape=network_shape, + num_epochs=num_epochs, + single_head=single_head) + ) + + experiments = [] + + print("Task {}\nLearning rate {}\nOptimizer {}\nnumber of epochs {}\nnumber of tasks {}".format( + title, learning_rate, optimizer, num_epochs, num_tasks + )) + + # Run experiments + for params in experiment_parameters: + print("-" * 50) + print("Running experiment", params['coreset'].__class__.__name__) + print("-" * 50) + set_seeds() + experiment = Experiment(batch_size=batch_size, + data_generator=gen(data, batch_size=batch_size, tasks=tasks), + ctx=CTX, verbose=verbose, + **params) + experiment.run() + print(experiment.overall_accuracy) + experiments.append(experiment) + print("-" * 50) + print() + + VCLRunner.plot(title, experiments, len(tasks)) + + @staticmethod + def plot(title, experiments, num_tasks): + fig = plt.figure(figsize=(num_tasks, 3)) + ax = plt.gca() + + x = range(1, num_tasks + 1) + + for experiment in experiments: + acc = np.nanmean(experiment.overall_accuracy, axis=1) + label = experiment.coreset.__class__.__name__ + plt.plot(x, acc, label=label, marker='o') + ax.set_xticks(x) + ax.set_ylabel('Average accuracy') + ax.set_xlabel('# tasks') + ax.legend() + ax.set_title(title) + + filename = "vcl_{}_{}.pdf".format(title, datetime.now().isoformat()[:-7]) + fig.savefig(filename, bbox_inches='tight') + plt.show() + plt.close() + + +if __name__ == "__main__": + import warnings + warnings.filterwarnings("ignore", category=UserWarning) + fire.Fire(VCLRunner.run) diff --git a/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:14:49.pdf b/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:14:49.pdf new file mode 100644 index 0000000..ad7097f Binary files /dev/null and b/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:14:49.pdf differ diff --git a/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:23:14.pdf b/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:23:14.pdf new file mode 100644 index 0000000..8063e13 Binary files /dev/null and b/examples/variational_continual_learning/vcl_Split MNIST_2019-02-23T00:23:14.pdf differ diff --git a/mxfusion/components/functions/mxfusion_gluon_function.py b/mxfusion/components/functions/mxfusion_gluon_function.py index 021e398..40cdf69 100644 --- a/mxfusion/components/functions/mxfusion_gluon_function.py +++ b/mxfusion/components/functions/mxfusion_gluon_function.py @@ -207,7 +207,8 @@ def _override_block_parameters(self, input_kws): ctx = val.context ctx_list = param._ctx_map[ctx.device_typeid&1] if ctx.device_id >= len(ctx_list) or ctx_list[ctx.device_id] is None: - raise Exception + raise ValueError("Context id {} out of range {}".format( + ctx.device_id, list(map(str, ctx_list)))) dev_id = ctx_list[ctx.device_id] param._data[dev_id] = val else: diff --git a/mxfusion/inference/__init__.py b/mxfusion/inference/__init__.py index e89b8f3..3adf37f 100644 --- a/mxfusion/inference/__init__.py +++ b/mxfusion/inference/__init__.py @@ -41,7 +41,7 @@ from .minibatch_loop import MinibatchInferenceLoop from .meanfield import create_Gaussian_meanfield from .forward_sampling import ForwardSampling, VariationalPosteriorForwardSampling, ForwardSamplingAlgorithm -from .grad_based_inference import GradBasedInference +from .grad_based_inference import GradBasedInference, GradIteratorBasedInference from .variational import StochasticVariationalInference from .inference_parameters import InferenceParameters from .score_function import ScoreFunctionInference, ScoreFunctionRBInference diff --git a/mxfusion/inference/batch_loop.py b/mxfusion/inference/batch_loop.py index e03710f..67d215c 100644 --- a/mxfusion/inference/batch_loop.py +++ b/mxfusion/inference/batch_loop.py @@ -23,7 +23,7 @@ class BatchInferenceLoop(GradLoop): """ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', - learning_rate=1e-3, max_iter=1000, n_prints=10, verbose=False): + learning_rate=1e-3, max_iter=1000, n_prints=10, verbose=False, callback=None): """ :param infr_executor: The MXNet function that computes the training objective. :type infr_executor: MXNet Gluon Block @@ -41,6 +41,10 @@ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', :type n_prints: int :param max_iter: the maximum number of iterations of gradient optimization :type max_iter: int + :param n_prints: number of times to print status + :type n_prints: int + :param callback: Callback function for custom print statements + :type callback: func :param verbose: whether to print per-iteration messages. :type verbose: boolean """ @@ -54,9 +58,15 @@ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', loss, loss_for_gradient = infr_executor(mx.nd.zeros(1, ctx=ctx), *data) loss_for_gradient.backward() if verbose: - print('\rIteration {} loss: {}'.format(i + 1, loss.asscalar()), + print('\rIteration {} loss: {}'.format(i, loss.asscalar()), end='') if i % iter_step == 0 and i > 0: print() + if callback is not None: + callback(i, loss.asscalar()) trainer.step(batch_size=1, ignore_stale_grad=True) - loss = infr_executor(mx.nd.zeros(1, ctx=ctx), *data) + loss, _ = infr_executor(mx.nd.zeros(1, ctx=ctx), *data) + if verbose: + print('\rIteration {} loss: {}'.format(max_iter, loss.asscalar()), end='') + if callback is not None: + callback(max_iter, loss.asscalar()) diff --git a/mxfusion/inference/forward_sampling.py b/mxfusion/inference/forward_sampling.py index 9ddbb22..638b143 100644 --- a/mxfusion/inference/forward_sampling.py +++ b/mxfusion/inference/forward_sampling.py @@ -50,8 +50,11 @@ def compute(self, F, variables): :rtype: mxnet.ndarray.ndarray.NDArray or mxnet.symbol.symbol.Symbol """ samples = self.model.draw_samples( - F=F, variables=variables, targets=self.target_variables, - num_samples=self.num_samples) + F=F, + variables=variables, # dict((k, v) for k, v in variables.items() if k not in self.ignored_variables), + targets=self.target_variables, + num_samples=self.num_samples, + ignored=self.ignored_variables) return samples @@ -80,24 +83,27 @@ class ForwardSampling(TransferInference): :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'} :param context: The MXNet context :type context: {mxnet.cpu or mxnet.gpu} + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ def __init__(self, num_samples, model, observed, var_tie, infr_params, target_variables=None, hybridize=False, constants=None, - dtype=None, context=None): + dtype=None, context=None, ignored=None): if target_variables is not None: target_variables = [v.uuid for v in target_variables if isinstance(v, Variable)] infr = ForwardSamplingAlgorithm( num_samples=num_samples, model=model, observed=observed, - target_variables=target_variables) + target_variables=target_variables, ignored=ignored) super(ForwardSampling, self).__init__( inference_algorithm=infr, var_tie=var_tie, infr_params=infr_params, constants=constants, hybridize=hybridize, dtype=dtype, context=context) -def merge_posterior_into_model(model, posterior, observed): +def merge_posterior_into_model(model, posterior, observed, ignored=None): """ Replace the prior distributions of a model with its variational posterior distributions. @@ -107,9 +113,16 @@ def merge_posterior_into_model(model, posterior, observed): :param posterior: Posterior :param observed: A list of observed variables :type observed: [Variable] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ + ignored = ignored or set() new_model = model.clone() for lv in model.get_latent_variables(observed): + # Test if lv is in ignored + if lv in ignored: + continue v = posterior.extract_distribution_of(posterior[lv]) new_model.replace_subgraph(new_model[v], v) return new_model @@ -135,10 +148,13 @@ class VariationalPosteriorForwardSampling(ForwardSampling): :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'} :param context: The MXNet context :type context: {mxnet.cpu or mxnet.gpu} + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ def __init__(self, num_samples, observed, inherited_inference, target_variables=None, - hybridize=False, constants=None, dtype=None, context=None): + hybridize=False, constants=None, dtype=None, context=None, ignored=None): if not isinstance(inherited_inference.inference_algorithm, (StochasticVariationalInference, MAP)): raise InferenceError('inherited_inference needs to be a subclass of SVIInference or SVIMiniBatchInference.') @@ -147,11 +163,52 @@ def __init__(self, num_samples, observed, q = inherited_inference.inference_algorithm.posterior model_graph = merge_posterior_into_model( - m, q, observed=inherited_inference.observed_variables) + m, q, observed=inherited_inference.observed_variables, ignored=ignored) super(VariationalPosteriorForwardSampling, self).__init__( num_samples=num_samples, model=model_graph, observed=observed, var_tie={}, infr_params=inherited_inference.params, target_variables=target_variables, hybridize=hybridize, - constants=constants, dtype=dtype, context=context) + constants=constants, dtype=dtype, context=context, ignored=ignored) + + +class VariationalPosteriorForwardSampling2(ForwardSampling): + """ + The forward sampling method for variational inference. + + :param num_samples: the number of samples used in estimating the variational lower bound + :type num_samples: int + :param observed: A list of observed variables + :type observed: [Variable] + :param inherited_inference: the inference method of which the model and inference results are taken + :type inherited_inference: SVIInference or SVIMiniBatchInference + :param target_variables: (optional) the target variables to sample + :type target_variables: [Variable] + :param constants: Specify a list of model variables as constants + :type constants: {Variable: mxnet.ndarray} + :param hybridize: Whether to hybridize the MXNet Gluon block of the inference method. + :type hybridize: boolean + :param dtype: data type for internal numerical representation + :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'} + :param context: The MXNet context + :type context: {mxnet.cpu or mxnet.gpu} + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] + """ + def __init__(self, num_samples, observed, + inherited_algorithm, inherited_params, inherited_model, inherited_posterior, target_variables=None, + hybridize=False, constants=None, dtype=None, context=None, ignored=None): + if not isinstance(inherited_algorithm, (StochasticVariationalInference, MAP)): + raise InferenceError('inherited_inference needs to be a subclass of SVIInference or SVIMiniBatchInference.') + + model_graph = merge_posterior_into_model( + inherited_model, inherited_posterior, observed=observed, ignored=ignored) + + super().__init__( + num_samples=num_samples, model=model_graph, + observed=observed, + var_tie={}, infr_params=inherited_params, + target_variables=target_variables, hybridize=hybridize, + constants=constants, dtype=dtype, context=context, ignored=ignored) diff --git a/mxfusion/inference/grad_based_inference.py b/mxfusion/inference/grad_based_inference.py index e6ec9d5..715186e 100644 --- a/mxfusion/inference/grad_based_inference.py +++ b/mxfusion/inference/grad_based_inference.py @@ -15,6 +15,7 @@ from .inference import Inference from .batch_loop import BatchInferenceLoop +from .minibatch_loop import MinibatchInferenceLoop class GradBasedInference(Inference): @@ -63,7 +64,7 @@ def create_executor(self): return infr def run(self, optimizer='adam', learning_rate=1e-3, max_iter=2000, - verbose=False, **kwargs): + verbose=False, callback=None, **kwargs): """ Run the inference method. @@ -75,6 +76,8 @@ def run(self, optimizer='adam', learning_rate=1e-3, max_iter=2000, :type max_iter: int :param verbose: whether to print per-iteration messages. :type verbose: boolean + :param callback: Callback function for custom print statements + :type callback: func :param kwargs: The keyword arguments specify the data for inferences. The key of each argument is the name of the corresponding variable in model definition and the value of the argument is the data in numpy array format. """ @@ -85,4 +88,61 @@ def run(self, optimizer='adam', learning_rate=1e-3, max_iter=2000, return self._grad_loop.run( infr_executor=infr, data=data, param_dict=self.params.param_dict, ctx=self.mxnet_context, optimizer=optimizer, - learning_rate=learning_rate, max_iter=max_iter, verbose=verbose) + learning_rate=learning_rate, max_iter=max_iter, verbose=verbose, callback=callback) + + +class GradIteratorBasedInference(Inference): + """ + An inference method consists of a few components: the applied inference algorithm, the model definition + (optionally a definition of posterior approximation), the inference parameters. + + :param inference_algorithm: The applied inference algorithm + :type inference_algorithm: InferenceAlgorithm + :param grad_loop: The reference to the main loop of gradient optimization + :type grad_loop: GradLoop + :param constants: Specify a list of model variables as constants + :type constants: {Variable: mxnet.ndarray} + :param hybridize: Whether to hybridize the MXNet Gluon block of the inference method. + :type hybridize: boolean + :param dtype: data type for internal numerical representation + :type dtype: {numpy.float64, numpy.float32, 'float64', 'float32'} + :param context: The MXNet context + :type context: {mxnet.cpu or mxnet.gpu} + """ + def __init__(self, inference_algorithm, grad_loop=None, constants=None, + hybridize=False, dtype=None, context=None): + if grad_loop is None: + grad_loop = MinibatchInferenceLoop() + super().__init__( + inference_algorithm=inference_algorithm, constants=constants, + hybridize=hybridize, dtype=dtype, context=context) + self._grad_loop = grad_loop + + def run(self, data, optimizer='adam', learning_rate=1e-3, max_iter=2000, verbose=False, callback=None, **kwargs): + """ + Run the inference method. + + :param optimizer: the choice of optimizer (default: 'adam') + :type optimizer: str + :param learning_rate: the learning rate of the gradient optimizer (default: 0.001) + :type learning_rate: float + :param max_iter: the maximum number of iterations of gradient optimization + :type max_iter: int + :param verbose: whether to print per-iteration messages. + :type verbose: boolean + :param callback: Callback function for custom print statements + :type callback: func + :param kwargs: The keyword arguments specify the data for inferences. The key of each argument is the name of + the corresponding variable in model definition and the value of the argument is the data in numpy array format. + """ + # data = [kwargs[v] for v in self.observed_variable_names] + + if not self._initialized: + raise ValueError("This inference method must be manually initialised, since we don't know the shapes" + "ahead of time.") + + infr = self.create_executor() + return self._grad_loop.run( + infr_executor=infr, data=data, param_dict=self.params.param_dict, + ctx=self.mxnet_context, optimizer=optimizer, + learning_rate=learning_rate, max_iter=max_iter, verbose=verbose, callback=callback) diff --git a/mxfusion/inference/inference_alg.py b/mxfusion/inference/inference_alg.py index 88416bd..fd0a150 100644 --- a/mxfusion/inference/inference_alg.py +++ b/mxfusion/inference/inference_alg.py @@ -108,7 +108,7 @@ def replicate_self(self, model, extra_graphs=None): replicant._observed_names = [v.name for v in observed] return replicant - def __init__(self, model, observed, extra_graphs=None): + def __init__(self, model, observed, extra_graphs=None, ignored=None): """ Initialize the algorithm @@ -118,6 +118,9 @@ def __init__(self, model, observed, extra_graphs=None): :type observed: [Variable] :param extra_graphs: a list of extra FactorGraph used in the inference algorithm. :type extra_graphs: [FactorGraph] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ self._model_graph = model self._extra_graphs = extra_graphs if extra_graphs is not None else [] @@ -126,6 +129,10 @@ def __init__(self, model, observed, extra_graphs=None): self._observed = set(observed) self._observed_uuid = variables_to_UUID(observed) self._observed_names = [v.name for v in observed] + ignored = ignored or [] + self._ignored = set(ignored) + self._ignored_uuid = variables_to_UUID(ignored) + self._ignored_names = [v.name for v in ignored] @property def observed_variables(self): @@ -148,6 +155,28 @@ def observed_variable_names(self): """ return self._observed_names + @property + def ignored_variables(self): + """ + The ignored variables in this inference algorithm. + """ + return self._ignored + + @property + def ignored_variable_UUIDs(self): + """ + The UUIDs of the ignored variables in this inference algorithm. + """ + return self._ignored_uuid + + @property + def ignored_variable_names(self): + """ + The names (if exist) of the ignored variables in this inference algorithm. + """ + return self._ignored_names + + @property def model(self): """ @@ -267,12 +296,15 @@ class SamplingAlgorithm(InferenceAlgorithm): :type target_variables: [UUID] :param extra_graphs: a list of extra FactorGraph used in the inference algorithm. :type extra_graphs: [FactorGraph] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ def __init__(self, model, observed, num_samples=1, target_variables=None, - extra_graphs=None): + extra_graphs=None, ignored=None): super(SamplingAlgorithm, self).__init__( - model=model, observed=observed, extra_graphs=extra_graphs) + model=model, observed=observed, extra_graphs=extra_graphs, ignored=ignored) self.num_samples = num_samples self.target_variables = target_variables diff --git a/mxfusion/inference/meanfield.py b/mxfusion/inference/meanfield.py index 68e2350..02f8337 100644 --- a/mxfusion/inference/meanfield.py +++ b/mxfusion/inference/meanfield.py @@ -21,24 +21,35 @@ from ..common.config import get_default_dtype -def create_Gaussian_meanfield(model, observed, dtype=None): +def create_Gaussian_meanfield(model, observed, ignored=None, dtype=None): """ - Create the Meanfield posterior for Variational Inference. + Create the mean-field posterior for Variational Inference. :param model: the definition of the probabilistic model :type model: Model :param observed: A list of observed variables :type observed: [Variable] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] :returns: the resulting posterior representation + :param dtype: Data type of the random variable (float32 or float64) :rtype: Posterior """ dtype = get_default_dtype() if dtype is None else dtype observed = variables_to_UUID(observed) + ignored = variables_to_UUID(ignored or []) q = Posterior(model) for v in model.variables.values(): - if v.type == VariableType.RANDVAR and v not in observed: + if v.type == VariableType.RANDVAR and v not in observed and v not in ignored: mean = Variable(shape=v.shape) variance = Variable(shape=v.shape, transformation=PositiveTransformation()) - q[v].set_prior(Normal(mean=mean, variance=variance, dtype=dtype)) + prior = Normal(mean=mean, variance=variance, dtype=dtype) + q[v].set_prior(prior) + + # setting a name for the priors so that cloning posteriors works + if not v.name and v.inherited_name: + setattr(q, "{}_prior".format(v.inherited_name), prior) + setattr(q, v.inherited_name, q[v]) return q diff --git a/mxfusion/inference/minibatch_loop.py b/mxfusion/inference/minibatch_loop.py index e196824..911b407 100644 --- a/mxfusion/inference/minibatch_loop.py +++ b/mxfusion/inference/minibatch_loop.py @@ -40,7 +40,7 @@ def __init__(self, batch_size=100, rv_scaling=None): if rv_scaling is not None else rv_scaling def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', - learning_rate=1e-3, max_iter=1000, verbose=False): + learning_rate=1e-3, max_iter=1000, verbose=False, callback=None): """ :param infr_executor: The MXNet function that computes the training objective. :type infr_executor: MXNet Gluon Block @@ -56,12 +56,16 @@ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', :type learning_rate: float :param max_iter: the maximum number of iterations of gradient optimization :type max_iter: int + :param callback: Callback function for custom print statements + :type callback: func :param verbose: whether to print per-iteration messages. :type verbose: boolean """ if isinstance(data, mx.gluon.data.DataLoader): data_loader = data + elif isinstance(data, mx.io.DataIter): + data_loader = data else: data_loader = mx.gluon.data.DataLoader( ArrayDataset(*data), batch_size=self.batch_size, shuffle=True, @@ -75,15 +79,21 @@ def run(self, infr_executor, data, param_dict, ctx, optimizer='adam', n_batches = 0 for i, data_batch in enumerate(data_loader): with mx.autograd.record(): + if isinstance(data_batch, mx.io.DataBatch): + data_batch = (data_batch.data[0], data_batch.label[0]) loss, loss_for_gradient = infr_executor(mx.nd.zeros(1, ctx=ctx), *data_batch) loss_for_gradient.backward() if verbose: print('\repoch {} Iteration {} loss: {}\t\t\t'.format( e + 1, i + 1, loss.asscalar()), end='') + if callback is not None: + callback(i, loss.asscalar()) trainer.step(batch_size=self.batch_size, ignore_stale_grad=True) L_e += loss.asscalar() n_batches += 1 if verbose: print('epoch-loss: {} '.format(L_e / n_batches)) + if callback is not None: + callback(e, L_e) diff --git a/mxfusion/inference/variational.py b/mxfusion/inference/variational.py index eaaa0ab..22c12c7 100644 --- a/mxfusion/inference/variational.py +++ b/mxfusion/inference/variational.py @@ -26,11 +26,14 @@ class VariationalInference(InferenceAlgorithm): :param posterior: Posterior :param observed: A list of observed variables :type observed: [Variable] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ - def __init__(self, model, posterior, observed): + def __init__(self, model, posterior, observed, ignored=None): super(VariationalInference, self).__init__( - model=model, observed=observed, extra_graphs=[posterior]) + model=model, observed=observed, extra_graphs=[posterior], ignored=ignored) @property def posterior(self): @@ -82,10 +85,13 @@ class StochasticVariationalInference(VariationalInference): :param posterior: Posterior :param observed: A list of observed variables :type observed: [Variable] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] """ - def __init__(self, num_samples, model, posterior, observed): + def __init__(self, num_samples, model, posterior, observed, ignored=None): super(StochasticVariationalInference, self).__init__( - model=model, posterior=posterior, observed=observed) + model=model, posterior=posterior, observed=observed, ignored=ignored) self.num_samples = num_samples def compute(self, F, variables): @@ -103,6 +109,12 @@ def compute(self, F, variables): samples = self.posterior.draw_samples( F=F, variables=variables, num_samples=self.num_samples) variables.update(samples) - logL = self.model.log_pdf(F=F, variables=variables) + + if self.ignored_variables: + targets = self.posterior.variables + else: + targets = None + + logL = self.model.log_pdf(F=F, variables=variables, targets=targets) logL = logL - self.posterior.log_pdf(F=F, variables=variables) return -logL, -logL diff --git a/mxfusion/models/factor_graph.py b/mxfusion/models/factor_graph.py index f4e9ff8..6386520 100644 --- a/mxfusion/models/factor_graph.py +++ b/mxfusion/models/factor_graph.py @@ -237,7 +237,7 @@ def log_pdf(self, F, variables, targets=None): "That shouldn't happen.") return logL - def draw_samples(self, F, variables, num_samples=1, targets=None): + def draw_samples(self, F, variables, num_samples=1, targets=None, ignored=None): """ Draw samples from the target variables of the Factor Graph. If the ``targets`` argument is None, draw samples from all the variables that are *not* in the conditional variables. If the ``targets`` argument is given, @@ -251,9 +251,13 @@ def draw_samples(self, F, variables, num_samples=1, targets=None): :type num_samples: int :param targets: a list of Variables to draw samples from. :type targets: [UUID] + :param ignored: A list of ignored variables. + These are variables that are not observed, but also will not be inferred + :type ignored: [Variable] :returns: the samples of the target variables. :rtype: (MXNet NDArray or MXNet Symbol,) or {str(UUID): MXNet NDArray or MXNet Symbol} """ + ignored = ignored or () samples = {} for f in self.ordered_factors: if isinstance(f, FunctionEvaluation): @@ -274,6 +278,8 @@ def draw_samples(self, F, variables, num_samples=1, targets=None): elif any(known): raise InferenceError("Part of the outputs of the distribution " + f.__class__.__name__ + " has been observed!") + if any(v in ignored for (_, v) in f.outputs): + continue outcome_uuid = [v.uuid for _, v in f.outputs] outcome = f.draw_samples( F=F, num_samples=num_samples, variables=variables, always_return_tuple=True) @@ -444,7 +450,7 @@ def _clone(self, new_model, leaves=None): new_leaf = v.replicate(var_map=var_map, replication_function=lambda x: ('recursive', 'recursive')) setattr(new_model, v.name, new_leaf) else: - v.graph = new_model.graph + v.components_graph = new_model.components_graph for v in self.variables.values(): if v.name is not None: setattr(new_model, v.name, new_model[v.uuid]) diff --git a/mxfusion/util/inference.py b/mxfusion/util/inference.py index 02f7601..79437a2 100644 --- a/mxfusion/util/inference.py +++ b/mxfusion/util/inference.py @@ -19,7 +19,10 @@ def broadcast_samples_dict(F, array_dict, num_samples=None): """ - Broadcast the shape of arrays in the provided dictionary. When the num_samples argument is given, all the sample dimesnions (the first dimension) of the arrays in the dictionary will be broadcasted to the size of num_samples. If the num_samples argument is not given, the sample dimensions of the arrays in the dictionary will be broadcasted to the maximum number of the sizes of the sample dimensions. + Broadcast the shape of arrays in the provided dictionary. When the num_samples argument is given, all the sample + dimensions (the first dimension) of the arrays in the dictionary will be broadcasted to the size of num_samples. + If the num_samples argument is not given, the sample dimensions of the arrays in the dictionary will be broadcasted + to the maximum number of the sizes of the sample dimensions. :param F: the execution mode of MXNet. :type F: mxnet.ndarray or mxnet.symbol @@ -65,8 +68,7 @@ def discover_shape_constants(data_shapes, graphs): variables in the model and inference models. :param data_shapes: a dict of shapes of data - :param graphs: a list of factor graphs of which variable shapes are - searched. + :param graphs: a list of factor graphs of which variable shapes are searched. :returns: a dict of constants discovered from data shapes :rtype: {Variable: int} """ @@ -79,11 +81,15 @@ def discover_shape_constants(data_shapes, graphs): for s1, s2 in zip(def_shape, shape): if isinstance(s1, int): if s1 != s2: - raise ModelSpecificationError("Variable ({}) shape mismatch between expected and found! s1 : {} s2 : {}".format(str(variables[var_id]),str(s1), str(s2))) + raise ModelSpecificationError( + "Variable ({}) shape mismatch between expected and found! s1 : {} s2 : {}" + .format(str(variables[var_id]),str(s1), str(s2))) elif isinstance(s1, Variable): shape_constants[s1] = s2 else: - raise ModelSpecificationError("The shape of a Variable should either an integer or a Variable, but encountered {}!".format(str(type(s1)))) + raise ModelSpecificationError( + "The shape of a Variable should either an integer or a Variable, but encountered {}!" + .format(str(type(s1)))) return shape_constants