Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ac2e47e
Starting VCL notebook
tdiethe Nov 12, 2018
9ce178b
Simplified NN code
tdiethe Nov 12, 2018
de7f7ec
Start of example code for VCL
tdiethe Nov 12, 2018
19eb7b6
Merge branch 'develop' of github.com:amzn/MXFusion into vcl
tdiethe Nov 12, 2018
97a48d8
Refactored code out of the main function into classes
tdiethe Nov 20, 2018
19b79e8
Merge branch 'develop' of github.com:amzn/MXFusion into vcl
tdiethe Nov 20, 2018
2df2d48
Fixed some modelling bugs
tdiethe Nov 20, 2018
3f3170f
Fixed provide_data and provide_label in multiiterator
tdiethe Nov 20, 2018
52267c3
Added Permuted MNIST generator
tdiethe Nov 21, 2018
672a5b2
Created Task class
tdiethe Nov 21, 2018
46b2223
Attempt to fix Sequential model for multi-head
tdiethe Nov 22, 2018
580b447
Separated out the MLP model into a separate class/module
tdiethe Nov 23, 2018
c20a9ed
Bug fix in function evaluation
tdiethe Nov 23, 2018
9c0c023
Merge branch 'develop' of github.com:amzn/MXFusion into vcl
tdiethe Nov 23, 2018
6110fad
Added some extra print statements
tdiethe Nov 23, 2018
e7a3ac6
Changed transpose property to function. This stops the debugger accid…
tdiethe Nov 23, 2018
bb7af81
Merge branch 'develop' of github.com:amzn/MXFusion into vcl
tdiethe Nov 26, 2018
ec39612
Experimenting with multi-head methods
tdiethe Nov 26, 2018
a25646d
Added version of grad based inference for DataLoader objects
tdiethe Dec 18, 2018
094c63a
Merge branch 'develop' of github.com:amzn/MXFusion into examples/vcl
tdiethe Feb 14, 2019
32c66e2
Changing print statements for 3.4 compatability
tdiethe Feb 21, 2019
33a1576
Fixed bug in factor graph
tdiethe Feb 21, 2019
d2cac26
Fixes to ignored parts of the graph
tdiethe Feb 22, 2019
ab266c3
Support for ignored variables. These are variables that are not obser…
tdiethe Feb 22, 2019
8e1878f
tidying
tdiethe Feb 22, 2019
eea6107
Merge branch 'develop' of github.com:amzn/MXFusion into examples/vcl
tdiethe Feb 22, 2019
e8872fa
more useful error message
tdiethe Feb 22, 2019
d15373b
edit to error message
tdiethe Feb 22, 2019
911795e
added figures from run on p3
tdiethe Feb 23, 2019
4fbcbfb
Fix to GPU error
tdiethe Feb 25, 2019
d537377
Added python-fire for simple command line argument parsing (see https…
tdiethe Feb 25, 2019
e220d93
Merge branch 'develop' of github.com:amzn/MXFusion into examples/vcl
tdiethe Feb 26, 2019
526aae3
Fix regressions after merge
tdiethe Feb 26, 2019
56565df
Fix bug in plotting
tdiethe Feb 26, 2019
422a7dc
Fix off-by-one error in printing
tdiethe Feb 26, 2019
17f6841
Fix in the case of ignored==None
tdiethe Feb 26, 2019
d73ebad
Fixing coreset merging
tdiethe Mar 28, 2019
b72e396
Adaptation of minibatch loop to work with mx.io.DataIter and mx.io.Da…
tdiethe Mar 28, 2019
565c61c
Some tidying
tdiethe Mar 28, 2019
1c9fc60
Merge branch 'develop' of github.com:amzn/MXFusion into examples/vcl
tdiethe Mar 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 308 additions & 0 deletions examples/notebooks/Variational Continual Learning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import gzip\n",
"import sys\n",
"\n",
"import mxfusion as mf\n",
"import mxnet as mx\n",
"\n",
"import logging\n",
"logging.getLogger().setLevel(logging.DEBUG) # logging to stdout\n",
"\n",
"# Set the compute context, GPU is available otherwise CPU\n",
"ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SplitMnistGenerator:\n",
" def __init__(self, data, batch_size):\n",
" self.data = data\n",
" self.batch_size = batch_size\n",
"\n",
" def __iter__(self):\n",
" for i in range(5):\n",
" idx_train_0 = np.where(self.data['train_label'] == i * 2)[0]\n",
" idx_train_1 = np.where(self.data['train_label'] == i * 2 + 1)[0]\n",
" idx_test_0 = np.where(self.data['test_label'] == i * 2)[0]\n",
" idx_test_1 = np.where(self.data['test_label'] == i * 2 + 1)[0]\n",
" \n",
" x_train = np.vstack((self.data['train_data'][idx_train_0], self.data['train_data'][idx_train_1]))\n",
" y_train = np.vstack((np.ones((idx_train_0.shape[0], 1)), -np.ones((idx_train_1.shape[0], 1))))\n",
"\n",
" x_test = np.vstack((self.data['test_data'][idx_test_0], self.data['test_data'][idx_test_1]))\n",
" y_test = np.vstack((np.ones((idx_test_0.shape[0], 1)), -np.ones((idx_test_1.shape[0], 1))))\n",
" \n",
" batch_size = x_train.shape[0] if self.batch_size is None else self.batch_size \n",
" train_iter = mx.io.NDArrayIter(x_train, y_train, batch_size, shuffle=True)\n",
"\n",
" batch_size = x_test.shape[0] if self.batch_size is None else self.batch_size \n",
" test_iter = mx.io.NDArrayIter(x_test, y_test, batch_size)\n",
" \n",
" yield train_iter, test_iter\n",
" return\n",
"\n",
"mnist = mx.test_utils.get_mnist()\n",
"in_dim = np.prod(mnist['train_data'][0].shape)\n",
"\n",
"gen = SplitMnistGenerator(mnist, batch_size=None)\n",
"for task_id, (train, test) in enumerate(gen):\n",
" print(\"Task\", task_id)\n",
" print(\"Train data shape\" ,train.data[0][1].shape)\n",
" print(\"Train label shape\" ,train.label[0][1].shape)\n",
" print(\"Test data shape\" ,test.data[0][1].shape)\n",
" print(\"Test label shape\" ,test.label[0][1].shape)\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rand_from_batch(x_coreset, y_coreset, x_train, y_train, coreset_size):\n",
" \"\"\" Random coreset selection \"\"\"\n",
" # Randomly select from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)\n",
" idx = np.random.choice(x_train.shape[0], coreset_size, False)\n",
" x_coreset.append(x_train[idx,:])\n",
" y_coreset.append(y_train[idx,:])\n",
" x_train = np.delete(x_train, idx, axis=0)\n",
" y_train = np.delete(y_train, idx, axis=0)\n",
" return x_coreset, y_coreset, x_train, y_train \n",
"\n",
"def k_center(x_coreset, y_coreset, x_train, y_train, coreset_size):\n",
" \"\"\" K-center coreset selection \"\"\"\n",
" # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)\n",
" dists = np.full(x_train.shape[0], np.inf)\n",
" current_id = 0\n",
" dists = update_distance(dists, x_train, current_id)\n",
" idx = [ current_id ]\n",
"\n",
" for i in range(1, coreset_size):\n",
" current_id = np.argmax(dists)\n",
" dists = update_distance(dists, x_train, current_id)\n",
" idx.append(current_id)\n",
"\n",
" x_coreset.append(x_train[idx,:])\n",
" y_coreset.append(y_train[idx,:])\n",
" x_train = np.delete(x_train, idx, axis=0)\n",
" y_train = np.delete(y_train, idx, axis=0)\n",
" return x_coreset, y_coreset, x_train, y_train\n",
"\n",
"def update_distance(dists, x_train, current_id):\n",
" for i in range(x_train.shape[0]):\n",
" current_dist = np.linalg.norm(x_train[i,:]-x_train[current_id,:])\n",
" dists[i] = np.minimum(current_dist, dists[i])\n",
" return dists"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def run_vcl(network_shape, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True):\n",
" x_coresets, y_coresets = [], []\n",
" x_testsets, y_testsets = [], []\n",
"\n",
" all_acc = np.array([])\n",
"\n",
" for task_id, (train, test) in enumerate(data_gen):\n",
" x_testsets.append(test.data[0][1])\n",
" y_testsets.append(test.label[0][1])\n",
"\n",
" # Set the readout head to train\n",
" head = 0 if single_head else task_id\n",
" # bsize = x_train.shape[0] if (batch_size is None) else batch_size\n",
"\n",
" # Train network with maximum likelihood to initialize first model\n",
" if task_id == 0:\n",
" ml_model = VanillaNN(network_shape)\n",
" ml_model.train(x_train, y_train, task_id, no_epochs, bsize)\n",
" mf_weights = ml_model.get_weights()\n",
" mf_variances = None\n",
" ml_model.close_session()\n",
"\n",
" # Select coreset if needed\n",
" if coreset_size > 0:\n",
" x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size)\n",
"\n",
" # Train on non-coreset data\n",
" mf_model = MFVINN(network_shape, prev_means=mf_weights, prev_log_variances=mf_variances)\n",
" mf_model.train(x_train, y_train, head, no_epochs, bsize)\n",
" mf_weights, mf_variances = mf_model.get_weights()\n",
"\n",
" # Incorporate coreset data and make prediction\n",
" acc = utils.get_scores(mf_model, x_testsets, y_testsets, x_coresets, y_coresets, hidden_size, no_epochs, single_head, batch_size)\n",
" all_acc = utils.concatenate_results(acc, all_acc)\n",
"\n",
" mf_model.close_session()\n",
"\n",
" return all_acc"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class BaseNN:\n",
" def __init__(self, network_shape):\n",
" # input and output placeholders\n",
" self.task_idx = mx.sym.Variable(name='task_idx', dtype=np.float32)\n",
" self.net = None\n",
" \n",
" def train(self, train_iter, val_iter, ctx):\n",
" # data = mx.sym.var('data')\n",
" # Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)\n",
"# data = mx.sym.flatten(data=data)\n",
" \n",
" # create a trainable module on compute context\n",
" self.model = mx.mod.Module(symbol=self.net, context=ctx)\n",
" self.model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n",
" init = mx.init.Xavier(factor_type=\"in\", magnitude=2.34)\n",
" self.model.init_params(initializer=init, force_init=True)\n",
" self.model.fit(train_iter, # train data\n",
" eval_data=val_iter, # validation data\n",
" optimizer='adam', # use SGD to train\n",
" optimizer_params={'learning_rate': 0.001}, # use fixed learning rate\n",
" eval_metric='acc', # report accuracy during training\n",
" batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches\n",
" num_epoch=10) # train for at most 50 dataset passes\n",
" # predict accuracy of mlp\n",
" acc = mx.metric.Accuracy()\n",
" self.model.score(test_iter, acc)\n",
" return acc\n",
"\n",
" def prediction_prob(self, test_iter, task_idx):\n",
" # task_idx??\n",
" prob = self.model.predict(test_iter)\n",
" return prob\n",
" \n",
"class VanillaNN(BaseNN):\n",
" def __init__(self, network_shape, prev_weights=None, learning_rate=0.001):\n",
" super(VanillaNN, self).__init__(network_shape)\n",
"\n",
" # Create net\n",
" self.net = mx.gluon.nn.HybridSequential(prefix='vanilla_')\n",
" with self.net.name_scope():\n",
" for layer in network_shape[1:-1]:\n",
" self.net.add(mx.gluon.nn.Dense(layer, activation=\"relu\"))\n",
" # Last layer for classification\n",
" self.net.add(mx.gluon.nn.Dense(network_shape[-1], flatten=True, in_units=network_shape[-2]))\n",
" self.loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()\n",
" self.net.initialize(mx.init.Xavier(magnitude=2.34))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"network_shape = (in_dim, 256, 256, 2) # binary classification\n",
"batch_size = None\n",
"no_epochs = 120\n",
"single_head = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run vanilla VCL\n",
"mx.random.seed(42)\n",
"np.random.seed(42)\n",
"\n",
"coreset_size = 0\n",
"data_gen = SplitMnistGenerator(mnist, batch_size)\n",
"vcl_result = run_vcl(network_shape, no_epochs, data_gen, rand_from_batch, coreset_size, batch_size, single_head)\n",
"print(vcl_result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run random coreset VCL\n",
"mx.random.seed(42)\n",
"np.random.seed(42)\n",
"\n",
"coreset_size = 40\n",
"data_gen = SplitMnistGenerator(mnist, batch_size)\n",
"rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, \n",
" coreset.rand_from_batch, coreset_size, batch_size, single_head)\n",
"print(rand_vcl_result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run k-center coreset VCL\n",
"mx.random.seed(42)\n",
"np.random.seed(42)\n",
"\n",
"data_gen = SplitMnistGenerator(mnist, batch_size)\n",
"kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, \n",
" coreset.k_center, coreset_size, batch_size, single_head)\n",
"print(kcen_vcl_result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Plot average accuracy\n",
"vcl_avg = np.nanmean(vcl_result, 1)\n",
"rand_vcl_avg = np.nanmean(rand_vcl_result, 1)\n",
"kcen_vcl_avg = np.nanmean(kcen_vcl_result, 1)\n",
"utils.plot('results/split.jpg', vcl_avg, rand_vcl_avg, kcen_vcl_avg)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading