diff --git a/maml++.ipynb b/maml++.ipynb
new file mode 100644
index 000000000..de5920c43
--- /dev/null
+++ b/maml++.ipynb
@@ -0,0 +1,3227 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "maml++.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_bgT2a-_HIbT",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# mount google drive\n",
+ "from google.colab import drive\n",
+ "drive.mount('/gdrive')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "0y79kypKJ_J3",
+ "colab": {}
+ },
+ "source": [
+ "# copy files from google colab to google drive\n",
+ "!cp -r omniglot_1_8_0.1_64_5_0/ ../gdrive/My\\ Drive/\n",
+ "!cp -r runs/ ../gdrive/My\\ Drive/"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab_type": "code",
+ "id": "f5ohNe6LJ_JX",
+ "colab": {}
+ },
+ "source": [
+ "# retrieve files from google drive to google colab\n",
+ "!cp -r ../gdrive/My\\ Drive/omniglot_1_8_0.1_64_5_0/ ./omniglot_1_8_0.1_64_5_0/\n",
+ "!cp -r ../gdrive/My\\ Drive/runs/ ./runs/ "
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "avyMhUz2yTg4",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# clone repository, for dataset and json files\n",
+ "!git clone https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch.git\n",
+ "!mv HowToTrainYourMAMLPytorch/* ./"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ZfXN4qJM9zSr",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "!rm -r HowToTrainYourMAMLPytorch/"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lX1pF70a-KZG",
+ "colab_type": "code",
+ "outputId": "bf311efe-ffaa-4b08-d4a9-b5c7191da8f2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 246
+ }
+ },
+ "source": [
+ "# ngrok, dynamic visualization\n",
+ "!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip\n",
+ "!unzip ngrok-stable-linux-amd64.zip"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "--2019-06-02 02:16:38-- https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip\n",
+ "Resolving bin.equinox.io (bin.equinox.io)... 52.73.94.166, 52.7.169.168, 34.232.40.183, ...\n",
+ "Connecting to bin.equinox.io (bin.equinox.io)|52.73.94.166|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 16648024 (16M) [application/octet-stream]\n",
+ "Saving to: ‘ngrok-stable-linux-amd64.zip’\n",
+ "\n",
+ "ngrok-stable-linux- 100%[===================>] 15.88M 42.3MB/s in 0.4s \n",
+ "\n",
+ "2019-06-02 02:16:39 (42.3 MB/s) - ‘ngrok-stable-linux-amd64.zip’ saved [16648024/16648024]\n",
+ "\n",
+ "Archive: ngrok-stable-linux-amd64.zip\n",
+ " inflating: ngrok \n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nw5CGKbw-KYC",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# logs are saved in ./runs directory\n",
+ "LOG_DIR = './runs'\n",
+ "get_ipython().system_raw(\n",
+ " 'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'\n",
+ " .format(LOG_DIR)\n",
+ ")"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "EoqoTvFz-KVe",
+ "colab_type": "code",
+ "outputId": "d7254665-941b-4ba0-cf8a-813f8c5de3c2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "get_ipython().system_raw('./ngrok http 6006 &')\n",
+ "! curl -s http://localhost:4040/api/tunnels | python3 -c \\\n",
+ " \"import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])\"\n"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "http://e43be063.ngrok.io\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "aTf0w978-RNd",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# for using TPU\n",
+ "# !pip install \\\n",
+ "# http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl \\\n",
+ "# http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "dlYcoGLm-RMF",
+ "colab_type": "code",
+ "outputId": "3c771328-b72f-41ca-a166-510602dfecf0",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1005
+ }
+ },
+ "source": [
+ "# for using ngrok, tensorboard with PyTorch, dynamic visualization, we need to upgrade tb-nightly\n",
+ "%%shell\n",
+ "pip install --upgrade tb-nightly"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Collecting tensorflow==1.14.0rc0\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/6f/0c/355160095d16e6fe06a06474471b4e7560218b152db88a06f0190e7401c9/tensorflow-1.14.0rc0-cp36-cp36m-manylinux1_x86_64.whl (109.2MB)\n",
+ "\u001b[K |████████████████████████████████| 109.2MB 205kB/s \n",
+ "\u001b[?25hRequirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.7.1)\n",
+ "Requirement already satisfied: tensorboard<1.14.0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.13.1)\n",
+ "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.15.0)\n",
+ "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.0.9)\n",
+ "Collecting tf-estimator-nightly<1.14.0.dev2019042302,>=1.14.0.dev2019042301 (from tensorflow==1.14.0rc0)\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/6a/7d343697ad80f824f9464e63fb910983174767f099ba6ed2c49cc6873f96/tf_estimator_nightly-1.14.0.dev2019042301-py2.py3-none-any.whl (480kB)\n",
+ "\u001b[K |████████████████████████████████| 481kB 34.6MB/s \n",
+ "\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.12.0)\n",
+ "Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.2.2)\n",
+ "Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.8.0)\n",
+ "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.1.0)\n",
+ "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.0.7)\n",
+ "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (3.7.1)\n",
+ "Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.16.4)\n",
+ "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.33.4)\n",
+ "Collecting wrapt>=1.11.1 (from tensorflow==1.14.0rc0)\n",
+ " Downloading https://files.pythonhosted.org/packages/67/b2/0f71ca90b0ade7fad27e3d20327c996c6252a2ffe88f50a95bba7434eda9/wrapt-1.11.1.tar.gz\n",
+ "Collecting google-pasta>=0.1.6 (from tensorflow==1.14.0rc0)\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d0/33/376510eb8d6246f3c30545f416b2263eee461e40940c2a4413c711bdf62d/google_pasta-0.1.7-py3-none-any.whl (52kB)\n",
+ "\u001b[K |████████████████████████████████| 61kB 16.3MB/s \n",
+ "\u001b[?25hRequirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow==1.14.0rc0) (3.1.1)\n",
+ "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow==1.14.0rc0) (0.15.4)\n",
+ "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow==1.14.0rc0) (2.8.0)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow==1.14.0rc0) (41.0.1)\n",
+ "Building wheels for collected packages: wrapt\n",
+ " Building wheel for wrapt (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Stored in directory: /root/.cache/pip/wheels/89/67/41/63cbf0f6ac0a6156588b9587be4db5565f8c6d8ccef98202fc\n",
+ "Successfully built wrapt\n",
+ "\u001b[31mERROR: thinc 6.12.1 has requirement wrapt<1.11.0,>=1.10.0, but you'll have wrapt 1.11.1 which is incompatible.\u001b[0m\n",
+ "Installing collected packages: tf-estimator-nightly, wrapt, google-pasta, tensorflow\n",
+ " Found existing installation: wrapt 1.10.11\n",
+ " Uninstalling wrapt-1.10.11:\n",
+ " Successfully uninstalled wrapt-1.10.11\n",
+ " Found existing installation: tensorflow 1.13.1\n",
+ " Uninstalling tensorflow-1.13.1:\n",
+ " Successfully uninstalled tensorflow-1.13.1\n",
+ "Successfully installed google-pasta-0.1.7 tensorflow-1.14.0rc0 tf-estimator-nightly-1.14.0.dev2019042301 wrapt-1.11.1\n",
+ "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (0.16.0)\n",
+ "Collecting tb-nightly\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/3e/85/f4cfd5ba1d5f67702e05a785d4fa2732c1ba58c66038c7280bb427c6430b/tb_nightly-1.14.0a20190601-py3-none-any.whl (3.1MB)\n",
+ "\u001b[K |████████████████████████████████| 3.1MB 3.4MB/s \n",
+ "\u001b[?25hRequirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.33.4)\n",
+ "Requirement already satisfied, skipping upgrade: numpy>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.16.4)\n",
+ "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.12.0)\n",
+ "Requirement already satisfied, skipping upgrade: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.15.0)\n",
+ "Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.7.1)\n",
+ "Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (41.0.1)\n",
+ "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.15.4)\n",
+ "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (3.1.1)\n",
+ "Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (3.7.1)\n",
+ "Installing collected packages: tb-nightly\n",
+ "Successfully installed tb-nightly-1.14.0a20190601\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 15
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "CXSGKa4l-RKj",
+ "colab_type": "code",
+ "outputId": "78cc8153-1aab-4643-9e7c-32410f7f65fc",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "from __future__ import print_function, unicode_literals, division\n",
+ "from IPython.core.debugger import set_trace\n",
+ "from IPython.display import HTML, Math\n",
+ "from pprint import pprint\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.optim as optim\n",
+ "import torch.autograd as autograd\n",
+ "from torchvision import datasets\n",
+ "from torch.nn.utils.weight_norm import WeightNorm\n",
+ "from torch.optim.lr_scheduler import StepLR\n",
+ "from torch.optim import Adam\n",
+ "from torchvision.transforms import Compose, RandomHorizontalFlip, RandomResizedCrop, ToTensor, Normalize, \\\n",
+ " CenterCrop, Resize, ColorJitter, ToPILImage, RandomCrop\n",
+ "import torchvision.models as models\n",
+ "from torch.autograd import Variable\n",
+ "from torch.utils.data import DataLoader, Dataset, sampler\n",
+ "from torchvision.utils import make_grid, save_image\n",
+ "from torch.utils.tensorboard import SummaryWriter\n",
+ "import torch.utils.checkpoint as cp\n",
+ "\n",
+ "# for using TPU\n",
+ "# import torch_xla\n",
+ "# import torch_xla\n",
+ "# import torch_xla_py.utils as xu\n",
+ "# import torch_xla_py.xla_model as xm\n",
+ "\n",
+ "import json, glob, time, math, os, datetime, string, random, re, warnings, itertools, \\\n",
+ " logging, numbers, csv, pickle, tqdm\n",
+ "from copy import copy\n",
+ "\n",
+ "import numpy as np\n",
+ "from PIL import Image, ImageEnhance, ImageFile\n",
+ "from collections import OrderedDict\n",
+ "import concurrent.futures\n",
+ "\n",
+ "ImageFile.LOAD_TRUNCATED_IMAGES = True\n",
+ "\n",
+ "print(\"CUDA available: \", torch.cuda.is_available())\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "\n",
+ "os.environ['DATASET_DIR'] = './datasets'\n",
+ "\n",
+ "# for writing into tensorboard\n",
+ "writer = SummaryWriter()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "CUDA available: True\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "d8LgqZ6qy1L7",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class rotate_image(object):\n",
+ "\n",
+ " def __init__(self, k, channels):\n",
+ " self.k = k\n",
+ " self.channels = channels\n",
+ "\n",
+ " def __call__(self, image):\n",
+ " if self.channels == 1 and len(image.shape) == 3:\n",
+ " image = image[:, :, 0]\n",
+ " image = np.expand_dims(image, axis=2)\n",
+ "\n",
+ " elif self.channels == 1 and len(image.shape) == 4:\n",
+ " image = image[:, :, :, 0]\n",
+ " image = np.expand_dims(image, axis=3)\n",
+ "\n",
+ " image = np.rot90(image, k=self.k).copy()\n",
+ " return image"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "GZ6A6o8oJLil",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class torch_rotate_image(object):\n",
+ "\n",
+ " def __init__(self, k, channels):\n",
+ " self.k = k\n",
+ " self.channels = channels\n",
+ "\n",
+ " def __call__(self, image):\n",
+ " rotate = RandomRotation(degrees=self.k * 90)\n",
+ " if image.shape[-1] == 1:\n",
+ " image = image[:, :, 0]\n",
+ " image = Image.fromarray(image)\n",
+ " image = rotate(image)\n",
+ " image = np.array(image)\n",
+ " if len(image.shape) == 2:\n",
+ " image = np.expand_dims(image, axis=2)\n",
+ " return image"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Fcg1Bxs3ythc",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def augment_image(image, k, channels, augment_bool, args, dataset_name):\n",
+ " transform_train, transform_evaluation = get_transforms_for_dataset(dataset_name=dataset_name,\n",
+ " args=args, k=k)\n",
+ " if len(image.shape) > 3:\n",
+ " images = [item for item in image]\n",
+ " output_images = []\n",
+ " for image in images:\n",
+ " if augment_bool is True:\n",
+ " for transform_current in transform_train:\n",
+ " image = transform_current(image)\n",
+ " else:\n",
+ " for transform_current in transform_evaluation:\n",
+ " image = transform_current(image)\n",
+ " output_images.append(image)\n",
+ " image = torch.stack(output_images)\n",
+ " else:\n",
+ " if augment_bool is True:\n",
+ " # meanstd transformation\n",
+ " for transform_current in transform_train:\n",
+ " image = transform_current(image)\n",
+ " else:\n",
+ " for transform_current in transform_evaluation:\n",
+ " image = transform_current(image)\n",
+ " return image"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "9NVCa1zGywce",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def get_transforms_for_dataset(dataset_name, args, k):\n",
+ " if \"cifar10\" in dataset_name or \"cifar100\" in dataset_name:\n",
+ " transform_train = [\n",
+ " RandomCrop(32, padding=4),\n",
+ " RandomHorizontalFlip(),\n",
+ " ToTensor(),\n",
+ " Normalize(args.classification_mean, args.classification_std)]\n",
+ "\n",
+ " transform_evaluate = [\n",
+ " ToTensor(),\n",
+ " Normalize(args.classification_mean, args.classification_std)]\n",
+ "\n",
+ " elif 'omniglot' in dataset_name:\n",
+ "\n",
+ " transform_train = [rotate_image(k=k, channels=args.image_channels), ToTensor()]\n",
+ " transform_evaluate = [ToTensor()]\n",
+ "\n",
+ "\n",
+ " elif 'imagenet' in dataset_name:\n",
+ "\n",
+ " transform_train = [Compose([\n",
+ "\n",
+ " ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])]\n",
+ "\n",
+ " transform_evaluate = [Compose([\n",
+ "\n",
+ " ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])]\n",
+ "\n",
+ " return transform_train, transform_evaluate"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "8E3bStPExiaz",
+ "colab_type": "code",
+ "outputId": "79a03442-1fd5-4c6e-bef5-9b478671579a",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 132
+ }
+ },
+ "source": [
+ "class FewShotLearningDatasetParallel(Dataset):\n",
+ " def __init__(self, args):\n",
+ " \"\"\"\n",
+ " A data provider class inheriting from Pytorch's Dataset class. It takes care of creating task sets for\n",
+ " our few-shot learning model training and evaluation\n",
+ " :param args: Arguments in the form of a Bunch object. Includes all hyperparameters necessary for the\n",
+ " data-provider. For transparency and readability reasons to explicitly set as self.object_name all arguments\n",
+ " required for the data provider, such that the reader knows exactly what is necessary for the data provider/\n",
+ " \"\"\"\n",
+ " self.data_path = args.dataset_path\n",
+ " self.dataset_name = args.dataset_name\n",
+ " self.data_loaded_in_memory = False\n",
+ " self.image_height, self.image_width, self.image_channel = args.image_height, args.image_width, args.image_channels\n",
+ " self.args = args\n",
+ " self.indexes_of_folders_indicating_class = args.indexes_of_folders_indicating_class\n",
+ " self.reverse_channels = args.reverse_channels\n",
+ " self.labels_as_int = args.labels_as_int\n",
+ " self.train_val_test_split = args.train_val_test_split\n",
+ " self.current_set_name = \"train\"\n",
+ " self.num_target_samples = args.num_target_samples\n",
+ " self.reset_stored_filepaths = args.reset_stored_filepaths\n",
+ " val_rng = np.random.RandomState(seed=args.val_seed)\n",
+ " val_seed = val_rng.randint(1, 999999)\n",
+ " train_rng = np.random.RandomState(seed=args.train_seed)\n",
+ " train_seed = train_rng.randint(1, 999999)\n",
+ " test_rng = np.random.RandomState(seed=args.val_seed)\n",
+ " test_seed = test_rng.randint(1, 999999)\n",
+ " args.val_seed = val_seed\n",
+ " args.train_seed = train_seed\n",
+ " args.test_seed = test_seed\n",
+ " self.init_seed = {\"train\": args.train_seed, \"val\": args.val_seed, 'test': args.val_seed}\n",
+ " self.seed = {\"train\": args.train_seed, \"val\": args.val_seed, 'test': args.val_seed}\n",
+ " self.num_of_gpus = args.num_of_gpus\n",
+ " self.batch_size = args.batch_size\n",
+ "\n",
+ " self.train_index = 0\n",
+ " self.val_index = 0\n",
+ " self.test_index = 0\n",
+ "\n",
+ " self.augment_images = False\n",
+ " self.num_samples_per_class = args.num_samples_per_class\n",
+ " self.num_classes_per_set = args.num_classes_per_set\n",
+ "\n",
+ " self.rng = np.random.RandomState(seed=self.seed['val'])\n",
+ " self.datasets = self.load_dataset()\n",
+ "\n",
+ " self.indexes = {\"train\": 0, \"val\": 0, 'test': 0}\n",
+ " self.dataset_size_dict = {\n",
+ " \"train\": {key: len(self.datasets['train'][key]) for key in list(self.datasets['train'].keys())},\n",
+ " \"val\": {key: len(self.datasets['val'][key]) for key in list(self.datasets['val'].keys())},\n",
+ " 'test': {key: len(self.datasets['test'][key]) for key in list(self.datasets['test'].keys())}}\n",
+ " self.label_set = self.get_label_set()\n",
+ " self.data_length = {name: np.sum([len(self.datasets[name][key])\n",
+ " for key in self.datasets[name]]) for name in self.datasets.keys()}\n",
+ "\n",
+ " print(\"data\", self.data_length)\n",
+ " self.observed_seed_set = None\n",
+ "\n",
+ " def load_dataset(self):\n",
+ " \"\"\"\n",
+ " Loads a dataset's dictionary files and splits the data according to the train_val_test_split variable stored\n",
+ " in the args object.\n",
+ " :return: Three sets, the training set, validation set and test sets (referred to as the meta-train,\n",
+ " meta-val and meta-test in the paper)\n",
+ " \"\"\"\n",
+ " rng = np.random.RandomState(seed=self.seed['val'])\n",
+ "\n",
+ " if self.args.sets_are_pre_split == True:\n",
+ " data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths()\n",
+ " dataset_splits = dict()\n",
+ " for key, value in data_image_paths.items():\n",
+ " key = self.get_label_from_index(index=key)\n",
+ " bits = key.split(\"/\")\n",
+ " set_name = bits[0]\n",
+ " class_label = bits[1]\n",
+ " if set_name not in dataset_splits:\n",
+ " dataset_splits[set_name] = {class_label: value}\n",
+ " else:\n",
+ " dataset_splits[set_name][class_label] = value\n",
+ " else:\n",
+ " data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths()\n",
+ " total_label_types = len(data_image_paths)\n",
+ " num_classes_idx = np.arange(len(data_image_paths.keys()), dtype=np.int32)\n",
+ " rng.shuffle(num_classes_idx)\n",
+ " keys = list(data_image_paths.keys())\n",
+ " values = list(data_image_paths.values())\n",
+ " new_keys = [keys[idx] for idx in num_classes_idx]\n",
+ " new_values = [values[idx] for idx in num_classes_idx]\n",
+ " data_image_paths = dict(zip(new_keys, new_values))\n",
+ " # data_image_paths = self.shuffle(data_image_paths)\n",
+ " x_train_id, x_val_id, x_test_id = int(self.train_val_test_split[0] * total_label_types), \\\n",
+ " int(np.sum(self.train_val_test_split[:2]) * total_label_types), \\\n",
+ " int(total_label_types)\n",
+ " print(x_train_id, x_val_id, x_test_id)\n",
+ " x_train_classes = (class_key for class_key in list(data_image_paths.keys())[:x_train_id])\n",
+ " x_val_classes = (class_key for class_key in list(data_image_paths.keys())[x_train_id:x_val_id])\n",
+ " x_test_classes = (class_key for class_key in list(data_image_paths.keys())[x_val_id:x_test_id])\n",
+ " x_train, x_val, x_test = {class_key: data_image_paths[class_key] for class_key in x_train_classes}, \\\n",
+ " {class_key: data_image_paths[class_key] for class_key in x_val_classes}, \\\n",
+ " {class_key: data_image_paths[class_key] for class_key in x_test_classes},\n",
+ " dataset_splits = {\"train\": x_train, \"val\":x_val , \"test\": x_test}\n",
+ "\n",
+ " if self.args.load_into_memory is True:\n",
+ "\n",
+ " print(\"Loading data into RAM\")\n",
+ " x_loaded = {\"train\": [], \"val\": [], \"test\": []}\n",
+ "\n",
+ " for set_key, set_value in dataset_splits.items():\n",
+ " print(\"Currently loading into memory the {} set\".format(set_key))\n",
+ " x_loaded[set_key] = {key: np.zeros(len(value), ) for key, value in set_value.items()}\n",
+ " # for class_key, class_value in set_value.items():\n",
+ " with tqdm.tqdm(total=len(set_value)) as pbar_memory_load:\n",
+ " with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:\n",
+ " # Process the list of files, but split the work across the process pool to use all CPUs!\n",
+ " for (class_label, class_images_loaded) in executor.map(self.load_parallel_batch, (set_value.items())):\n",
+ " x_loaded[set_key][class_label] = class_images_loaded\n",
+ " pbar_memory_load.update(1)\n",
+ "\n",
+ " dataset_splits = x_loaded\n",
+ " self.data_loaded_in_memory = True\n",
+ "\n",
+ " return dataset_splits\n",
+ "\n",
+ " def load_datapaths(self):\n",
+ " \"\"\"\n",
+ " If saved json dictionaries of the data are available, then this method loads the dictionaries such that the\n",
+ " data is ready to be read. If the json dictionaries do not exist, then this method calls get_data_paths()\n",
+ " which will build the json dictionary containing the class to filepath samples, and then store them.\n",
+ " :return: data_image_paths: dict containing class to filepath list pairs.\n",
+ " index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable\n",
+ " string-names of the class\n",
+ " label_to_index: dictionary containing human understandable string mapped to numerical indexes\n",
+ " \"\"\"\n",
+ " dataset_dir = os.environ['DATASET_DIR']\n",
+ " data_path_file = \"{}/{}.json\".format(dataset_dir, self.dataset_name)\n",
+ " print(dataset_dir)\n",
+ " print(data_path_file)\n",
+ " self.index_to_label_name_dict_file = \"{}/map_to_label_name_{}.json\".format(dataset_dir, self.dataset_name)\n",
+ " self.label_name_to_map_dict_file = \"{}/label_name_to_map_{}.json\".format(dataset_dir, self.dataset_name)\n",
+ "\n",
+ " if not os.path.exists(data_path_file):\n",
+ " self.reset_stored_filepaths = True\n",
+ "\n",
+ " if self.reset_stored_filepaths == True:\n",
+ " if os.path.exists(data_path_file):\n",
+ " os.remove(data_path_file)\n",
+ " self.reset_stored_filepaths = False\n",
+ "\n",
+ " try:\n",
+ " data_image_paths = self.load_from_json(filename=data_path_file)\n",
+ " label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file)\n",
+ " index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file)\n",
+ " print(data_image_paths, index_to_label_name_dict_file, label_to_index)\n",
+ " return data_image_paths, index_to_label_name_dict_file, label_to_index\n",
+ " except:\n",
+ " print(\"Mapped data paths can't be found, remapping paths..\")\n",
+ " data_image_paths, code_to_label_name, label_name_to_code = self.get_data_paths()\n",
+ " self.save_to_json(dict_to_store=data_image_paths, filename=data_path_file)\n",
+ " self.save_to_json(dict_to_store=code_to_label_name, filename=self.index_to_label_name_dict_file)\n",
+ " self.save_to_json(dict_to_store=label_name_to_code, filename=self.label_name_to_map_dict_file)\n",
+ " return self.load_datapaths()\n",
+ "\n",
+ " def save_to_json(self, filename, dict_to_store):\n",
+ " with open(os.path.abspath(filename), 'w') as f:\n",
+ " json.dump(dict_to_store, fp=f)\n",
+ "\n",
+ " def load_from_json(self, filename):\n",
+ " with open(filename, mode=\"r\") as f:\n",
+ " load_dict = json.load(fp=f)\n",
+ "\n",
+ " return load_dict\n",
+ "\n",
+ " def load_test_image(self, filepath):\n",
+ " \"\"\"\n",
+ " Tests whether a target filepath contains an uncorrupted image. If image is corrupted, attempt to fix.\n",
+ " :param filepath: Filepath of image to be tested\n",
+ " :return: Return filepath of image if image exists and is uncorrupted (or attempt to fix has succeeded),\n",
+ " else return None\n",
+ " \"\"\"\n",
+ " image = None\n",
+ " try:\n",
+ " image = Image.open(filepath)\n",
+ " except RuntimeWarning:\n",
+ " os.system(\"convert {} -strip {}\".format(filepath, filepath))\n",
+ " print(\"converting\")\n",
+ " image = Image.open(filepath)\n",
+ " except:\n",
+ " print(\"Broken image\")\n",
+ "\n",
+ " if image is not None:\n",
+ " return filepath\n",
+ " else:\n",
+ " return None\n",
+ "\n",
+ " def get_data_paths(self):\n",
+ " \"\"\"\n",
+ " Method that scans the dataset directory and generates class to image-filepath list dictionaries.\n",
+ " :return: data_image_paths: dict containing class to filepath list pairs.\n",
+ " index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable\n",
+ " string-names of the class\n",
+ " label_to_index: dictionary containing human understandable string mapped to numerical indexes\n",
+ " \"\"\"\n",
+ " print(\"Get images from\", self.data_path)\n",
+ " data_image_path_list_raw = []\n",
+ " labels = set()\n",
+ " for subdir, dir, files in os.walk(self.data_path):\n",
+ " for file in files:\n",
+ " if (\".jpeg\") in file.lower() or (\".png\") in file.lower() or (\".jpg\") in file.lower():\n",
+ " filepath = os.path.abspath(os.path.join(subdir, file))\n",
+ " label = self.get_label_from_path(filepath)\n",
+ " data_image_path_list_raw.append(filepath)\n",
+ " labels.add(label)\n",
+ "\n",
+ " labels = sorted(labels)\n",
+ " idx_to_label_name = {idx: label for idx, label in enumerate(labels)}\n",
+ " label_name_to_idx = {label: idx for idx, label in enumerate(labels)}\n",
+ " data_image_path_dict = {idx: [] for idx in list(idx_to_label_name.keys())}\n",
+ " with tqdm.tqdm(total=len(data_image_path_list_raw)) as pbar_error:\n",
+ " with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:\n",
+ " # Process the list of files, but split the work across the process pool to use all CPUs!\n",
+ " for image_file in executor.map(self.load_test_image, (data_image_path_list_raw)):\n",
+ " pbar_error.update(1)\n",
+ " if image_file is not None:\n",
+ " label = self.get_label_from_path(image_file)\n",
+ " data_image_path_dict[label_name_to_idx[label]].append(image_file)\n",
+ "\n",
+ " return data_image_path_dict, idx_to_label_name, label_name_to_idx\n",
+ "\n",
+ " def get_label_set(self):\n",
+ " \"\"\"\n",
+ " Generates a set containing all class numerical indexes\n",
+ " :return: A set containing all class numerical indexes\n",
+ " \"\"\"\n",
+ " index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file)\n",
+ " return set(list(index_to_label_name_dict_file.keys()))\n",
+ "\n",
+ " def get_index_from_label(self, label):\n",
+ " \"\"\"\n",
+ " Given a class's (human understandable) string, returns the numerical index of that class\n",
+ " :param label: A string of a human understandable class contained in the dataset\n",
+ " :return: An int containing the numerical index of the given class-string\n",
+ " \"\"\"\n",
+ " label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file)\n",
+ " return label_to_index[label]\n",
+ "\n",
+ " def get_label_from_index(self, index):\n",
+ " \"\"\"\n",
+ " Given an index return the human understandable label mapping to it.\n",
+ " :param index: A numerical index (int)\n",
+ " :return: A human understandable label (str)\n",
+ " \"\"\"\n",
+ " index_to_label_name = self.load_from_json(filename=self.index_to_label_name_dict_file)\n",
+ " return index_to_label_name[index]\n",
+ "\n",
+ " def get_label_from_path(self, filepath):\n",
+ " \"\"\"\n",
+ " Given a path of an image generate the human understandable label for that image.\n",
+ " :param filepath: The image's filepath\n",
+ " :return: A human understandable label.\n",
+ " \"\"\"\n",
+ " label_bits = filepath.split(\"/\")\n",
+ " label = \"/\".join([label_bits[idx] for idx in self.indexes_of_folders_indicating_class])\n",
+ " if self.labels_as_int:\n",
+ " label = int(label)\n",
+ " return label\n",
+ "\n",
+ " def load_image(self, image_path, channels):\n",
+ " \"\"\"\n",
+ " Given an image filepath and the number of channels to keep, load an image and keep the specified channels\n",
+ " :param image_path: The image's filepath\n",
+ " :param channels: The number of channels to keep\n",
+ " :return: An image array of shape (h, w, channels), whose values range between 0.0 and 1.0.\n",
+ " \"\"\"\n",
+ " if not self.data_loaded_in_memory:\n",
+ " image = Image.open(image_path)\n",
+ " if 'omniglot' in self.dataset_name:\n",
+ " image = image.resize((self.image_height, self.image_width), resample=Image.LANCZOS)\n",
+ " image = np.array(image, np.float32)\n",
+ " if channels == 1:\n",
+ " image = np.expand_dims(image, axis=2)\n",
+ " else:\n",
+ " image = image.resize((self.image_height, self.image_width)).convert('RGB')\n",
+ " image = np.array(image, np.float32)\n",
+ " image = image / 255.0\n",
+ " else:\n",
+ " image = image_path\n",
+ "\n",
+ " return image\n",
+ "\n",
+ " def load_batch(self, batch_image_paths):\n",
+ " \"\"\"\n",
+ " Load a batch of images, given a list of filepaths\n",
+ " :param batch_image_paths: A list of filepaths\n",
+ " :return: A numpy array of images of shape batch, height, width, channels\n",
+ " \"\"\"\n",
+ " image_batch = []\n",
+ "\n",
+ " if self.data_loaded_in_memory:\n",
+ " for image_path in batch_image_paths:\n",
+ " image_batch.append(image_path)\n",
+ " image_batch = np.array(image_batch, dtype=np.float32)\n",
+ " #print(image_batch.shape)\n",
+ " else:\n",
+ " image_batch = [self.load_image(image_path=image_path, channels=self.image_channel)\n",
+ " for image_path in batch_image_paths]\n",
+ " image_batch = np.array(image_batch, dtype=np.float32)\n",
+ " image_batch = self.preprocess_data(image_batch)\n",
+ "\n",
+ " return image_batch\n",
+ "\n",
+ " def load_parallel_batch(self, inputs):\n",
+ " \"\"\"\n",
+ " Load a batch of images, given a list of filepaths\n",
+ " :param batch_image_paths: A list of filepaths\n",
+ " :return: A numpy array of images of shape batch, height, width, channels\n",
+ " \"\"\"\n",
+ " class_label, batch_image_paths = inputs\n",
+ " image_batch = []\n",
+ "\n",
+ " if self.data_loaded_in_memory:\n",
+ " for image_path in batch_image_paths:\n",
+ " image_batch.append(np.copy(image_path))\n",
+ " image_batch = np.array(image_batch, dtype=np.float32)\n",
+ " else:\n",
+ " #with tqdm.tqdm(total=1) as load_pbar:\n",
+ " image_batch = [self.load_image(image_path=image_path, channels=self.image_channel)\n",
+ " for image_path in batch_image_paths]\n",
+ " #load_pbar.update(1)\n",
+ "\n",
+ " image_batch = np.array(image_batch, dtype=np.float32)\n",
+ " image_batch = self.preprocess_data(image_batch)\n",
+ "\n",
+ " return class_label, image_batch\n",
+ "\n",
+ " def preprocess_data(self, x):\n",
+ " \"\"\"\n",
+ " Preprocesses data such that their shapes match the specified structures\n",
+ " :param x: A data batch to preprocess\n",
+ " :return: A preprocessed data batch\n",
+ " \"\"\"\n",
+ " x_shape = x.shape\n",
+ " x = np.reshape(x, (-1, x_shape[-3], x_shape[-2], x_shape[-1]))\n",
+ " if self.reverse_channels is True:\n",
+ " reverse_photos = np.ones(shape=x.shape)\n",
+ " for channel in range(x.shape[-1]):\n",
+ " reverse_photos[:, :, :, x.shape[-1] - 1 - channel] = x[:, :, :, channel]\n",
+ " x = reverse_photos\n",
+ " x = x.reshape(x_shape)\n",
+ " return x\n",
+ "\n",
+ " def reconstruct_original(self, x):\n",
+ " \"\"\"\n",
+ " Applies the reverse operations that preprocess_data() applies such that the data returns to their original form\n",
+ " :param x: A batch of data to reconstruct\n",
+ " :return: A reconstructed batch of data\n",
+ " \"\"\"\n",
+ " x = x * 255.0\n",
+ " return x\n",
+ "\n",
+ " def shuffle(self, x, rng):\n",
+ " \"\"\"\n",
+ " Shuffles the data batch along it's first axis\n",
+ " :param x: A data batch\n",
+ " :return: A shuffled data batch\n",
+ " \"\"\"\n",
+ " indices = np.arange(len(x))\n",
+ " rng.shuffle(indices)\n",
+ " x = x[indices]\n",
+ " return x\n",
+ "\n",
+ " def get_set(self, dataset_name, seed, augment_images=False, step):\n",
+ " \"\"\"\n",
+ " Generates a task-set to be used for training or evaluation\n",
+ " :param set_name: The name of the set to use, e.g. \"train\", \"val\" etc.\n",
+ " :return: A task-set containing an image and label support set, and an image and label target set.\n",
+ " \"\"\"\n",
+ " #seed = seed % self.args.total_unique_tasks\n",
+ " rng = np.random.RandomState(seed)\n",
+ " selected_classes = rng.choice(list(self.dataset_size_dict[dataset_name].keys()),\n",
+ " size=self.num_classes_per_set, replace=False)\n",
+ " rng.shuffle(selected_classes)\n",
+ " k_list = rng.randint(0, 4, size=self.num_classes_per_set)\n",
+ " k_dict = {selected_class: k_item for (selected_class, k_item) in zip(selected_classes, k_list)}\n",
+ " episode_labels = [i for i in range(self.num_classes_per_set)]\n",
+ " class_to_episode_label = {selected_class: episode_label for (selected_class, episode_label) in\n",
+ " zip(selected_classes, episode_labels)}\n",
+ "\n",
+ " x_images = []\n",
+ " y_labels = []\n",
+ "\n",
+ " for class_entry in selected_classes:\n",
+ " choose_samples_list = rng.choice(self.dataset_size_dict[dataset_name][class_entry],\n",
+ " size=self.num_samples_per_class + self.num_target_samples, replace=False)\n",
+ " class_image_samples = []\n",
+ " class_labels = []\n",
+ " for sample in choose_samples_list:\n",
+ " choose_samples = self.datasets[dataset_name][class_entry][sample]\n",
+ " x_class_data = self.load_batch([choose_samples])[0]\n",
+ " k = k_dict[class_entry]\n",
+ " x_class_data = augment_image(image=x_class_data, k=k,\n",
+ " channels=self.image_channel, augment_bool=augment_images,\n",
+ " dataset_name=self.dataset_name, args=self.args)\n",
+ " class_image_samples.append(x_class_data)\n",
+ " class_labels.append(int(class_to_episode_label[class_entry]))\n",
+ " class_image_samples = torch.stack(class_image_samples)\n",
+ " x_images.append(class_image_samples)\n",
+ " y_labels.append(class_labels)\n",
+ "\n",
+ " x_images = torch.stack(x_images)\n",
+ " y_labels = np.array(y_labels, dtype=np.float32)\n",
+ "\n",
+ " support_set_images = x_images[:, :self.num_samples_per_class]\n",
+ " support_set_labels = y_labels[:, :self.num_samples_per_class]\n",
+ " target_set_images = x_images[:, self.num_samples_per_class:]\n",
+ " target_set_labels = y_labels[:, self.num_samples_per_class:]\n",
+ " writer.add_images('support_set_images', support_set_images, step)\n",
+ " writer.add_images('target_set_images', target_set_image, step)\n",
+ " \n",
+ " return support_set_images, target_set_images, support_set_labels, target_set_labels, seed\n",
+ "\n",
+ " def __len__(self):\n",
+ " total_samples = self.data_length[self.current_set_name]\n",
+ " return total_samples\n",
+ "\n",
+ " def length(self, set_name):\n",
+ " self.switch_set(set_name=set_name)\n",
+ " return len(self)\n",
+ "\n",
+ " def set_augmentation(self, augment_images):\n",
+ " self.augment_images = augment_images\n",
+ "\n",
+ " def switch_set(self, set_name, current_iter=None):\n",
+ " self.current_set_name = set_name\n",
+ " if set_name == \"train\":\n",
+ " self.update_seed(dataset_name=set_name, seed=self.init_seed[set_name] + current_iter)\n",
+ "\n",
+ " def update_seed(self, dataset_name, seed=100):\n",
+ " self.seed[dataset_name] = seed\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " \n",
+ " support_set_images, target_set_image, support_set_labels, target_set_label, seed = \\\n",
+ " self.get_set(self.current_set_name, seed=self.seed[self.current_set_name] + idx,\n",
+ " augment_images=self.augment_images, idx)\n",
+ " \n",
+ " return support_set_images, target_set_image, support_set_labels, target_set_label, seed\n",
+ "\n",
+ " def reset_seed(self):\n",
+ " self.seed = self.init_seed"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "error",
+ "ename": "SyntaxError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m367\u001b[0m\n\u001b[0;31m def get_set(self, dataset_name, seed, augment_images=False, step):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m non-default argument follows default argument\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "PzzTQac1xOJW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaLearningSystemDataLoader(object):\n",
+ " def __init__(self, args, current_iter=0):\n",
+ " \"\"\"\n",
+ " Initializes a meta learning system dataloader. The data loader uses the Pytorch DataLoader class to parallelize\n",
+ " batch sampling and preprocessing.\n",
+ " :param args: An arguments NamedTuple containing all the required arguments.\n",
+ " :param current_iter: Current iter of experiment. Is used to make sure the data loader continues where it left\n",
+ " of previously.\n",
+ " \"\"\"\n",
+ " self.num_of_gpus = args.num_of_gpus\n",
+ " self.batch_size = args.batch_size\n",
+ " self.samples_per_iter = args.samples_per_iter\n",
+ " self.num_workers = args.num_dataprovider_workers\n",
+ " self.total_train_iters_produced = 0\n",
+ " self.dataset = FewShotLearningDatasetParallel(args=args)\n",
+ " self.batches_per_iter = args.samples_per_iter\n",
+ " self.full_data_length = self.dataset.data_length\n",
+ " self.continue_from_iter(current_iter=current_iter)\n",
+ " self.args = args\n",
+ "\n",
+ " def get_dataloader(self):\n",
+ " \"\"\"\n",
+ " Returns a data loader with the correct set (train, val or test), continuing from the current iter.\n",
+ " :return:\n",
+ " \"\"\"\n",
+ " return DataLoader(self.dataset, batch_size=(self.num_of_gpus * self.batch_size * self.samples_per_iter),\n",
+ " shuffle=False, num_workers=self.num_workers, drop_last=True)\n",
+ "\n",
+ " def continue_from_iter(self, current_iter):\n",
+ " \"\"\"\n",
+ " Makes sure the data provider is aware of where we are in terms of training iterations in the experiment.\n",
+ " :param current_iter:\n",
+ " \"\"\"\n",
+ " self.total_train_iters_produced += (current_iter * (self.num_of_gpus * self.batch_size * self.samples_per_iter))\n",
+ "\n",
+ " def get_train_batches(self, total_batches=-1, augment_images=False):\n",
+ " \"\"\"\n",
+ " Returns a training batches data_loader\n",
+ " :param total_batches: The number of batches we want the data loader to sample\n",
+ " :param augment_images: Whether we want the images to be augmented.\n",
+ " \"\"\"\n",
+ " if total_batches == -1:\n",
+ " self.dataset.data_length = self.full_data_length\n",
+ " else:\n",
+ " self.dataset.data_length[\"train\"] = total_batches * self.dataset.batch_size\n",
+ " self.dataset.switch_set(set_name=\"train\", current_iter=self.total_train_iters_produced)\n",
+ " self.dataset.set_augmentation(augment_images=augment_images)\n",
+ " self.total_train_iters_produced += (self.num_of_gpus * self.batch_size * self.samples_per_iter)\n",
+ " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n",
+ " yield sample_batched\n",
+ "\n",
+ "\n",
+ " def get_val_batches(self, total_batches=-1, augment_images=False):\n",
+ " \"\"\"\n",
+ " Returns a validation batches data_loader\n",
+ " :param total_batches: The number of batches we want the data loader to sample\n",
+ " :param augment_images: Whether we want the images to be augmented.\n",
+ " \"\"\"\n",
+ " if total_batches == -1:\n",
+ " self.dataset.data_length = self.full_data_length\n",
+ " else:\n",
+ " self.dataset.data_length['val'] = total_batches * self.dataset.batch_size\n",
+ " self.dataset.switch_set(set_name=\"val\")\n",
+ " self.dataset.set_augmentation(augment_images=augment_images)\n",
+ " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n",
+ " yield sample_batched\n",
+ "\n",
+ "\n",
+ " def get_test_batches(self, total_batches=-1, augment_images=False):\n",
+ " \"\"\"\n",
+ " Returns a testing batches data_loader\n",
+ " :param total_batches: The number of batches we want the data loader to sample\n",
+ " :param augment_images: Whether we want the images to be augmented.\n",
+ " \"\"\"\n",
+ " if total_batches == -1:\n",
+ " self.dataset.data_length = self.full_data_length\n",
+ " else:\n",
+ " self.dataset.data_length['test'] = total_batches * self.dataset.batch_size\n",
+ " self.dataset.switch_set(set_name='test')\n",
+ " self.dataset.set_augmentation(augment_images=augment_images)\n",
+ " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n",
+ " yield sample_batched"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "HCshYy8PvXly",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def extract_top_level_dict(current_dict):\n",
+ " \"\"\"\n",
+ " Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params\n",
+ " :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree.\n",
+ " :param value: Param value\n",
+ " :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it.\n",
+ " :return: A dictionary graph of the params already added to the graph.\n",
+ " \"\"\"\n",
+ " output_dict = dict()\n",
+ " for key in current_dict.keys():\n",
+ " name = key.replace(\"layer_dict.\", \"\")\n",
+ " top_level = name.split(\".\")[0]\n",
+ " sub_level = \".\".join(name.split(\".\")[1:])\n",
+ "\n",
+ " if top_level not in output_dict:\n",
+ " if sub_level == \"\":\n",
+ " output_dict[top_level] = current_dict[key]\n",
+ " else:\n",
+ " output_dict[top_level] = {sub_level: current_dict[key]}\n",
+ " else:\n",
+ " new_item = {key: value for key, value in output_dict[top_level].items()}\n",
+ " new_item[sub_level] = current_dict[key]\n",
+ " output_dict[top_level] = new_item\n",
+ "\n",
+ " #print(current_dict.keys(), output_dict.keys())\n",
+ " return output_dict\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "brbh1ISiqIwN",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaConv2dLayer(nn.Module):\n",
+ " def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_bias, groups=1, dilation_rate=1):\n",
+ " \"\"\"\n",
+ " A MetaConv2D layer. Applies the same functionality of a standard Conv2D layer with the added functionality of\n",
+ " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n",
+ " weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta\n",
+ " learning setting.\n",
+ " :param in_channels: Number of input channels\n",
+ " :param out_channels: Number of output channels\n",
+ " :param kernel_size: Convolutional kernel size\n",
+ " :param stride: Convolutional stride\n",
+ " :param padding: Convolution padding\n",
+ " :param use_bias: Boolean indicating whether to use a bias or not.\n",
+ " \"\"\"\n",
+ " super(MetaConv2dLayer, self).__init__()\n",
+ " num_filters = out_channels\n",
+ " self.stride = int(stride)\n",
+ " self.padding = int(padding)\n",
+ " self.dilation_rate = int(dilation_rate)\n",
+ " self.use_bias = use_bias\n",
+ " self.groups = int(groups)\n",
+ " self.weight = nn.Parameter(torch.empty(num_filters, in_channels, kernel_size, kernel_size))\n",
+ " nn.init.xavier_uniform_(self.weight)\n",
+ "\n",
+ " if self.use_bias:\n",
+ " self.bias = nn.Parameter(torch.zeros(num_filters))\n",
+ "\n",
+ " def forward(self, x, params=None):\n",
+ " \"\"\"\n",
+ " Applies a conv2D forward pass. If params are not None will use the passed params as the conv weights and biases\n",
+ " :param x: Input image batch.\n",
+ " :param params: If none, then conv layer will use the stored self.weights and self.bias, if they are not none\n",
+ " then the conv layer will use the passed params as its parameters.\n",
+ " :return: The output of a convolutional function.\n",
+ " \"\"\"\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ " if self.use_bias:\n",
+ " (weight, bias) = params[\"weight\"], params[\"bias\"]\n",
+ " else:\n",
+ " (weight) = params[\"weight\"]\n",
+ " bias = None\n",
+ " else:\n",
+ " #print(\"No inner loop params\")\n",
+ " if self.use_bias:\n",
+ " weight, bias = self.weight, self.bias\n",
+ " else:\n",
+ " weight = self.weight\n",
+ " bias = None\n",
+ "\n",
+ " out = F.conv2d(input=x, weight=weight, bias=bias, stride=self.stride,\n",
+ " padding=self.padding, dilation=self.dilation_rate, groups=self.groups)\n",
+ " return out\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "5_yy9Vluuu1E",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaBatchNormLayer(nn.Module):\n",
+ " def __init__(self, num_features, device, args, eps=1e-5, momentum=0.1, affine=True,\n",
+ " track_running_stats=True, meta_batch_norm=True, no_learnable_params=False,\n",
+ " use_per_step_bn_statistics=False):\n",
+ " \"\"\"\n",
+ " A MetaBatchNorm layer. Applies the same functionality of a standard BatchNorm layer with the added functionality of\n",
+ " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n",
+ " weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta\n",
+ " learning setting. Also has the additional functionality of being able to store per step running stats and per step beta and gamma.\n",
+ " :param num_features:\n",
+ " :param device:\n",
+ " :param args:\n",
+ " :param eps:\n",
+ " :param momentum:\n",
+ " :param affine:\n",
+ " :param track_running_stats:\n",
+ " :param meta_batch_norm:\n",
+ " :param no_learnable_params:\n",
+ " :param use_per_step_bn_statistics:\n",
+ " \"\"\"\n",
+ " super(MetaBatchNormLayer, self).__init__()\n",
+ " self.num_features = num_features\n",
+ " self.eps = eps\n",
+ "\n",
+ " self.affine = affine\n",
+ " self.track_running_stats = track_running_stats\n",
+ " self.meta_batch_norm = meta_batch_norm\n",
+ " self.num_features = num_features\n",
+ " self.device = device\n",
+ " self.use_per_step_bn_statistics = use_per_step_bn_statistics\n",
+ " self.args = args\n",
+ " self.learnable_gamma = self.args.learnable_bn_gamma\n",
+ " self.learnable_beta = self.args.learnable_bn_beta\n",
+ "\n",
+ " if use_per_step_bn_statistics:\n",
+ " self.running_mean = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features),\n",
+ " requires_grad=False)\n",
+ " self.running_var = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features),\n",
+ " requires_grad=False)\n",
+ " self.bias = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features),\n",
+ " requires_grad=self.learnable_beta)\n",
+ " self.weight = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features),\n",
+ " requires_grad=self.learnable_gamma)\n",
+ " else:\n",
+ " self.running_mean = nn.Parameter(torch.zeros(num_features), requires_grad=False)\n",
+ " self.running_var = nn.Parameter(torch.zeros(num_features), requires_grad=False)\n",
+ " self.bias = nn.Parameter(torch.zeros(num_features),\n",
+ " requires_grad=self.learnable_beta)\n",
+ " self.weight = nn.Parameter(torch.ones(num_features),\n",
+ " requires_grad=self.learnable_gamma)\n",
+ "\n",
+ " if self.args.enable_inner_loop_optimizable_bn_params:\n",
+ " self.bias = nn.Parameter(torch.zeros(num_features),\n",
+ " requires_grad=self.learnable_beta)\n",
+ " self.weight = nn.Parameter(torch.ones(num_features),\n",
+ " requires_grad=self.learnable_gamma)\n",
+ "\n",
+ " self.backup_running_mean = torch.zeros(self.running_mean.shape)\n",
+ " self.backup_running_var = torch.ones(self.running_var.shape)\n",
+ "\n",
+ " self.momentum = momentum\n",
+ "\n",
+ " def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False):\n",
+ " \"\"\"\n",
+ " Forward propagates by applying a bach norm function. If params are none then internal params are used.\n",
+ " Otherwise passed params will be used to execute the function.\n",
+ " :param input: input data batch, size either can be any.\n",
+ " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n",
+ " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n",
+ " :param params: A dictionary containing 'weight' and 'bias'.\n",
+ " :param training: Whether this is currently the training or evaluation phase.\n",
+ " :param backup_running_statistics: Whether to backup the running statistics. This is used\n",
+ " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n",
+ " :return: The result of the batch norm operation.\n",
+ " \"\"\"\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ " (weight, bias) = params[\"weight\"], params[\"bias\"]\n",
+ " #print(num_step, params['weight'])\n",
+ " else:\n",
+ " #print(num_step, \"no params\")\n",
+ " weight, bias = self.weight, self.bias\n",
+ "\n",
+ " if self.use_per_step_bn_statistics:\n",
+ " running_mean = self.running_mean[num_step]\n",
+ " running_var = self.running_var[num_step]\n",
+ " if params is None:\n",
+ " if not self.args.enable_inner_loop_optimizable_bn_params:\n",
+ " bias = self.bias[num_step]\n",
+ " weight = self.weight[num_step]\n",
+ " else:\n",
+ " running_mean = None\n",
+ " running_var = None\n",
+ "\n",
+ "\n",
+ " if backup_running_statistics and self.use_per_step_bn_statistics:\n",
+ " self.backup_running_mean.data = copy(self.running_mean.data)\n",
+ " self.backup_running_var.data = copy(self.running_var.data)\n",
+ "\n",
+ " momentum = self.momentum\n",
+ "\n",
+ " output = F.batch_norm(input, running_mean, running_var, weight, bias,\n",
+ " training=True, momentum=momentum, eps=self.eps)\n",
+ "\n",
+ " return output\n",
+ "\n",
+ " def restore_backup_stats(self):\n",
+ " \"\"\"\n",
+ " Resets batch statistics to their backup values which are collected after each forward pass.\n",
+ " \"\"\"\n",
+ " if self.use_per_step_bn_statistics:\n",
+ " self.running_mean = nn.Parameter(self.backup_running_mean.to(device=self.device), requires_grad=False)\n",
+ " self.running_var = nn.Parameter(self.backup_running_var.to(device=self.device), requires_grad=False)\n",
+ "\n",
+ " def extra_repr(self):\n",
+ " return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \\\n",
+ " 'track_running_stats={track_running_stats}'.format(**self.__dict__)\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LsbTVhQ9vply",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaLayerNormLayer(nn.Module):\n",
+ " def __init__(self, input_feature_shape, eps=1e-5, elementwise_affine=True):\n",
+ " \"\"\"\n",
+ " A MetaLayerNorm layer. A layer that applies the same functionality as a layer norm layer with the added\n",
+ " capability of being able to receive params at inference time to use instead of the internal ones. As well as\n",
+ " being able to use its own internal weights.\n",
+ " :param input_feature_shape: The input shape without the batch dimension, e.g. c, h, w\n",
+ " :param eps: Epsilon to use for protection against overflows\n",
+ " :param elementwise_affine: Whether to learn a multiplicative interaction parameter 'w' in addition to\n",
+ " the biases.\n",
+ " \"\"\"\n",
+ " super(MetaLayerNormLayer, self).__init__()\n",
+ " if isinstance(input_feature_shape, numbers.Integral):\n",
+ " input_feature_shape = (input_feature_shape,)\n",
+ " self.normalized_shape = torch.Size(input_feature_shape)\n",
+ " self.eps = eps\n",
+ " self.elementwise_affine = elementwise_affine\n",
+ " if self.elementwise_affine:\n",
+ " self.weight = nn.Parameter(torch.Tensor(*input_feature_shape), requires_grad=False)\n",
+ " self.bias = nn.Parameter(torch.Tensor(*input_feature_shape))\n",
+ " else:\n",
+ " self.register_parameter('weight', None)\n",
+ " self.register_parameter('bias', None)\n",
+ " self.reset_parameters()\n",
+ "\n",
+ " def reset_parameters(self):\n",
+ " \"\"\"\n",
+ " Reset parameters to their initialization values.\n",
+ " \"\"\"\n",
+ " if self.elementwise_affine:\n",
+ " self.weight.data.fill_(1)\n",
+ " self.bias.data.zero_()\n",
+ "\n",
+ " def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False):\n",
+ " \"\"\"\n",
+ " Forward propagates by applying a layer norm function. If params are none then internal params are used.\n",
+ " Otherwise passed params will be used to execute the function.\n",
+ " :param input: input data batch, size either can be any.\n",
+ " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n",
+ " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n",
+ " :param params: A dictionary containing 'weight' and 'bias'.\n",
+ " :param training: Whether this is currently the training or evaluation phase.\n",
+ " :param backup_running_statistics: Whether to backup the running statistics. This is used\n",
+ " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n",
+ " :return: The result of the batch norm operation.\n",
+ " \"\"\"\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ " bias = params[\"bias\"]\n",
+ " else:\n",
+ " bias = self.bias\n",
+ " #print('no inner loop params', self)\n",
+ "\n",
+ " return F.layer_norm(\n",
+ " input, self.normalized_shape, self.weight, bias, self.eps)\n",
+ "\n",
+ " def restore_backup_stats(self):\n",
+ " pass\n",
+ "\n",
+ " def extra_repr(self):\n",
+ " return '{normalized_shape}, eps={eps}, ' \\\n",
+ " 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "-Up-rbPipql9",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaConvNormLayerReLU(nn.Module):\n",
+ " def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True,\n",
+ " meta_layer=True, no_bn_learnable_params=False, device=None):\n",
+ " \"\"\"\n",
+ " Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order.\n",
+ " :param args: A named tuple containing the system's hyperparameters.\n",
+ " :param device: The device to run the layer on.\n",
+ " :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm'\n",
+ " :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm,\n",
+ " meta-conv etc.\n",
+ " :param input_shape: The image input shape in the form (b, c, h, w)\n",
+ " :param num_filters: number of filters for convolutional layer\n",
+ " :param kernel_size: the kernel size of the convolutional layer\n",
+ " :param stride: the stride of the convolutional layer\n",
+ " :param padding: the bias of the convolutional layer\n",
+ " :param use_bias: whether the convolutional layer utilizes a bias\n",
+ " \"\"\"\n",
+ " super(MetaConvNormLayerReLU, self).__init__()\n",
+ " self.normalization = normalization\n",
+ " self.use_per_step_bn_statistics = args.per_step_bn_statistics\n",
+ " self.input_shape = input_shape\n",
+ " self.args = args\n",
+ " self.num_filters = num_filters\n",
+ " self.kernel_size = kernel_size\n",
+ " self.stride = stride\n",
+ " self.padding = padding\n",
+ " self.use_bias = use_bias\n",
+ " self.meta_layer = meta_layer\n",
+ " self.no_bn_learnable_params = no_bn_learnable_params\n",
+ " self.device = device\n",
+ " self.layer_dict = nn.ModuleDict()\n",
+ " self.build_block()\n",
+ "\n",
+ " def build_block(self):\n",
+ "\n",
+ " x = torch.zeros(self.input_shape)\n",
+ "\n",
+ " out = x\n",
+ "\n",
+ " self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters,\n",
+ " kernel_size=self.kernel_size,\n",
+ " stride=self.stride, padding=self.padding, use_bias=self.use_bias)\n",
+ "\n",
+ "\n",
+ "\n",
+ " out = self.conv(out)\n",
+ "\n",
+ " if self.normalization:\n",
+ " if self.args.norm_layer == \"batch_norm\":\n",
+ " self.norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True,\n",
+ " meta_batch_norm=self.meta_layer,\n",
+ " no_learnable_params=self.no_bn_learnable_params,\n",
+ " device=self.device,\n",
+ " use_per_step_bn_statistics=self.use_per_step_bn_statistics,\n",
+ " args=self.args)\n",
+ " elif self.args.norm_layer == \"layer_norm\":\n",
+ " self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:])\n",
+ "\n",
+ " out = self.norm_layer(out, num_step=0)\n",
+ "\n",
+ " out = F.leaky_relu(out)\n",
+ "\n",
+ " print(out.shape)\n",
+ "\n",
+ " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n",
+ " \"\"\"\n",
+ " Forward propagates by applying the function. If params are none then internal params are used.\n",
+ " Otherwise passed params will be used to execute the function.\n",
+ " :param input: input data batch, size either can be any.\n",
+ " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n",
+ " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n",
+ " :param params: A dictionary containing 'weight' and 'bias'.\n",
+ " :param training: Whether this is currently the training or evaluation phase.\n",
+ " :param backup_running_statistics: Whether to backup the running statistics. This is used\n",
+ " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n",
+ " :return: The result of the batch norm operation.\n",
+ " \"\"\"\n",
+ " batch_norm_params = None\n",
+ " conv_params = None\n",
+ " activation_function_pre_params = None\n",
+ "\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ "\n",
+ " if self.normalization:\n",
+ " if 'norm_layer' in params:\n",
+ " batch_norm_params = params['norm_layer']\n",
+ "\n",
+ " if 'activation_function_pre' in params:\n",
+ " activation_function_pre_params = params['activation_function_pre']\n",
+ "\n",
+ " conv_params = params['conv']\n",
+ "\n",
+ " out = x\n",
+ "\n",
+ "\n",
+ " out = self.conv(out, params=conv_params)\n",
+ "\n",
+ " if self.normalization:\n",
+ " out = self.norm_layer.forward(out, num_step=num_step,\n",
+ " params=batch_norm_params, training=training,\n",
+ " backup_running_statistics=backup_running_statistics)\n",
+ "\n",
+ " out = F.leaky_relu(out)\n",
+ "\n",
+ " return out\n",
+ "\n",
+ " def restore_backup_stats(self):\n",
+ " \"\"\"\n",
+ " Restore stored statistics from the backup, replacing the current ones.\n",
+ " \"\"\"\n",
+ " if self.normalization:\n",
+ " self.norm_layer.restore_backup_stats()\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qujCuA12uLby",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaNormLayerConvReLU(nn.Module):\n",
+ " def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True,\n",
+ " meta_layer=True, no_bn_learnable_params=False, device=None):\n",
+ " \"\"\"\n",
+ " Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order.\n",
+ " :param args: A named tuple containing the system's hyperparameters.\n",
+ " :param device: The device to run the layer on.\n",
+ " :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm'\n",
+ " :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm,\n",
+ " meta-conv etc.\n",
+ " :param input_shape: The image input shape in the form (b, c, h, w)\n",
+ " :param num_filters: number of filters for convolutional layer\n",
+ " :param kernel_size: the kernel size of the convolutional layer\n",
+ " :param stride: the stride of the convolutional layer\n",
+ " :param padding: the bias of the convolutional layer\n",
+ " :param use_bias: whether the convolutional layer utilizes a bias\n",
+ " \"\"\"\n",
+ " super(MetaNormLayerConvReLU, self).__init__()\n",
+ " self.normalization = normalization\n",
+ " self.use_per_step_bn_statistics = args.per_step_bn_statistics\n",
+ " self.input_shape = input_shape\n",
+ " self.args = args\n",
+ " self.num_filters = num_filters\n",
+ " self.kernel_size = kernel_size\n",
+ " self.stride = stride\n",
+ " self.padding = padding\n",
+ " self.use_bias = use_bias\n",
+ " self.meta_layer = meta_layer\n",
+ " self.no_bn_learnable_params = no_bn_learnable_params\n",
+ " self.device = device\n",
+ " self.layer_dict = nn.ModuleDict()\n",
+ " self.build_block()\n",
+ "\n",
+ " def build_block(self):\n",
+ "\n",
+ " x = torch.zeros(self.input_shape)\n",
+ "\n",
+ " out = x\n",
+ " if self.normalization:\n",
+ " if self.args.norm_layer == \"batch_norm\":\n",
+ " self.norm_layer = MetaBatchNormLayer(self.input_shape[1], track_running_stats=True,\n",
+ " meta_batch_norm=self.meta_layer,\n",
+ " no_learnable_params=self.no_bn_learnable_params,\n",
+ " device=self.device,\n",
+ " use_per_step_bn_statistics=self.use_per_step_bn_statistics,\n",
+ " args=self.args)\n",
+ " elif self.args.norm_layer == \"layer_norm\":\n",
+ " self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:])\n",
+ "\n",
+ " out = self.norm_layer.forward(out, num_step=0)\n",
+ " self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters,\n",
+ " kernel_size=self.kernel_size,\n",
+ " stride=self.stride, padding=self.padding, use_bias=self.use_bias)\n",
+ "\n",
+ "\n",
+ " self.layer_dict['activation_function_pre'] = nn.LeakyReLU()\n",
+ "\n",
+ "\n",
+ " out = self.layer_dict['activation_function_pre'].forward(self.conv.forward(out))\n",
+ " print(out.shape)\n",
+ "\n",
+ " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n",
+ " \"\"\"\n",
+ " Forward propagates by applying the function. If params are none then internal params are used.\n",
+ " Otherwise passed params will be used to execute the function.\n",
+ " :param input: input data batch, size either can be any.\n",
+ " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n",
+ " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n",
+ " :param params: A dictionary containing 'weight' and 'bias'.\n",
+ " :param training: Whether this is currently the training or evaluation phase.\n",
+ " :param backup_running_statistics: Whether to backup the running statistics. This is used\n",
+ " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n",
+ " :return: The result of the batch norm operation.\n",
+ " \"\"\"\n",
+ " batch_norm_params = None\n",
+ "\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ "\n",
+ " if self.normalization:\n",
+ " if 'norm_layer' in params:\n",
+ " batch_norm_params = params['norm_layer']\n",
+ "\n",
+ " conv_params = params['conv']\n",
+ " else:\n",
+ " conv_params = None\n",
+ " #print('no inner loop params', self)\n",
+ "\n",
+ " out = x\n",
+ "\n",
+ " if self.normalization:\n",
+ " out = self.norm_layer.forward(out, num_step=num_step,\n",
+ " params=batch_norm_params, training=training,\n",
+ " backup_running_statistics=backup_running_statistics)\n",
+ "\n",
+ " out = self.conv.forward(out, params=conv_params)\n",
+ " out = self.layer_dict['activation_function_pre'].forward(out)\n",
+ "\n",
+ " return out\n",
+ "\n",
+ " def restore_backup_stats(self):\n",
+ " \"\"\"\n",
+ " Restore stored statistics from the backup, replacing the current ones.\n",
+ " \"\"\"\n",
+ " if self.normalization:\n",
+ " self.norm_layer.restore_backup_stats()\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LLwmEEAgv5EO",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MetaLinearLayer(nn.Module):\n",
+ " def __init__(self, input_shape, num_filters, use_bias):\n",
+ " \"\"\"\n",
+ " A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of\n",
+ " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n",
+ " weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta\n",
+ " learning setting.\n",
+ " :param input_shape: The shape of the input data, in the form (b, f)\n",
+ " :param num_filters: Number of output filters\n",
+ " :param use_bias: Whether to use biases or not.\n",
+ " \"\"\"\n",
+ " super(MetaLinearLayer, self).__init__()\n",
+ " b, c = input_shape\n",
+ "\n",
+ " self.use_bias = use_bias\n",
+ " self.weights = nn.Parameter(torch.ones(num_filters, c))\n",
+ " nn.init.xavier_uniform_(self.weights)\n",
+ " if self.use_bias:\n",
+ " self.bias = nn.Parameter(torch.zeros(num_filters))\n",
+ "\n",
+ " def forward(self, x, params=None):\n",
+ " \"\"\"\n",
+ " Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used.\n",
+ " Otherwise passed params will be used to execute the function.\n",
+ " :param x: Input data batch, in the form (b, f)\n",
+ " :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used.\n",
+ " Otherwise the external are used.\n",
+ " :return: The result of the linear function.\n",
+ " \"\"\"\n",
+ " if params is not None:\n",
+ " params = extract_top_level_dict(current_dict=params)\n",
+ " if self.use_bias:\n",
+ " (weight, bias) = params[\"weights\"], params[\"bias\"]\n",
+ " else:\n",
+ " (weight) = params[\"weights\"]\n",
+ " bias = None\n",
+ " else:\n",
+ " pass\n",
+ " #print('no inner loop params', self)\n",
+ "\n",
+ " if self.use_bias:\n",
+ " weight, bias = self.weights, self.bias\n",
+ " else:\n",
+ " weight = self.weights\n",
+ " bias = None\n",
+ " # print(x.shape)\n",
+ " out = F.linear(input=x, weight=weight, bias=bias)\n",
+ " return out"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "FoWzgA2qtpT9",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class VGGReLUNormNetwork(nn.Module):\n",
+ " def __init__(self, im_shape, num_output_classes, args, device, meta_classifier=True):\n",
+ " \"\"\"\n",
+ " Builds a multilayer convolutional network. It also provides functionality for passing external parameters\n",
+ " to be\n",
+ " used at inference time. Enables inner loop optimization readily.\n",
+ " :param im_shape: The input image batch shape.\n",
+ " :param num_output_classes: The number of output classes of the network.\n",
+ " :param args: A named tuple containing the system's hyperparameters.\n",
+ " :param device: The device to run this on.\n",
+ " :param meta_classifier: A flag indicating whether the system's meta-learning (inner-loop) functionalities\n",
+ " should\n",
+ " be enabled.\n",
+ " \"\"\"\n",
+ " super(VGGReLUNormNetwork, self).__init__()\n",
+ " b, c, self.h, self.w = im_shape\n",
+ " self.device = device\n",
+ " self.total_layers = 0\n",
+ " self.args = args\n",
+ " self.upscale_shapes = []\n",
+ " self.cnn_filters = args.cnn_num_filters\n",
+ " self.input_shape = list(im_shape)\n",
+ " self.num_stages = args.num_stages\n",
+ " self.num_output_classes = num_output_classes\n",
+ "\n",
+ " if args.max_pooling:\n",
+ " print(\"Using max pooling\")\n",
+ " self.conv_stride = 1\n",
+ " else:\n",
+ " print(\"Using strided convolutions\")\n",
+ " self.conv_stride = 2\n",
+ " self.meta_classifier = meta_classifier\n",
+ "\n",
+ " self.build_network()\n",
+ " print(\"meta network params\")\n",
+ " for name, param in self.named_parameters():\n",
+ " print(name, param.shape)\n",
+ "\n",
+ " def build_network(self):\n",
+ " \"\"\"\n",
+ " Builds the network before inference is required by creating some dummy inputs with the same input as the\n",
+ " self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and\n",
+ " sets output shapes for each layer.\n",
+ " \"\"\"\n",
+ " x = torch.zeros(self.input_shape)\n",
+ " out = x\n",
+ " self.layer_dict = nn.ModuleDict()\n",
+ " self.upscale_shapes.append(x.shape)\n",
+ "\n",
+ " for i in range(self.num_stages):\n",
+ " self.layer_dict['conv{}'.format(i)] = MetaConvNormLayerReLU(input_shape=out.shape,\n",
+ " num_filters=self.cnn_filters,\n",
+ " kernel_size=3, stride=self.conv_stride,\n",
+ " padding=self.args.conv_padding,\n",
+ " use_bias=True, args=self.args,\n",
+ " normalization=True,\n",
+ " meta_layer=self.meta_classifier,\n",
+ " no_bn_learnable_params=False,\n",
+ " device=self.device)\n",
+ " out = self.layer_dict['conv{}'.format(i)](out, training=True, num_step=0)\n",
+ "\n",
+ " if self.args.max_pooling:\n",
+ " out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0)\n",
+ "\n",
+ "\n",
+ " if not self.args.max_pooling:\n",
+ " out = F.avg_pool2d(out, out.shape[2])\n",
+ "\n",
+ " self.encoder_features_shape = list(out.shape)\n",
+ " out = out.view(out.shape[0], -1)\n",
+ "\n",
+ " self.layer_dict['linear'] = MetaLinearLayer(input_shape=(out.shape[0], np.prod(out.shape[1:])),\n",
+ " num_filters=self.num_output_classes, use_bias=True)\n",
+ "\n",
+ " out = self.layer_dict['linear'](out)\n",
+ " print(\"VGGNetwork build\", out.shape)\n",
+ "\n",
+ " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n",
+ " \"\"\"\n",
+ " Forward propages through the network. If any params are passed then they are used instead of stored params.\n",
+ " :param x: Input image batch.\n",
+ " :param num_step: The current inner loop step number\n",
+ " :param params: If params are None then internal parameters are used. If params are a dictionary with keys the\n",
+ " same as the layer names then they will be used instead.\n",
+ " :param training: Whether this is training (True) or eval time.\n",
+ " :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is\n",
+ " then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics)\n",
+ " :return: Logits of shape b, num_output_classes.\n",
+ " \"\"\"\n",
+ " param_dict = dict()\n",
+ "\n",
+ " if params is not None:\n",
+ " param_dict = extract_top_level_dict(current_dict=params)\n",
+ "\n",
+ " # print('top network', param_dict.keys())\n",
+ " for name, param in self.layer_dict.named_parameters():\n",
+ " path_bits = name.split(\".\")\n",
+ " layer_name = path_bits[0]\n",
+ " if layer_name not in param_dict:\n",
+ " param_dict[layer_name] = None\n",
+ "\n",
+ " out = x\n",
+ "\n",
+ " for i in range(self.num_stages):\n",
+ " out = self.layer_dict['conv{}'.format(i)](out, params=param_dict['conv{}'.format(i)], training=training,\n",
+ " backup_running_statistics=backup_running_statistics,\n",
+ " num_step=num_step)\n",
+ " if self.args.max_pooling:\n",
+ " out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0)\n",
+ "\n",
+ " if not self.args.max_pooling:\n",
+ " out = F.avg_pool2d(out, out.shape[2])\n",
+ "\n",
+ " out = out.view(out.size(0), -1)\n",
+ " out = self.layer_dict['linear'](out, param_dict['linear'])\n",
+ "\n",
+ " return out\n",
+ "\n",
+ " def zero_grad(self, params=None):\n",
+ " if params is None:\n",
+ " for param in self.parameters():\n",
+ " if param.requires_grad == True:\n",
+ " if param.grad is not None:\n",
+ " if torch.sum(param.grad) > 0:\n",
+ " print(param.grad)\n",
+ " param.grad.zero_()\n",
+ " else:\n",
+ " for name, param in params.items():\n",
+ " if param.requires_grad == True:\n",
+ " if param.grad is not None:\n",
+ " if torch.sum(param.grad) > 0:\n",
+ " print(param.grad)\n",
+ " param.grad.zero_()\n",
+ " params[name].grad = None\n",
+ "\n",
+ " def restore_backup_stats(self):\n",
+ " \"\"\"\n",
+ " Reset stored batch statistics from the stored backup.\n",
+ " \"\"\"\n",
+ " for i in range(self.num_stages):\n",
+ " self.layer_dict['conv{}'.format(i)].restore_backup_stats()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "-P0Tewg7iNRF",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class GradientDescentLearningRule(nn.Module):\n",
+ " \"\"\"Simple (stochastic) gradient descent learning rule.\n",
+ " For a scalar error function `E(p[0], p_[1] ... )` of some set of\n",
+ " potentially multidimensional parameters this attempts to find a local\n",
+ " minimum of the loss function by applying updates to each parameter of the\n",
+ " form\n",
+ " p[i] := p[i] - learning_rate * dE/dp[i]\n",
+ " With `learning_rate` a positive scaling parameter.\n",
+ " The error function used in successive applications of these updates may be\n",
+ " a stochastic estimator of the true error function (e.g. when the error with\n",
+ " respect to only a subset of data-points is calculated) in which case this\n",
+ " will correspond to a stochastic gradient descent learning rule.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, device, learning_rate=1e-3):\n",
+ " \"\"\"Creates a new learning rule object.\n",
+ " Args:\n",
+ " learning_rate: A postive scalar to scale gradient updates to the\n",
+ " parameters by. This needs to be carefully set - if too large\n",
+ " the learning dynamic will be unstable and may diverge, while\n",
+ " if set too small learning will proceed very slowly.\n",
+ " \"\"\"\n",
+ " super(GradientDescentLearningRule, self).__init__()\n",
+ " assert learning_rate > 0., 'learning_rate should be positive.'\n",
+ " self.learning_rate = torch.ones(1) * learning_rate\n",
+ " self.learning_rate.to(device)\n",
+ "\n",
+ " def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9):\n",
+ " \"\"\"Applies a single gradient descent update to all parameters.\n",
+ " All parameter updates are performed using in-place operations and so\n",
+ " nothing is returned.\n",
+ " Args:\n",
+ " grads_wrt_params: A list of gradients of the scalar loss function\n",
+ " with respect to each of the parameters passed to `initialise`\n",
+ " previously, with this list expected to be in the same order.\n",
+ " \"\"\"\n",
+ " updated_names_weights_dict = dict()\n",
+ " for key in names_weights_dict.keys():\n",
+ " updated_names_weights_dict[key] = names_weights_dict[key] - self.learning_rate * \\\n",
+ " names_grads_wrt_params_dict[\n",
+ " key]\n",
+ "\n",
+ " return updated_names_weights_dict\n",
+ "\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "t4ogOl5vqpqM",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class LSLRGradientDescentLearningRule(nn.Module):\n",
+ " \"\"\"Simple (stochastic) gradient descent learning rule.\n",
+ " For a scalar error function `E(p[0], p_[1] ... )` of some set of\n",
+ " potentially multidimensional parameters this attempts to find a local\n",
+ " minimum of the loss function by applying updates to each parameter of the\n",
+ " form\n",
+ " p[i] := p[i] - learning_rate * dE/dp[i]\n",
+ " With `learning_rate` a positive scaling parameter.\n",
+ " The error function used in successive applications of these updates may be\n",
+ " a stochastic estimator of the true error function (e.g. when the error with\n",
+ " respect to only a subset of data-points is calculated) in which case this\n",
+ " will correspond to a stochastic gradient descent learning rule.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, init_learning_rate=1e-3):\n",
+ " \"\"\"Creates a new learning rule object.\n",
+ " Args:\n",
+ " init_learning_rate: A postive scalar to scale gradient updates to the\n",
+ " parameters by. This needs to be carefully set - if too large\n",
+ " the learning dynamic will be unstable and may diverge, while\n",
+ " if set too small learning will proceed very slowly.\n",
+ " \"\"\"\n",
+ " super(LSLRGradientDescentLearningRule, self).__init__()\n",
+ " print(init_learning_rate)\n",
+ " assert init_learning_rate > 0., 'learning_rate should be positive.'\n",
+ "\n",
+ " self.init_learning_rate = torch.ones(1) * init_learning_rate\n",
+ " self.init_learning_rate.to(device)\n",
+ " self.total_num_inner_loop_steps = total_num_inner_loop_steps\n",
+ " self.use_learnable_learning_rates = use_learnable_learning_rates\n",
+ "\n",
+ " def initialise(self, names_weights_dict):\n",
+ " self.names_learning_rates_dict = nn.ParameterDict()\n",
+ " for idx, (key, param) in enumerate(names_weights_dict.items()):\n",
+ " self.names_learning_rates_dict[key.replace(\".\", \"-\")] = nn.Parameter(\n",
+ " data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate,\n",
+ " requires_grad=self.use_learnable_learning_rates)\n",
+ "\n",
+ " def reset(self):\n",
+ "\n",
+ " # for key, param in self.names_learning_rates_dict.items():\n",
+ " # param.fill_(self.init_learning_rate)\n",
+ " pass\n",
+ "\n",
+ " def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.1):\n",
+ " \"\"\"Applies a single gradient descent update to all parameters.\n",
+ " All parameter updates are performed using in-place operations and so\n",
+ " nothing is returned.\n",
+ " Args:\n",
+ " grads_wrt_params: A list of gradients of the scalar loss function\n",
+ " with respect to each of the parameters passed to `initialise`\n",
+ " previously, with this list expected to be in the same order.\n",
+ " \"\"\"\n",
+ " updated_names_weights_dict = dict()\n",
+ " for key in names_grads_wrt_params_dict.keys():\n",
+ " updated_names_weights_dict[key] = names_weights_dict[key] - \\\n",
+ " self.names_learning_rates_dict[key.replace(\".\", \"-\")][num_step] \\\n",
+ " * names_grads_wrt_params_dict[\n",
+ " key]\n",
+ "\n",
+ " return updated_names_weights_dict"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nJtmTvZ8nqYN",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def set_torch_seed(seed):\n",
+ " \"\"\"\n",
+ " Sets the pytorch seeds for current experiment run\n",
+ " :param seed: The seed (int)\n",
+ " :return: A random number generator to use\n",
+ " \"\"\"\n",
+ " rng = np.random.RandomState(seed=seed)\n",
+ " torch_seed = rng.randint(0, 999999)\n",
+ " torch.manual_seed(seed=torch_seed)\n",
+ "\n",
+ " return rng"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "iS8_p22sp3hW",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class MAMLFewShotClassifier(nn.Module):\n",
+ " def __init__(self, im_shape, device, args):\n",
+ " \"\"\"\n",
+ " Initializes a MAML few shot learning system\n",
+ " :param im_shape: The images input size, in batch, c, h, w shape\n",
+ " :param device: The device to use to use the model on.\n",
+ " :param args: A namedtuple of arguments specifying various hyperparameters.\n",
+ " \"\"\"\n",
+ " super(MAMLFewShotClassifier, self).__init__()\n",
+ " self.args = args\n",
+ " self.device = device\n",
+ " self.batch_size = args.batch_size\n",
+ " self.use_cuda = args.use_cuda\n",
+ " self.im_shape = im_shape\n",
+ " self.current_epoch = 0\n",
+ "\n",
+ " self.rng = set_torch_seed(seed=args.seed)\n",
+ " self.classifier = VGGReLUNormNetwork(im_shape=self.im_shape, num_output_classes=self.args.\n",
+ " num_classes_per_set,\n",
+ " args=args, device=device, meta_classifier=True).to(device=self.device)\n",
+ " self.task_learning_rate = args.task_learning_rate\n",
+ "\n",
+ " self.inner_loop_optimizer = LSLRGradientDescentLearningRule(device=device,\n",
+ " init_learning_rate=self.task_learning_rate,\n",
+ " total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter,\n",
+ " use_learnable_learning_rates=self.args.learnable_per_layer_per_step_inner_loop_learning_rate)\n",
+ " self.inner_loop_optimizer.initialise(\n",
+ " names_weights_dict=self.get_inner_loop_parameter_dict(params=self.classifier.named_parameters()))\n",
+ "\n",
+ " print(\"Inner Loop parameters\")\n",
+ " for key, value in self.inner_loop_optimizer.named_parameters():\n",
+ " print(key, value.shape)\n",
+ "\n",
+ " self.use_cuda = args.use_cuda\n",
+ " self.device = device\n",
+ " self.args = args\n",
+ " self.to(device)\n",
+ " print(\"Outer Loop parameters\")\n",
+ " for name, param in self.named_parameters():\n",
+ " if param.requires_grad:\n",
+ " print(name, param.shape, param.device, param.requires_grad)\n",
+ "\n",
+ "\n",
+ " self.optimizer = optim.Adam(self.trainable_parameters(), lr=args.meta_learning_rate, amsgrad=False)\n",
+ " self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.args.total_epochs,\n",
+ " eta_min=self.args.min_learning_rate)\n",
+ "\n",
+ " def get_per_step_loss_importance_vector(self):\n",
+ " \"\"\"\n",
+ " Generates a tensor of dimensionality (num_inner_loop_steps) indicating the importance of each step's target\n",
+ " loss towards the optimization loss.\n",
+ " :return: A tensor to be used to compute the weighted average of the loss, useful for\n",
+ " the MSL (Multi Step Loss) mechanism.\n",
+ " \"\"\"\n",
+ " loss_weights = np.ones(shape=(self.args.number_of_training_steps_per_iter)) * (\n",
+ " 1.0 / self.args.number_of_training_steps_per_iter)\n",
+ " decay_rate = 1.0 / self.args.number_of_training_steps_per_iter / self.args.multi_step_loss_num_epochs\n",
+ " min_value_for_non_final_losses = 0.03 / self.args.number_of_training_steps_per_iter\n",
+ " for i in range(len(loss_weights) - 1):\n",
+ " curr_value = np.maximum(loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses)\n",
+ " loss_weights[i] = curr_value\n",
+ "\n",
+ " curr_value = np.minimum(\n",
+ " loss_weights[-1] + (self.current_epoch * (self.args.number_of_training_steps_per_iter - 1) * decay_rate),\n",
+ " 1.0 - ((self.args.number_of_training_steps_per_iter - 1) * min_value_for_non_final_losses))\n",
+ " loss_weights[-1] = curr_value\n",
+ " loss_weights = torch.Tensor(loss_weights).to(device=self.device)\n",
+ " return loss_weights\n",
+ "\n",
+ " def get_inner_loop_parameter_dict(self, params):\n",
+ " \"\"\"\n",
+ " Returns a dictionary with the parameters to use for inner loop updates.\n",
+ " :param params: A dictionary of the network's parameters.\n",
+ " :return: A dictionary of the parameters to use for the inner loop optimization process.\n",
+ " \"\"\"\n",
+ " param_dict = dict()\n",
+ " for name, param in params:\n",
+ " if param.requires_grad:\n",
+ " if self.args.enable_inner_loop_optimizable_bn_params:\n",
+ " param_dict[name] = param.to(device=self.device)\n",
+ " else:\n",
+ " if \"norm_layer\" not in name:\n",
+ " param_dict[name] = param.to(device=self.device)\n",
+ "\n",
+ " return param_dict\n",
+ "\n",
+ " def apply_inner_loop_update(self, loss, names_weights_copy, use_second_order, current_step_idx):\n",
+ " \"\"\"\n",
+ " Applies an inner loop update given current step's loss, the weights to update, a flag indicating whether to use\n",
+ " second order derivatives and the current step's index.\n",
+ " :param loss: Current step's loss with respect to the support set.\n",
+ " :param names_weights_copy: A dictionary with names to parameters to update.\n",
+ " :param use_second_order: A boolean flag of whether to use second order derivatives.\n",
+ " :param current_step_idx: Current step's index.\n",
+ " :return: A dictionary with the updated weights (name, param)\n",
+ " \"\"\"\n",
+ " self.classifier.zero_grad(names_weights_copy)\n",
+ "\n",
+ " grads = torch.autograd.grad(loss, names_weights_copy.values(),\n",
+ " create_graph=use_second_order)\n",
+ " names_grads_wrt_params = dict(zip(names_weights_copy.keys(), grads))\n",
+ "\n",
+ " names_weights_copy = self.inner_loop_optimizer.update_params(names_weights_dict=names_weights_copy,\n",
+ " names_grads_wrt_params_dict=names_grads_wrt_params,\n",
+ " num_step=current_step_idx)\n",
+ "\n",
+ " return names_weights_copy\n",
+ "\n",
+ " def get_across_task_loss_metrics(self, total_losses, total_accuracies):\n",
+ " losses = dict()\n",
+ "\n",
+ " losses['loss'] = torch.mean(torch.stack(total_losses))\n",
+ " losses['accuracy'] = np.mean(total_accuracies)\n",
+ "\n",
+ " return losses\n",
+ "\n",
+ " def forward(self, data_batch, epoch, use_second_order, use_multi_step_loss_optimization, num_steps, training_phase):\n",
+ " \"\"\"\n",
+ " Runs a forward outer loop pass on the batch of tasks using the MAML/++ framework.\n",
+ " :param data_batch: A data batch containing the support and target sets.\n",
+ " :param epoch: Current epoch's index\n",
+ " :param use_second_order: A boolean saying whether to use second order derivatives.\n",
+ " :param use_multi_step_loss_optimization: Whether to optimize on the outer loop using just the last step's\n",
+ " target loss (True) or whether to use multi step loss which improves the stability of the system (False)\n",
+ " :param num_steps: Number of inner loop steps.\n",
+ " :param training_phase: Whether this is a training phase (True) or an evaluation phase (False)\n",
+ " :return: A dictionary with the collected losses of the current outer forward propagation.\n",
+ " \"\"\"\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n",
+ "\n",
+ " [b, ncs, spc] = y_support_set.shape\n",
+ "\n",
+ " self.num_classes_per_set = ncs\n",
+ "\n",
+ " total_losses = []\n",
+ " total_accuracies = []\n",
+ " per_task_target_preds = [[] for i in range(len(x_target_set))]\n",
+ " self.classifier.zero_grad()\n",
+ " for task_id, (x_support_set_task, y_support_set_task, x_target_set_task, y_target_set_task) in \\\n",
+ " enumerate(zip(x_support_set,\n",
+ " y_support_set,\n",
+ " x_target_set,\n",
+ " y_target_set)):\n",
+ " task_losses = []\n",
+ " task_accuracies = []\n",
+ " per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector()\n",
+ " names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters())\n",
+ "\n",
+ " n, s, c, h, w = x_target_set_task.shape\n",
+ "\n",
+ " x_support_set_task = x_support_set_task.view(-1, c, h, w)\n",
+ " y_support_set_task = y_support_set_task.view(-1)\n",
+ " x_target_set_task = x_target_set_task.view(-1, c, h, w)\n",
+ " y_target_set_task = y_target_set_task.view(-1)\n",
+ "\n",
+ " for num_step in range(num_steps):\n",
+ "\n",
+ " support_loss, support_preds = self.net_forward(x=x_support_set_task,\n",
+ " y=y_support_set_task,\n",
+ " weights=names_weights_copy,\n",
+ " backup_running_statistics=\n",
+ " True if (num_step == 0) else False,\n",
+ " training=True, num_step=num_step)\n",
+ "\n",
+ " names_weights_copy = self.apply_inner_loop_update(loss=support_loss,\n",
+ " names_weights_copy=names_weights_copy,\n",
+ " use_second_order=use_second_order,\n",
+ " current_step_idx=num_step)\n",
+ "\n",
+ " if use_multi_step_loss_optimization and training_phase and epoch < self.args.multi_step_loss_num_epochs:\n",
+ " target_loss, target_preds = self.net_forward(x=x_target_set_task,\n",
+ " y=y_target_set_task, weights=names_weights_copy,\n",
+ " backup_running_statistics=False, training=True,\n",
+ " num_step=num_step)\n",
+ "\n",
+ " task_losses.append(per_step_loss_importance_vectors[num_step] * target_loss)\n",
+ " else:\n",
+ " if num_step == (self.args.number_of_training_steps_per_iter - 1):\n",
+ " target_loss, target_preds = self.net_forward(x=x_target_set_task,\n",
+ " y=y_target_set_task, weights=names_weights_copy,\n",
+ " backup_running_statistics=False, training=True,\n",
+ " num_step=num_step)\n",
+ " task_losses.append(target_loss)\n",
+ "\n",
+ " per_task_target_preds[task_id] = target_preds.detach().cpu().numpy()\n",
+ " _, predicted = torch.max(target_preds.data, 1)\n",
+ "\n",
+ " accuracy = predicted.float().eq(y_target_set_task.data.float()).cpu().float()\n",
+ " task_losses = torch.sum(torch.stack(task_losses))\n",
+ " total_losses.append(task_losses)\n",
+ " total_accuracies.extend(accuracy)\n",
+ "\n",
+ " if not training_phase:\n",
+ " self.classifier.restore_backup_stats()\n",
+ "\n",
+ " losses = self.get_across_task_loss_metrics(total_losses=total_losses,\n",
+ " total_accuracies=total_accuracies)\n",
+ "\n",
+ " for idx, item in enumerate(per_step_loss_importance_vectors):\n",
+ " losses['loss_importance_vector_{}'.format(idx)] = item.detach().cpu().numpy()\n",
+ "\n",
+ " return losses, per_task_target_preds\n",
+ "\n",
+ " def net_forward(self, x, y, weights, backup_running_statistics, training, num_step):\n",
+ " \"\"\"\n",
+ " A base model forward pass on some data points x. Using the parameters in the weights dictionary. Also requires\n",
+ " boolean flags indicating whether to reset the running statistics at the end of the run (if at evaluation phase).\n",
+ " A flag indicating whether this is the training session and an int indicating the current step's number in the\n",
+ " inner loop.\n",
+ " :param x: A data batch of shape b, c, h, w\n",
+ " :param y: A data targets batch of shape b, n_classes\n",
+ " :param weights: A dictionary containing the weights to pass to the network.\n",
+ " :param backup_running_statistics: A flag indicating whether to reset the batch norm running statistics to their\n",
+ " previous values after the run (only for evaluation)\n",
+ " :param training: A flag indicating whether the current process phase is a training or evaluation.\n",
+ " :param num_step: An integer indicating the number of the step in the inner loop.\n",
+ " :return: the crossentropy losses with respect to the given y, the predictions of the base model.\n",
+ " \"\"\"\n",
+ " preds = self.classifier.forward(x=x, params=weights,\n",
+ " training=training,\n",
+ " backup_running_statistics=backup_running_statistics, num_step=num_step)\n",
+ "\n",
+ " loss = F.cross_entropy(input=preds, target=y)\n",
+ "\n",
+ " return loss, preds\n",
+ "\n",
+ " def trainable_parameters(self):\n",
+ " \"\"\"\n",
+ " Returns an iterator over the trainable parameters of the model.\n",
+ " \"\"\"\n",
+ " for param in self.parameters():\n",
+ " if param.requires_grad:\n",
+ " yield param\n",
+ "\n",
+ " def train_forward_prop(self, data_batch, epoch):\n",
+ " \"\"\"\n",
+ " Runs an outer loop forward prop using the meta-model and base-model.\n",
+ " :param data_batch: A data batch containing the support set and the target set input, output pairs.\n",
+ " :param epoch: The index of the currrent epoch.\n",
+ " :return: A dictionary of losses for the current step.\n",
+ " \"\"\"\n",
+ " losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch,\n",
+ " use_second_order=self.args.second_order and\n",
+ " epoch > self.args.first_order_to_second_order_epoch,\n",
+ " use_multi_step_loss_optimization=self.args.use_multi_step_loss_optimization,\n",
+ " num_steps=self.args.number_of_training_steps_per_iter,\n",
+ " training_phase=True)\n",
+ " return losses, per_task_target_preds\n",
+ "\n",
+ " def evaluation_forward_prop(self, data_batch, epoch):\n",
+ " \"\"\"\n",
+ " Runs an outer loop evaluation forward prop using the meta-model and base-model.\n",
+ " :param data_batch: A data batch containing the support set and the target set input, output pairs.\n",
+ " :param epoch: The index of the currrent epoch.\n",
+ " :return: A dictionary of losses for the current step.\n",
+ " \"\"\"\n",
+ " losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, use_second_order=False,\n",
+ " use_multi_step_loss_optimization=True,\n",
+ " num_steps=self.args.number_of_evaluation_steps_per_iter,\n",
+ " training_phase=False)\n",
+ "\n",
+ " return losses, per_task_target_preds\n",
+ "\n",
+ " def meta_update(self, loss):\n",
+ " \"\"\"\n",
+ " Applies an outer loop update on the meta-parameters of the model.\n",
+ " :param loss: The current crossentropy loss.\n",
+ " \"\"\"\n",
+ " self.optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " if 'imagenet' in self.args.dataset_name:\n",
+ " for name, param in self.classifier.named_parameters():\n",
+ " if param.requires_grad:\n",
+ " param.grad.data.clamp_(-10, 10) # not sure if this is necessary, more experiments are needed\n",
+ " self.optimizer.step()\n",
+ "\n",
+ " def run_train_iter(self, data_batch, epoch):\n",
+ " \"\"\"\n",
+ " Runs an outer loop update step on the meta-model's parameters.\n",
+ " :param data_batch: input data batch containing the support set and target set input, output pairs\n",
+ " :param epoch: the index of the current epoch\n",
+ " :return: The losses of the ran iteration.\n",
+ " \"\"\"\n",
+ " epoch = int(epoch)\n",
+ " self.scheduler.step(epoch=epoch)\n",
+ " if self.current_epoch != epoch:\n",
+ " self.current_epoch = epoch\n",
+ "\n",
+ " if not self.training:\n",
+ " self.train()\n",
+ "\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n",
+ "\n",
+ " x_support_set = torch.Tensor(x_support_set).float().to(device=self.device)\n",
+ " x_target_set = torch.Tensor(x_target_set).float().to(device=self.device)\n",
+ " y_support_set = torch.Tensor(y_support_set).long().to(device=self.device)\n",
+ " y_target_set = torch.Tensor(y_target_set).long().to(device=self.device)\n",
+ "\n",
+ " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n",
+ "\n",
+ " losses, per_task_target_preds = self.train_forward_prop(data_batch=data_batch, epoch=epoch)\n",
+ "\n",
+ " self.meta_update(loss=losses['loss'])\n",
+ " losses['learning_rate'] = self.scheduler.get_lr()[0]\n",
+ " self.optimizer.zero_grad()\n",
+ " self.zero_grad()\n",
+ "\n",
+ " return losses, per_task_target_preds\n",
+ "\n",
+ " def run_validation_iter(self, data_batch):\n",
+ " \"\"\"\n",
+ " Runs an outer loop evaluation step on the meta-model's parameters.\n",
+ " :param data_batch: input data batch containing the support set and target set input, output pairs\n",
+ " :param epoch: the index of the current epoch\n",
+ " :return: The losses of the ran iteration.\n",
+ " \"\"\"\n",
+ "\n",
+ " if self.training:\n",
+ " self.eval()\n",
+ "\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n",
+ "\n",
+ " x_support_set = torch.Tensor(x_support_set).float().to(device=self.device)\n",
+ " x_target_set = torch.Tensor(x_target_set).float().to(device=self.device)\n",
+ " y_support_set = torch.Tensor(y_support_set).long().to(device=self.device)\n",
+ " y_target_set = torch.Tensor(y_target_set).long().to(device=self.device)\n",
+ "\n",
+ " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n",
+ "\n",
+ " losses, per_task_target_preds = self.evaluation_forward_prop(data_batch=data_batch, epoch=self.current_epoch)\n",
+ "\n",
+ " # losses['loss'].backward() # uncomment if you get the weird memory error\n",
+ " # self.zero_grad()\n",
+ " # self.optimizer.zero_grad()\n",
+ "\n",
+ " return losses, per_task_target_preds\n",
+ "\n",
+ " def save_model(self, model_save_dir, state):\n",
+ " \"\"\"\n",
+ " Save the network parameter state and experiment state dictionary.\n",
+ " :param model_save_dir: The directory to store the state at.\n",
+ " :param state: The state containing the experiment state and the network. It's in the form of a dictionary\n",
+ " object.\n",
+ " \"\"\"\n",
+ " state['network'] = self.state_dict()\n",
+ " torch.save(state, f=model_save_dir)\n",
+ "\n",
+ " def load_model(self, model_save_dir, model_name, model_idx):\n",
+ " \"\"\"\n",
+ " Load checkpoint and return the state dictionary containing the network state params and experiment state.\n",
+ " :param model_save_dir: The directory from which to load the files.\n",
+ " :param model_name: The model_name to be loaded from the direcotry.\n",
+ " :param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current\n",
+ " experiment)\n",
+ " :return: A dictionary containing the experiment state and the saved model parameters.\n",
+ " \"\"\"\n",
+ " filepath = os.path.join(model_save_dir, \"{}_{}\".format(model_name, model_idx))\n",
+ " state = torch.load(filepath)\n",
+ " state_dict_loaded = state['network']\n",
+ " self.load_state_dict(state_dict=state_dict_loaded)\n",
+ " return state"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "UmUi50wLrYr4",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def save_to_json(filename, dict_to_store):\n",
+ " with open(os.path.abspath(filename), 'w') as f:\n",
+ " json.dump(dict_to_store, fp=f)\n",
+ "\n",
+ "def load_from_json(filename):\n",
+ " with open(filename, mode=\"r\") as f:\n",
+ " load_dict = json.load(fp=f)\n",
+ "\n",
+ " return load_dict\n",
+ "\n",
+ "def save_statistics(experiment_name, line_to_add, filename=\"summary_statistics.csv\", create=False):\n",
+ " summary_filename = \"{}/{}\".format(experiment_name, filename)\n",
+ " if create:\n",
+ " with open(summary_filename, 'w') as f:\n",
+ " writer = csv.writer(f)\n",
+ " writer.writerow(line_to_add)\n",
+ " else:\n",
+ " with open(summary_filename, 'a') as f:\n",
+ " writer = csv.writer(f)\n",
+ " writer.writerow(line_to_add)\n",
+ "\n",
+ " return summary_filename\n",
+ "\n",
+ "def load_statistics(experiment_name, filename=\"summary_statistics.csv\"):\n",
+ " data_dict = dict()\n",
+ " summary_filename = \"{}/{}\".format(experiment_name, filename)\n",
+ " with open(summary_filename, 'r') as f:\n",
+ " lines = f.readlines()\n",
+ " data_labels = lines[0].replace(\"\\n\", \"\").split(\",\")\n",
+ " del lines[0]\n",
+ "\n",
+ " for label in data_labels:\n",
+ " data_dict[label] = []\n",
+ "\n",
+ " for line in lines:\n",
+ " data = line.replace(\"\\n\", \"\").split(\",\")\n",
+ " for key, item in zip(data_labels, data):\n",
+ " data_dict[key].append(item)\n",
+ " return data_dict\n",
+ "\n",
+ "\n",
+ "def build_experiment_folder(experiment_name):\n",
+ " experiment_path = os.path.abspath(experiment_name)\n",
+ " saved_models_filepath = \"{}/{}\".format(experiment_path, \"saved_models\")\n",
+ " logs_filepath = \"{}/{}\".format(experiment_path, \"logs\")\n",
+ " samples_filepath = \"{}/{}\".format(experiment_path, \"visual_outputs\")\n",
+ "\n",
+ " if not os.path.exists(experiment_path):\n",
+ " os.makedirs(experiment_path)\n",
+ " if not os.path.exists(logs_filepath):\n",
+ " os.makedirs(logs_filepath)\n",
+ " if not os.path.exists(samples_filepath):\n",
+ " os.makedirs(samples_filepath)\n",
+ " if not os.path.exists(saved_models_filepath):\n",
+ " os.makedirs(saved_models_filepath)\n",
+ "\n",
+ " outputs = (saved_models_filepath, logs_filepath, samples_filepath)\n",
+ " outputs = (os.path.abspath(item) for item in outputs)\n",
+ " return outputs\n",
+ "\n",
+ "def get_best_validation_model_statistics(experiment_name, filename=\"summary_statistics.csv\"):\n",
+ " \"\"\"\n",
+ " Returns the best val epoch and val accuracy from a log csv file\n",
+ " :param log_dir: The log directory the file is saved in\n",
+ " :param statistics_file_name: The log file name\n",
+ " :return: The best validation accuracy and the epoch at which it is produced\n",
+ " \"\"\"\n",
+ " log_file_dict = load_statistics(filename=filename, experiment_name=experiment_name)\n",
+ " d_val_loss = np.array(log_file_dict['total_d_val_loss_mean'], dtype=np.float32)\n",
+ " best_d_val_loss = np.min(d_val_loss)\n",
+ " best_d_val_epoch = np.argmin(d_val_loss)\n",
+ "\n",
+ " return best_d_val_loss, best_d_val_epoch\n",
+ "\n",
+ "def create_json_experiment_log(experiment_log_dir, args, log_name=\"experiment_log.json\"):\n",
+ " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n",
+ "\n",
+ " experiment_summary_dict = dict()\n",
+ "\n",
+ " for key, value in vars(args).items():\n",
+ " experiment_summary_dict[key] = value\n",
+ "\n",
+ " experiment_summary_dict[\"epoch_stats\"] = dict()\n",
+ " timestamp = datetime.datetime.now().timestamp()\n",
+ " experiment_summary_dict[\"experiment_status\"] = [(timestamp, \"initialization\")]\n",
+ " experiment_summary_dict[\"experiment_initialization_time\"] = timestamp\n",
+ " with open(os.path.abspath(summary_filename), 'w') as f:\n",
+ " json.dump(experiment_summary_dict, fp=f)\n",
+ "\n",
+ "def update_json_experiment_log_dict(key, value, experiment_log_dir, log_name=\"experiment_log.json\"):\n",
+ " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n",
+ " with open(summary_filename) as f:\n",
+ " summary_dict = json.load(fp=f)\n",
+ "\n",
+ " summary_dict[key].append(value)\n",
+ "\n",
+ " with open(summary_filename, 'w') as f:\n",
+ " json.dump(summary_dict, fp=f)\n",
+ "\n",
+ "def change_json_log_experiment_status(experiment_status, experiment_log_dir, log_name=\"experiment_log.json\"):\n",
+ " timestamp = datetime.datetime.now().timestamp()\n",
+ " experiment_status = (timestamp, experiment_status)\n",
+ " update_json_experiment_log_dict(key=\"experiment_status\", value=experiment_status,\n",
+ " experiment_log_dir=experiment_log_dir, log_name=log_name)\n",
+ "\n",
+ "def update_json_experiment_log_epoch_stats(epoch_stats, experiment_log_dir, log_name=\"experiment_log.json\"):\n",
+ " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n",
+ " with open(summary_filename) as f:\n",
+ " summary_dict = json.load(fp=f)\n",
+ "\n",
+ " epoch_stats_dict = summary_dict[\"epoch_stats\"]\n",
+ "\n",
+ " for key in epoch_stats.keys():\n",
+ " entry = float(epoch_stats[key])\n",
+ " if key in epoch_stats_dict:\n",
+ " epoch_stats_dict[key].append(entry)\n",
+ " else:\n",
+ " epoch_stats_dict[key] = [entry]\n",
+ "\n",
+ " summary_dict['epoch_stats'] = epoch_stats_dict\n",
+ "\n",
+ " with open(summary_filename, 'w') as f:\n",
+ " json.dump(summary_dict, fp=f)\n",
+ " return summary_filename"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "xCz8qhr0w2g3",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "class ExperimentBuilder(object):\n",
+ " def __init__(self, args, data, model, device):\n",
+ " \"\"\"\n",
+ " Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system\n",
+ " (model) and a device (e.g. gpu/cpu/n)\n",
+ " :param args: A namedtuple containing all experiment hyperparameters\n",
+ " :param data: A data provider of instance MetaLearningSystemDataLoader\n",
+ " :param model: A meta learning system instance\n",
+ " :param device: Device/s to use for the experiment\n",
+ " \"\"\"\n",
+ " self.args, self.device = args, device\n",
+ "\n",
+ " self.model = model\n",
+ " self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(\n",
+ " experiment_name=self.args.experiment_name)\n",
+ "\n",
+ " self.total_losses = dict()\n",
+ " self.state = dict()\n",
+ " self.state['best_val_acc'] = 0.\n",
+ " self.state['best_val_iter'] = 0\n",
+ " self.state['current_iter'] = 0\n",
+ " self.state['current_iter'] = 0\n",
+ " self.start_epoch = 0\n",
+ " self.max_models_to_save = self.args.max_models_to_save\n",
+ " self.create_summary_csv = False\n",
+ "\n",
+ " if self.args.continue_from_epoch == 'from_scratch':\n",
+ " self.create_summary_csv = True\n",
+ "\n",
+ " elif self.args.continue_from_epoch == 'latest':\n",
+ " checkpoint = os.path.join(self.saved_models_filepath, \"train_model_latest\")\n",
+ " print(\"attempting to find existing checkpoint\", )\n",
+ " if os.path.exists(checkpoint):\n",
+ " self.state = \\\n",
+ " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n",
+ " model_idx='latest')\n",
+ " self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n",
+ "\n",
+ " else:\n",
+ " self.args.continue_from_epoch = 'from_scratch'\n",
+ " self.create_summary_csv = True\n",
+ " elif int(self.args.continue_from_epoch) >= 0:\n",
+ " self.state = \\\n",
+ " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n",
+ " model_idx=self.args.continue_from_epoch)\n",
+ " self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n",
+ "\n",
+ " self.data = data(args=args, current_iter=self.state['current_iter'])\n",
+ "\n",
+ " print(\"train_seed {}, val_seed: {}, at start time\".format(self.data.dataset.seed[\"train\"],\n",
+ " self.data.dataset.seed[\"val\"]))\n",
+ " self.total_epochs_before_pause = self.args.total_epochs_before_pause\n",
+ " self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch)\n",
+ " self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n",
+ " self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False\n",
+ " self.start_time = time.time()\n",
+ " self.epochs_done_in_this_run = 0\n",
+ " print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))\n",
+ "\n",
+ " def build_summary_dict(self, total_losses, phase, summary_losses=None):\n",
+ " \"\"\"\n",
+ " Builds/Updates a summary dict directly from the metric dict of the current iteration.\n",
+ " :param total_losses: Current dict with total losses (not aggregations) from experiment\n",
+ " :param phase: Current training phase\n",
+ " :param summary_losses: Current summarised (aggregated/summarised) losses stats means, stdv etc.\n",
+ " :return: A new summary dict with the updated summary statistics information.\n",
+ " \"\"\"\n",
+ " if summary_losses is None:\n",
+ " summary_losses = dict()\n",
+ "\n",
+ " for key in total_losses:\n",
+ " summary_losses[\"{}_{}_mean\".format(phase, key)] = np.mean(total_losses[key])\n",
+ " summary_losses[\"{}_{}_std\".format(phase, key)] = np.std(total_losses[key])\n",
+ "\n",
+ " return summary_losses\n",
+ "\n",
+ " def build_loss_summary_string(self, summary_losses):\n",
+ " \"\"\"\n",
+ " Builds a progress bar summary string given current summary losses dictionary\n",
+ " :param summary_losses: Current summary statistics\n",
+ " :return: A summary string ready to be shown to humans.\n",
+ " \"\"\"\n",
+ " output_update = \"\"\n",
+ " for key, value in zip(list(summary_losses.keys()), list(summary_losses.values())):\n",
+ " if \"loss\" in key or \"accuracy\" in key:\n",
+ " value = float(value)\n",
+ " output_update += \"{}: {:.4f}, \".format(key, value)\n",
+ "\n",
+ " return output_update\n",
+ "\n",
+ " def merge_two_dicts(self, first_dict, second_dict):\n",
+ " \"\"\"Given two dicts, merge them into a new dict as a shallow copy.\"\"\"\n",
+ " z = first_dict.copy()\n",
+ " z.update(second_dict)\n",
+ " return z\n",
+ "\n",
+ " def train_iteration(self, train_sample, sample_idx, epoch_idx, total_losses, current_iter, pbar_train):\n",
+ " \"\"\"\n",
+ " Runs a training iteration, updates the progress bar and returns the total and current epoch train losses.\n",
+ " :param train_sample: A sample from the data provider\n",
+ " :param sample_idx: The index of the incoming sample, in relation to the current training run.\n",
+ " :param epoch_idx: The epoch index.\n",
+ " :param total_losses: The current total losses dictionary to be updated.\n",
+ " :param current_iter: The current training iteration in relation to the whole experiment.\n",
+ " :param pbar_train: The progress bar of the training.\n",
+ " :return: Updates total_losses, train_losses, current_iter\n",
+ " \"\"\"\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set, seed = train_sample\n",
+ " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n",
+ "\n",
+ " if sample_idx == 0:\n",
+ " print(\"shape of data\", x_support_set.shape, x_target_set.shape, y_support_set.shape,\n",
+ " y_target_set.shape)\n",
+ "\n",
+ " losses, _ = self.model.run_train_iter(data_batch=data_batch, epoch=epoch_idx)\n",
+ "\n",
+ " for key, value in zip(list(losses.keys()), list(losses.values())):\n",
+ " if key not in total_losses:\n",
+ " total_losses[key] = [float(value)]\n",
+ " else:\n",
+ " total_losses[key].append(float(value))\n",
+ "\n",
+ " train_losses = self.build_summary_dict(total_losses=total_losses, phase=\"train\")\n",
+ " train_output_update = self.build_loss_summary_string(losses)\n",
+ "\n",
+ " pbar_train.update(1)\n",
+ " pbar_train.set_description(\"training phase {} -> {}\".format(self.epoch, train_output_update))\n",
+ "\n",
+ " current_iter += 1\n",
+ "\n",
+ " return train_losses, total_losses, current_iter\n",
+ "\n",
+ " def evaluation_iteration(self, val_sample, total_losses, pbar_val, phase):\n",
+ " \"\"\"\n",
+ " Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.\n",
+ " :param val_sample: A sample from the data provider\n",
+ " :param total_losses: The current total losses dictionary to be updated.\n",
+ " :param pbar_val: The progress bar of the val stage.\n",
+ " :return: The updated val_losses, total_losses\n",
+ " \"\"\"\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample\n",
+ " data_batch = (\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set)\n",
+ "\n",
+ " losses, _ = self.model.run_validation_iter(data_batch=data_batch)\n",
+ " for key, value in zip(list(losses.keys()), list(losses.values())):\n",
+ " if key not in total_losses:\n",
+ " total_losses[key] = [float(value)]\n",
+ " else:\n",
+ " total_losses[key].append(float(value))\n",
+ "\n",
+ " val_losses = self.build_summary_dict(total_losses=total_losses, phase=phase)\n",
+ " val_output_update = self.build_loss_summary_string(losses)\n",
+ "\n",
+ " pbar_val.update(1)\n",
+ " pbar_val.set_description(\n",
+ " \"val_phase {} -> {}\".format(self.epoch, val_output_update))\n",
+ "\n",
+ " return val_losses, total_losses\n",
+ "\n",
+ " def test_evaluation_iteration(self, val_sample, model_idx, sample_idx, per_model_per_batch_preds, pbar_test):\n",
+ " \"\"\"\n",
+ " Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.\n",
+ " :param val_sample: A sample from the data provider\n",
+ " :param total_losses: The current total losses dictionary to be updated.\n",
+ " :param pbar_test: The progress bar of the val stage.\n",
+ " :return: The updated val_losses, total_losses\n",
+ " \"\"\"\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample\n",
+ " data_batch = (\n",
+ " x_support_set, x_target_set, y_support_set, y_target_set)\n",
+ "\n",
+ " losses, per_task_preds = self.model.run_validation_iter(data_batch=data_batch)\n",
+ "\n",
+ " per_model_per_batch_preds[model_idx].extend(list(per_task_preds))\n",
+ "\n",
+ " test_output_update = self.build_loss_summary_string(losses)\n",
+ "\n",
+ " pbar_test.update(1)\n",
+ " pbar_test.set_description(\n",
+ " \"test_phase {} -> {}\".format(self.epoch, test_output_update))\n",
+ "\n",
+ " return per_model_per_batch_preds\n",
+ "\n",
+ " def save_models(self, model, epoch, state):\n",
+ " \"\"\"\n",
+ " Saves two separate instances of the current model. One to be kept for history and reloading later and another\n",
+ " one marked as \"latest\" to be used by the system for the next epoch training. Useful when the training/val\n",
+ " process is interrupted or stopped. Leads to fault tolerant training and validation systems that can continue\n",
+ " from where they left off before.\n",
+ " :param model: Current meta learning model of any instance within the few_shot_learning_system.py\n",
+ " :param epoch: Current epoch\n",
+ " :param state: Current model and experiment state dict.\n",
+ " \"\"\"\n",
+ " model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, \"train_model_{}\".format(int(epoch))),\n",
+ " state=state)\n",
+ "\n",
+ " model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, \"train_model_latest\"),\n",
+ " state=state)\n",
+ "\n",
+ " print(\"saved models to\", self.saved_models_filepath)\n",
+ "\n",
+ " def pack_and_save_metrics(self, start_time, create_summary_csv, train_losses, val_losses, state, step):\n",
+ " \"\"\"\n",
+ " Given current epochs start_time, train losses, val losses and whether to create a new stats csv file, pack stats\n",
+ " and save into a statistics csv file. Return a new start time for the new epoch.\n",
+ " :param start_time: The start time of the current epoch\n",
+ " :param create_summary_csv: A boolean variable indicating whether to create a new statistics file or\n",
+ " append results to existing one\n",
+ " :param train_losses: A dictionary with the current train losses\n",
+ " :param val_losses: A dictionary with the currrent val loss\n",
+ " :return: The current time, to be used for the next epoch.\n",
+ " \"\"\"\n",
+ " epoch_summary_losses = self.merge_two_dicts(first_dict=train_losses, second_dict=val_losses)\n",
+ "\n",
+ " if 'per_epoch_statistics' not in state:\n",
+ " state['per_epoch_statistics'] = dict()\n",
+ "\n",
+ " for key, value in epoch_summary_losses.items():\n",
+ "\n",
+ " if key not in state['per_epoch_statistics']:\n",
+ " state['per_epoch_statistics'][key] = [value]\n",
+ " else:\n",
+ " state['per_epoch_statistics'][key].append(value)\n",
+ "\n",
+ " epoch_summary_string = self.build_loss_summary_string(epoch_summary_losses)\n",
+ " epoch_summary_losses[\"epoch\"] = self.epoch\n",
+ " epoch_summary_losses['epoch_run_time'] = time.time() - start_time\n",
+ "\n",
+ " if create_summary_csv:\n",
+ " self.summary_statistics_filepath = save_statistics(self.logs_filepath, list(epoch_summary_losses.keys()),\n",
+ " create=True)\n",
+ " self.create_summary_csv = False\n",
+ "\n",
+ " start_time = time.time()\n",
+ " \n",
+ " writer.add_scalar('epoch_summary_losses', epoch_summary_losses[\"epoch\"], step)\n",
+ " \n",
+ " print(\"epoch {} -> {}\".format(epoch_summary_losses[\"epoch\"], epoch_summary_string))\n",
+ "\n",
+ " self.summary_statistics_filepath = save_statistics(self.logs_filepath,\n",
+ " list(epoch_summary_losses.values()))\n",
+ " return start_time, state\n",
+ "\n",
+ " def evaluated_test_set_using_the_best_models(self, top_n_models):\n",
+ " per_epoch_statistics = self.state['per_epoch_statistics']\n",
+ " val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])\n",
+ " val_idx = np.array([i for i in range(len(val_acc))])\n",
+ " sorted_idx = np.argsort(val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]\n",
+ "\n",
+ " sorted_val_acc = val_acc[sorted_idx]\n",
+ " val_idx = val_idx[sorted_idx]\n",
+ " print(sorted_idx)\n",
+ " print(sorted_val_acc)\n",
+ "\n",
+ " top_n_idx = val_idx[:top_n_models]\n",
+ " per_model_per_batch_preds = [[] for i in range(top_n_models)]\n",
+ " per_model_per_batch_targets = [[] for i in range(top_n_models)]\n",
+ " test_losses = [dict() for i in range(top_n_models)]\n",
+ " for idx, model_idx in enumerate(top_n_idx):\n",
+ " self.state = \\\n",
+ " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n",
+ " model_idx=model_idx + 1)\n",
+ " with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_test:\n",
+ " for sample_idx, test_sample in enumerate(\n",
+ " self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),\n",
+ " augment_images=False)):\n",
+ " #print(test_sample[4])\n",
+ " per_model_per_batch_targets[idx].extend(np.array(test_sample[3]))\n",
+ " per_model_per_batch_preds = self.test_evaluation_iteration(val_sample=test_sample,\n",
+ " sample_idx=sample_idx,\n",
+ " model_idx=idx,\n",
+ " per_model_per_batch_preds=per_model_per_batch_preds,\n",
+ " pbar_test=pbar_test)\n",
+ " # for i in range(top_n_models):\n",
+ " # print(\"test assertion\", 0)\n",
+ " # print(per_model_per_batch_targets[0], per_model_per_batch_targets[i])\n",
+ " # assert np.equal(np.array(per_model_per_batch_targets[0]), np.array(per_model_per_batch_targets[i]))\n",
+ "\n",
+ " per_batch_preds = np.mean(per_model_per_batch_preds, axis=0)\n",
+ " #print(per_batch_preds.shape)\n",
+ " per_batch_max = np.argmax(per_batch_preds, axis=2)\n",
+ " per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(per_batch_max.shape)\n",
+ " #print(per_batch_max)\n",
+ " accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))\n",
+ " accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))\n",
+ "\n",
+ " test_losses = {\"test_accuracy_mean\": accuracy, \"test_accuracy_std\": accuracy_std}\n",
+ "\n",
+ " _ = save_statistics(self.logs_filepath,\n",
+ " list(test_losses.keys()),\n",
+ " create=True, filename=\"test_summary.csv\")\n",
+ "\n",
+ " summary_statistics_filepath = save_statistics(self.logs_filepath,\n",
+ " list(test_losses.values()),\n",
+ " create=False, filename=\"test_summary.csv\")\n",
+ " print(test_losses)\n",
+ " print(\"saved test performance at\", summary_statistics_filepath)\n",
+ "\n",
+ " def run_experiment(self):\n",
+ " \"\"\"\n",
+ " Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,\n",
+ " will return the test set evaluation results on the best performing validation model.\n",
+ " \"\"\"\n",
+ " with tqdm.tqdm(initial=self.state['current_iter'],\n",
+ " total=int(self.args.total_iter_per_epoch * self.args.total_epochs)) as pbar_train:\n",
+ "\n",
+ " while (self.state['current_iter'] < (self.args.total_epochs * self.args.total_iter_per_epoch)) and (self.args.evaluate_on_test_set_only == False):\n",
+ "\n",
+ " for train_sample_idx, train_sample in enumerate(\n",
+ " self.data.get_train_batches(total_batches=int(self.args.total_iter_per_epoch *\n",
+ " self.args.total_epochs) - self.state[\n",
+ " 'current_iter'],\n",
+ " augment_images=self.augment_flag)):\n",
+ " # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))\n",
+ " train_losses, total_losses, self.state['current_iter'] = self.train_iteration(\n",
+ " train_sample=train_sample,\n",
+ " total_losses=self.total_losses,\n",
+ " epoch_idx=(self.state['current_iter'] /\n",
+ " self.args.total_iter_per_epoch),\n",
+ " pbar_train=pbar_train,\n",
+ " current_iter=self.state['current_iter'],\n",
+ " sample_idx=self.state['current_iter'])\n",
+ "\n",
+ " if self.state['current_iter'] % self.args.total_iter_per_epoch == 0:\n",
+ "\n",
+ " total_losses = dict()\n",
+ " val_losses = dict()\n",
+ " with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_val:\n",
+ " for _, val_sample in enumerate(\n",
+ " self.data.get_val_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),\n",
+ " augment_images=False)):\n",
+ " val_losses, total_losses = self.evaluation_iteration(val_sample=val_sample,\n",
+ " total_losses=total_losses,\n",
+ " pbar_val=pbar_val, phase='val')\n",
+ "\n",
+ " if val_losses[\"val_accuracy_mean\"] > self.state['best_val_acc']:\n",
+ " print(\"Best validation accuracy\", val_losses[\"val_accuracy_mean\"])\n",
+ " writer.add_scalar(\"best validation accuracy\", val_losses[\"val_accuracy_mean\"],\n",
+ " self.epoch)\n",
+ " self.state['best_val_acc'] = val_losses[\"val_accuracy_mean\"]\n",
+ " self.state['best_val_iter'] = self.state['current_iter']\n",
+ " self.state['best_epoch'] = int(\n",
+ " self.state['best_val_iter'] / self.args.total_iter_per_epoch)\n",
+ "\n",
+ "\n",
+ " self.epoch += 1\n",
+ " self.state = self.merge_two_dicts(first_dict=self.merge_two_dicts(first_dict=self.state,\n",
+ " second_dict=train_losses),\n",
+ " second_dict=val_losses)\n",
+ "\n",
+ " self.save_models(model=self.model, epoch=self.epoch, state=self.state)\n",
+ "\n",
+ " self.start_time, self.state = self.pack_and_save_metrics(start_time=self.start_time,\n",
+ " create_summary_csv=self.create_summary_csv,\n",
+ " train_losses=train_losses,\n",
+ " val_losses=val_losses,\n",
+ " state=self.state,\n",
+ " step=self.epoch)\n",
+ "\n",
+ " self.total_losses = dict()\n",
+ "\n",
+ " self.epochs_done_in_this_run += 1\n",
+ "\n",
+ " save_to_json(filename=os.path.join(self.logs_filepath, \"summary_statistics.json\"),\n",
+ " dict_to_store=self.state['per_epoch_statistics'])\n",
+ "\n",
+ " if self.epochs_done_in_this_run >= self.total_epochs_before_pause:\n",
+ " print(\"train_seed {}, val_seed: {}, at pause time\".format(self.data.dataset.seed[\"train\"],\n",
+ " self.data.dataset.seed[\"val\"]))\n",
+ " sys.exit()\n",
+ " self.evaluated_test_set_using_the_best_models(top_n_models=5)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ClIPkzMEXvKS",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "from torch import cuda\n",
+ "\n",
+ "\n",
+ "def get_args():\n",
+ " import argparse\n",
+ " import os\n",
+ " import torch\n",
+ " import json\n",
+ " parser = argparse.ArgumentParser(description='Welcome to the MAML++ training and inference system')\n",
+ "\n",
+ " parser.add_argument('--batch_size', nargs=\"?\", type=int, default=32, help='Batch_size for experiment')\n",
+ " parser.add_argument('--image_height', nargs=\"?\", type=int, default=28)\n",
+ " parser.add_argument('--image_width', nargs=\"?\", type=int, default=28)\n",
+ " parser.add_argument('--image_channels', nargs=\"?\", type=int, default=1)\n",
+ " parser.add_argument('--reset_stored_filepaths', type=str, default=\"False\")\n",
+ " parser.add_argument('--reverse_channels', type=str, default=\"False\")\n",
+ " parser.add_argument('--num_of_gpus', type=int, default=1)\n",
+ " parser.add_argument('--indexes_of_folders_indicating_class', nargs='+', default=[-2, -3])\n",
+ " parser.add_argument('--train_val_test_split', nargs='+', default=[0.73982737361, 0.26, 0.13008631319])\n",
+ " parser.add_argument('--samples_per_iter', nargs=\"?\", type=int, default=1)\n",
+ " parser.add_argument('--labels_as_int', type=str, default=\"False\")\n",
+ " parser.add_argument('--seed', type=int, default=104)\n",
+ "\n",
+ " parser.add_argument('--gpu_to_use', type=int)\n",
+ " parser.add_argument('--num_dataprovider_workers', nargs=\"?\", type=int, default=4)\n",
+ " parser.add_argument('--max_models_to_save', nargs=\"?\", type=int, default=5)\n",
+ " parser.add_argument('--dataset_name', type=str, default=\"omniglot_dataset\")\n",
+ " parser.add_argument('--dataset_path', type=str, default=\"datasets/omniglot_dataset\")\n",
+ " parser.add_argument('--reset_stored_paths', type=str, default=\"False\")\n",
+ " parser.add_argument('--experiment_name', nargs=\"?\", type=str, )\n",
+ " parser.add_argument('--architecture_name', nargs=\"?\", type=str)\n",
+ " parser.add_argument('--continue_from_epoch', nargs=\"?\", type=str, default='latest', help='Continue from checkpoint of epoch')\n",
+ " parser.add_argument('--dropout_rate_value', type=float, default=0.3, help='Dropout_rate_value')\n",
+ " parser.add_argument('--num_target_samples', type=int, default=15, help='Dropout_rate_value')\n",
+ " parser.add_argument('--second_order', type=str, default=\"False\", help='Dropout_rate_value')\n",
+ " parser.add_argument('--total_epochs', type=int, default=200, help='Number of epochs per experiment')\n",
+ " parser.add_argument('--total_iter_per_epoch', type=int, default=500, help='Number of iters per epoch')\n",
+ " parser.add_argument('--min_learning_rate', type=float, default=0.00001, help='Min learning rate')\n",
+ " parser.add_argument('--meta_learning_rate', type=float, default=0.001, help='Learning rate of overall MAML system')\n",
+ " parser.add_argument('--meta_opt_bn', type=str, default=\"False\")\n",
+ " parser.add_argument('--task_learning_rate', type=float, default=0.1, help='Learning rate per task gradient step')\n",
+ "\n",
+ " parser.add_argument('--norm_layer', type=str, default=\"batch_norm\")\n",
+ " parser.add_argument('--max_pooling', type=str, default=\"False\")\n",
+ " parser.add_argument('--per_step_bn_statistics', type=str, default=\"False\")\n",
+ " parser.add_argument('--num_classes_per_set', type=int, default=20, help='Number of classes to sample per set')\n",
+ " parser.add_argument('--cnn_num_blocks', type=int, default=4, help='Number of classes to sample per set')\n",
+ " parser.add_argument('--number_of_training_steps_per_iter', type=int, default=1, help='Number of classes to sample per set')\n",
+ " parser.add_argument('--number_of_evaluation_steps_per_iter', type=int, default=1, help='Number of classes to sample per set')\n",
+ " parser.add_argument('--cnn_num_filters', type=int, default=64, help='Number of classes to sample per set')\n",
+ " parser.add_argument('--cnn_blocks_per_stage', type=int, default=1,\n",
+ " help='Number of classes to sample per set')\n",
+ " parser.add_argument('--num_samples_per_class', type=int, default=1, help='Number of samples per set to sample')\n",
+ " parser.add_argument('--name_of_args_json_file', type=str, default=\"./experiment_config/omniglot_maml++-omniglot_1_8_0.1_64_5_0.json\")\n",
+ " # parser.add_argument('--num_stages', default=4)\n",
+ " # parser.add_argument('--conv_padding', default=True)\n",
+ " \n",
+ " args = parser.parse_args('')\n",
+ " args_dict = vars(args)\n",
+ " if args.name_of_args_json_file is not \"None\":\n",
+ " args_dict = extract_args_from_json(args.name_of_args_json_file, args_dict)\n",
+ "\n",
+ " for key in list(args_dict.keys()):\n",
+ "\n",
+ " if str(args_dict[key]).lower() == \"true\":\n",
+ " args_dict[key] = True\n",
+ " elif str(args_dict[key]).lower() == \"false\":\n",
+ " args_dict[key] = False\n",
+ " if key == \"dataset_path\":\n",
+ " args_dict[key] = os.path.join(os.environ['DATASET_DIR'], args_dict[key])\n",
+ " print(key, os.path.join(os.environ['DATASET_DIR'], args_dict[key]))\n",
+ "\n",
+ " print(key, args_dict[key], type(args_dict[key]))\n",
+ "\n",
+ " args = Bunch(args_dict)\n",
+ "\n",
+ "\n",
+ " args.use_cuda = torch.cuda.is_available()\n",
+ "\n",
+ " if args.gpu_to_use == -1:\n",
+ " args.use_cuda = False\n",
+ "\n",
+ " if args.use_cuda:\n",
+ " os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu_to_use)\n",
+ " device = cuda.current_device()\n",
+ " else:\n",
+ " device = torch.device('cpu')\n",
+ "\n",
+ " return args, device\n",
+ "\n",
+ "\n",
+ "\n",
+ "class Bunch(object):\n",
+ " def __init__(self, adict):\n",
+ " self.__dict__.update(adict)\n",
+ "\n",
+ "def extract_args_from_json(json_file_path, args_dict):\n",
+ " import json\n",
+ " summary_filename = json_file_path\n",
+ " with open(summary_filename) as f:\n",
+ " summary_dict = json.load(fp=f)\n",
+ "\n",
+ " for key in summary_dict.keys():\n",
+ " if \"continue_from\" in key:\n",
+ " pass\n",
+ " elif \"gpu_to_use\" in key:\n",
+ " pass\n",
+ " else:\n",
+ " args_dict[key] = summary_dict[key]\n",
+ "\n",
+ " return args_dict\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "GmWqo4GQ66oG",
+ "colab_type": "code",
+ "outputId": "bf944b05-3d91-4b35-bd3e-506d94ab78e5",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 2675
+ }
+ },
+ "source": [
+ "# Combines the arguments, model, data and experiment builders to run an experiment\n",
+ "args, device = get_args()\n",
+ "print(args.image_channels)\n",
+ "model = MAMLFewShotClassifier(args=args, device=device,\n",
+ " im_shape=(2, args.image_channels, args.image_height, args.image_width))\n",
+ "# maybe_unzip_dataset(args=args)\n",
+ "data = MetaLearningSystemDataLoader\n",
+ "maml_system = ExperimentBuilder(model=model, data=data, args=args, device=device)\n",
+ "maml_system.run_experiment()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "batch_size 8 \n",
+ "image_height 28 \n",
+ "image_width 28 \n",
+ "image_channels 1 \n",
+ "reset_stored_filepaths False \n",
+ "reverse_channels False \n",
+ "num_of_gpus 1 \n",
+ "indexes_of_folders_indicating_class [-3, -2] \n",
+ "train_val_test_split [0.70918052988, 0.03080714725, 0.2606284658] \n",
+ "samples_per_iter 1 \n",
+ "labels_as_int False \n",
+ "seed 104 \n",
+ "gpu_to_use None \n",
+ "num_dataprovider_workers 4 \n",
+ "max_models_to_save 5 \n",
+ "dataset_name omniglot_dataset \n",
+ "dataset_path ./datasets/./datasets/omniglot_dataset\n",
+ "dataset_path ./datasets/omniglot_dataset \n",
+ "reset_stored_paths False \n",
+ "experiment_name omniglot_1_8_0.1_64_5_0 \n",
+ "architecture_name None \n",
+ "continue_from_epoch latest \n",
+ "dropout_rate_value 0.0 \n",
+ "num_target_samples 1 \n",
+ "second_order True \n",
+ "total_epochs 100 \n",
+ "total_iter_per_epoch 500 \n",
+ "min_learning_rate 1e-05 \n",
+ "meta_learning_rate 0.001 \n",
+ "meta_opt_bn False \n",
+ "task_learning_rate 0.1 \n",
+ "norm_layer batch_norm \n",
+ "max_pooling True \n",
+ "per_step_bn_statistics True \n",
+ "num_classes_per_set 5 \n",
+ "cnn_num_blocks 4 \n",
+ "number_of_training_steps_per_iter 5 \n",
+ "number_of_evaluation_steps_per_iter 5 \n",
+ "cnn_num_filters 64 \n",
+ "cnn_blocks_per_stage 1 \n",
+ "num_samples_per_class 1 \n",
+ "name_of_args_json_file ./experiment_config/omniglot_maml++-omniglot_1_8_0.1_64_5_0.json \n",
+ "train_seed 0 \n",
+ "val_seed 0 \n",
+ "load_from_npz_files False \n",
+ "sets_are_pre_split False \n",
+ "load_into_memory True \n",
+ "init_inner_loop_learning_rate 0.1 \n",
+ "train_in_stages False \n",
+ "multi_step_loss_num_epochs 10 \n",
+ "minimum_per_task_contribution 0.01 \n",
+ "num_evaluation_tasks 600 \n",
+ "learnable_per_layer_per_step_inner_loop_learning_rate True \n",
+ "enable_inner_loop_optimizable_bn_params False \n",
+ "evaluate_on_test_set_only False \n",
+ "learnable_batch_norm_momentum False \n",
+ "evalute_on_test_set_only False \n",
+ "learnable_bn_gamma True \n",
+ "learnable_bn_beta True \n",
+ "weight_decay 0.0 \n",
+ "total_epochs_before_pause 100 \n",
+ "first_order_to_second_order_epoch -1 \n",
+ "num_stages 4 \n",
+ "conv_padding True \n",
+ "use_multi_step_loss_optimization True \n",
+ "1\n",
+ "Using max pooling\n",
+ "torch.Size([2, 64, 28, 28])\n",
+ "torch.Size([2, 64, 14, 14])\n",
+ "torch.Size([2, 64, 7, 7])\n",
+ "torch.Size([2, 64, 3, 3])\n",
+ "VGGNetwork build torch.Size([2, 5])\n",
+ "meta network params\n",
+ "layer_dict.conv0.conv.weight torch.Size([64, 1, 3, 3])\n",
+ "layer_dict.conv0.conv.bias torch.Size([64])\n",
+ "layer_dict.conv0.norm_layer.running_mean torch.Size([5, 64])\n",
+ "layer_dict.conv0.norm_layer.running_var torch.Size([5, 64])\n",
+ "layer_dict.conv0.norm_layer.bias torch.Size([5, 64])\n",
+ "layer_dict.conv0.norm_layer.weight torch.Size([5, 64])\n",
+ "layer_dict.conv1.conv.weight torch.Size([64, 64, 3, 3])\n",
+ "layer_dict.conv1.conv.bias torch.Size([64])\n",
+ "layer_dict.conv1.norm_layer.running_mean torch.Size([5, 64])\n",
+ "layer_dict.conv1.norm_layer.running_var torch.Size([5, 64])\n",
+ "layer_dict.conv1.norm_layer.bias torch.Size([5, 64])\n",
+ "layer_dict.conv1.norm_layer.weight torch.Size([5, 64])\n",
+ "layer_dict.conv2.conv.weight torch.Size([64, 64, 3, 3])\n",
+ "layer_dict.conv2.conv.bias torch.Size([64])\n",
+ "layer_dict.conv2.norm_layer.running_mean torch.Size([5, 64])\n",
+ "layer_dict.conv2.norm_layer.running_var torch.Size([5, 64])\n",
+ "layer_dict.conv2.norm_layer.bias torch.Size([5, 64])\n",
+ "layer_dict.conv2.norm_layer.weight torch.Size([5, 64])\n",
+ "layer_dict.conv3.conv.weight torch.Size([64, 64, 3, 3])\n",
+ "layer_dict.conv3.conv.bias torch.Size([64])\n",
+ "layer_dict.conv3.norm_layer.running_mean torch.Size([5, 64])\n",
+ "layer_dict.conv3.norm_layer.running_var torch.Size([5, 64])\n",
+ "layer_dict.conv3.norm_layer.bias torch.Size([5, 64])\n",
+ "layer_dict.conv3.norm_layer.weight torch.Size([5, 64])\n",
+ "layer_dict.linear.weights torch.Size([5, 64])\n",
+ "layer_dict.linear.bias torch.Size([5])\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "\r 0%| | 0/1150 [00:00, ?it/s]IOPub data rate exceeded.\n",
+ "The notebook server will temporarily stop sending output\n",
+ "to the client in order to avoid crashing it.\n",
+ "To change this limit, set the config variable\n",
+ "`--NotebookApp.iopub_data_rate_limit`.\n",
+ "\n",
+ "Current values:\n",
+ "NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)\n",
+ "NotebookApp.rate_limit_window=3.0 (secs)\n",
+ "\n",
+ "100%|██████████| 1150/1150 [00:03<00:00, 334.29it/s]\n",
+ " 0%| | 0/50 [00:00, ?it/s]"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Currently loading into memory the val set\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 50/50 [00:00<00:00, 114.76it/s]\n",
+ " 0%| | 0/423 [00:00, ?it/s]"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Currently loading into memory the test set\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 423/423 [00:01<00:00, 294.29it/s]\n",
+ " 57%|█████▋ | 28500/50000 [00:00, ?it/s]"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "data {'train': 23000, 'val': 1000, 'test': 8460}\n",
+ "train_seed 985773, val_seed: 985773, at start time\n",
+ "28500 50000\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "error",
+ "ename": "AssertionError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMetaLearningSystemDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmaml_system\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExperimentBuilder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmaml_system\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36mrun_experiment\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal_epochs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m 'current_iter'],\n\u001b[0;32m--> 311\u001b[0;31m augment_images=self.augment_flag)):\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;31m# print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m train_losses, total_losses, self.state['current_iter'] = self.train_iteration(\n",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36mget_train_batches\u001b[0;34m(self, total_batches, augment_images)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_augmentation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maugment_images\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maugment_images\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal_train_iters_produced\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_of_gpus\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples_per_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msample_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_batched\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_dataloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0msample_batched\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 582\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 583\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0mnext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__next__\u001b[0m \u001b[0;31m# Python 2 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_next_batch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"KeyError:\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 607\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 608\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mAssertionError\u001b[0m: Traceback (most recent call last):\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py\", line 99, in _worker_loop\n samples = collate_fn([dataset[i] for i in batch_indices])\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py\", line 99, in \n samples = collate_fn([dataset[i] for i in batch_indices])\n File \"\", line 438, in __getitem__\n writer.add_images('support_set_images', support_set_images, idx)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/writer.py\", line 404, in add_images\n image(tag, img_tensor, dataformats=dataformats), global_step, walltime)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/summary.py\", line 221, in image\n tensor = convert_to_HWC(tensor, dataformats)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/_utils.py\", line 95, in convert_to_HWC\n tensor shape: {}, input_format: {}\".format(tensor.shape, input_format)\nAssertionError: size of input tensor and input format are different. tensor shape: (5, 1, 1, 28, 28), input_format: NCHW\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "gMlXOtEcMl8B",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file