diff --git a/maml++.ipynb b/maml++.ipynb new file mode 100644 index 000000000..de5920c43 --- /dev/null +++ b/maml++.ipynb @@ -0,0 +1,3227 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "maml++.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "_bgT2a-_HIbT", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# mount google drive\n", + "from google.colab import drive\n", + "drive.mount('/gdrive')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "0y79kypKJ_J3", + "colab": {} + }, + "source": [ + "# copy files from google colab to google drive\n", + "!cp -r omniglot_1_8_0.1_64_5_0/ ../gdrive/My\\ Drive/\n", + "!cp -r runs/ ../gdrive/My\\ Drive/" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab_type": "code", + "id": "f5ohNe6LJ_JX", + "colab": {} + }, + "source": [ + "# retrieve files from google drive to google colab\n", + "!cp -r ../gdrive/My\\ Drive/omniglot_1_8_0.1_64_5_0/ ./omniglot_1_8_0.1_64_5_0/\n", + "!cp -r ../gdrive/My\\ Drive/runs/ ./runs/ " + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "avyMhUz2yTg4", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# clone repository, for dataset and json files\n", + "!git clone https://github.com/AntreasAntoniou/HowToTrainYourMAMLPytorch.git\n", + "!mv HowToTrainYourMAMLPytorch/* ./" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZfXN4qJM9zSr", + "colab_type": "code", + "colab": {} + }, + "source": [ + "!rm -r HowToTrainYourMAMLPytorch/" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lX1pF70a-KZG", + "colab_type": "code", + "outputId": "bf311efe-ffaa-4b08-d4a9-b5c7191da8f2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 246 + } + }, + "source": [ + "# ngrok, dynamic visualization\n", + "!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip\n", + "!unzip ngrok-stable-linux-amd64.zip" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "--2019-06-02 02:16:38-- https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip\n", + "Resolving bin.equinox.io (bin.equinox.io)... 52.73.94.166, 52.7.169.168, 34.232.40.183, ...\n", + "Connecting to bin.equinox.io (bin.equinox.io)|52.73.94.166|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 16648024 (16M) [application/octet-stream]\n", + "Saving to: ‘ngrok-stable-linux-amd64.zip’\n", + "\n", + "ngrok-stable-linux- 100%[===================>] 15.88M 42.3MB/s in 0.4s \n", + "\n", + "2019-06-02 02:16:39 (42.3 MB/s) - ‘ngrok-stable-linux-amd64.zip’ saved [16648024/16648024]\n", + "\n", + "Archive: ngrok-stable-linux-amd64.zip\n", + " inflating: ngrok \n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nw5CGKbw-KYC", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# logs are saved in ./runs directory\n", + "LOG_DIR = './runs'\n", + "get_ipython().system_raw(\n", + " 'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'\n", + " .format(LOG_DIR)\n", + ")" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "EoqoTvFz-KVe", + "colab_type": "code", + "outputId": "d7254665-941b-4ba0-cf8a-813f8c5de3c2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "get_ipython().system_raw('./ngrok http 6006 &')\n", + "! curl -s http://localhost:4040/api/tunnels | python3 -c \\\n", + " \"import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])\"\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "http://e43be063.ngrok.io\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "aTf0w978-RNd", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# for using TPU\n", + "# !pip install \\\n", + "# http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl \\\n", + "# http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dlYcoGLm-RMF", + "colab_type": "code", + "outputId": "3c771328-b72f-41ca-a166-510602dfecf0", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1005 + } + }, + "source": [ + "# for using ngrok, tensorboard with PyTorch, dynamic visualization, we need to upgrade tb-nightly\n", + "%%shell\n", + "pip install --upgrade tb-nightly" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting tensorflow==1.14.0rc0\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/6f/0c/355160095d16e6fe06a06474471b4e7560218b152db88a06f0190e7401c9/tensorflow-1.14.0rc0-cp36-cp36m-manylinux1_x86_64.whl (109.2MB)\n", + "\u001b[K |████████████████████████████████| 109.2MB 205kB/s \n", + "\u001b[?25hRequirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.7.1)\n", + "Requirement already satisfied: tensorboard<1.14.0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.13.1)\n", + "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.15.0)\n", + "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.0.9)\n", + "Collecting tf-estimator-nightly<1.14.0.dev2019042302,>=1.14.0.dev2019042301 (from tensorflow==1.14.0rc0)\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/6a/7d343697ad80f824f9464e63fb910983174767f099ba6ed2c49cc6873f96/tf_estimator_nightly-1.14.0.dev2019042301-py2.py3-none-any.whl (480kB)\n", + "\u001b[K |████████████████████████████████| 481kB 34.6MB/s \n", + "\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.12.0)\n", + "Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.2.2)\n", + "Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.8.0)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.1.0)\n", + "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.0.7)\n", + "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (3.7.1)\n", + "Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (1.16.4)\n", + "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow==1.14.0rc0) (0.33.4)\n", + "Collecting wrapt>=1.11.1 (from tensorflow==1.14.0rc0)\n", + " Downloading https://files.pythonhosted.org/packages/67/b2/0f71ca90b0ade7fad27e3d20327c996c6252a2ffe88f50a95bba7434eda9/wrapt-1.11.1.tar.gz\n", + "Collecting google-pasta>=0.1.6 (from tensorflow==1.14.0rc0)\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d0/33/376510eb8d6246f3c30545f416b2263eee461e40940c2a4413c711bdf62d/google_pasta-0.1.7-py3-none-any.whl (52kB)\n", + "\u001b[K |████████████████████████████████| 61kB 16.3MB/s \n", + "\u001b[?25hRequirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow==1.14.0rc0) (3.1.1)\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow==1.14.0rc0) (0.15.4)\n", + "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow==1.14.0rc0) (2.8.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow==1.14.0rc0) (41.0.1)\n", + "Building wheels for collected packages: wrapt\n", + " Building wheel for wrapt (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Stored in directory: /root/.cache/pip/wheels/89/67/41/63cbf0f6ac0a6156588b9587be4db5565f8c6d8ccef98202fc\n", + "Successfully built wrapt\n", + "\u001b[31mERROR: thinc 6.12.1 has requirement wrapt<1.11.0,>=1.10.0, but you'll have wrapt 1.11.1 which is incompatible.\u001b[0m\n", + "Installing collected packages: tf-estimator-nightly, wrapt, google-pasta, tensorflow\n", + " Found existing installation: wrapt 1.10.11\n", + " Uninstalling wrapt-1.10.11:\n", + " Successfully uninstalled wrapt-1.10.11\n", + " Found existing installation: tensorflow 1.13.1\n", + " Uninstalling tensorflow-1.13.1:\n", + " Successfully uninstalled tensorflow-1.13.1\n", + "Successfully installed google-pasta-0.1.7 tensorflow-1.14.0rc0 tf-estimator-nightly-1.14.0.dev2019042301 wrapt-1.11.1\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (0.16.0)\n", + "Collecting tb-nightly\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/3e/85/f4cfd5ba1d5f67702e05a785d4fa2732c1ba58c66038c7280bb427c6430b/tb_nightly-1.14.0a20190601-py3-none-any.whl (3.1MB)\n", + "\u001b[K |████████████████████████████████| 3.1MB 3.4MB/s \n", + "\u001b[?25hRequirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.33.4)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.16.4)\n", + "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.12.0)\n", + "Requirement already satisfied, skipping upgrade: grpcio>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (1.15.0)\n", + "Requirement already satisfied, skipping upgrade: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.7.1)\n", + "Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (41.0.1)\n", + "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (0.15.4)\n", + "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (3.1.1)\n", + "Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly) (3.7.1)\n", + "Installing collected packages: tb-nightly\n", + "Successfully installed tb-nightly-1.14.0a20190601\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 15 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CXSGKa4l-RKj", + "colab_type": "code", + "outputId": "78cc8153-1aab-4643-9e7c-32410f7f65fc", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "from __future__ import print_function, unicode_literals, division\n", + "from IPython.core.debugger import set_trace\n", + "from IPython.display import HTML, Math\n", + "from pprint import pprint\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.autograd as autograd\n", + "from torchvision import datasets\n", + "from torch.nn.utils.weight_norm import WeightNorm\n", + "from torch.optim.lr_scheduler import StepLR\n", + "from torch.optim import Adam\n", + "from torchvision.transforms import Compose, RandomHorizontalFlip, RandomResizedCrop, ToTensor, Normalize, \\\n", + " CenterCrop, Resize, ColorJitter, ToPILImage, RandomCrop\n", + "import torchvision.models as models\n", + "from torch.autograd import Variable\n", + "from torch.utils.data import DataLoader, Dataset, sampler\n", + "from torchvision.utils import make_grid, save_image\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "import torch.utils.checkpoint as cp\n", + "\n", + "# for using TPU\n", + "# import torch_xla\n", + "# import torch_xla\n", + "# import torch_xla_py.utils as xu\n", + "# import torch_xla_py.xla_model as xm\n", + "\n", + "import json, glob, time, math, os, datetime, string, random, re, warnings, itertools, \\\n", + " logging, numbers, csv, pickle, tqdm\n", + "from copy import copy\n", + "\n", + "import numpy as np\n", + "from PIL import Image, ImageEnhance, ImageFile\n", + "from collections import OrderedDict\n", + "import concurrent.futures\n", + "\n", + "ImageFile.LOAD_TRUNCATED_IMAGES = True\n", + "\n", + "print(\"CUDA available: \", torch.cuda.is_available())\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "\n", + "os.environ['DATASET_DIR'] = './datasets'\n", + "\n", + "# for writing into tensorboard\n", + "writer = SummaryWriter()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "CUDA available: True\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "d8LgqZ6qy1L7", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class rotate_image(object):\n", + "\n", + " def __init__(self, k, channels):\n", + " self.k = k\n", + " self.channels = channels\n", + "\n", + " def __call__(self, image):\n", + " if self.channels == 1 and len(image.shape) == 3:\n", + " image = image[:, :, 0]\n", + " image = np.expand_dims(image, axis=2)\n", + "\n", + " elif self.channels == 1 and len(image.shape) == 4:\n", + " image = image[:, :, :, 0]\n", + " image = np.expand_dims(image, axis=3)\n", + "\n", + " image = np.rot90(image, k=self.k).copy()\n", + " return image" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "GZ6A6o8oJLil", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class torch_rotate_image(object):\n", + "\n", + " def __init__(self, k, channels):\n", + " self.k = k\n", + " self.channels = channels\n", + "\n", + " def __call__(self, image):\n", + " rotate = RandomRotation(degrees=self.k * 90)\n", + " if image.shape[-1] == 1:\n", + " image = image[:, :, 0]\n", + " image = Image.fromarray(image)\n", + " image = rotate(image)\n", + " image = np.array(image)\n", + " if len(image.shape) == 2:\n", + " image = np.expand_dims(image, axis=2)\n", + " return image" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Fcg1Bxs3ythc", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def augment_image(image, k, channels, augment_bool, args, dataset_name):\n", + " transform_train, transform_evaluation = get_transforms_for_dataset(dataset_name=dataset_name,\n", + " args=args, k=k)\n", + " if len(image.shape) > 3:\n", + " images = [item for item in image]\n", + " output_images = []\n", + " for image in images:\n", + " if augment_bool is True:\n", + " for transform_current in transform_train:\n", + " image = transform_current(image)\n", + " else:\n", + " for transform_current in transform_evaluation:\n", + " image = transform_current(image)\n", + " output_images.append(image)\n", + " image = torch.stack(output_images)\n", + " else:\n", + " if augment_bool is True:\n", + " # meanstd transformation\n", + " for transform_current in transform_train:\n", + " image = transform_current(image)\n", + " else:\n", + " for transform_current in transform_evaluation:\n", + " image = transform_current(image)\n", + " return image" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "9NVCa1zGywce", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def get_transforms_for_dataset(dataset_name, args, k):\n", + " if \"cifar10\" in dataset_name or \"cifar100\" in dataset_name:\n", + " transform_train = [\n", + " RandomCrop(32, padding=4),\n", + " RandomHorizontalFlip(),\n", + " ToTensor(),\n", + " Normalize(args.classification_mean, args.classification_std)]\n", + "\n", + " transform_evaluate = [\n", + " ToTensor(),\n", + " Normalize(args.classification_mean, args.classification_std)]\n", + "\n", + " elif 'omniglot' in dataset_name:\n", + "\n", + " transform_train = [rotate_image(k=k, channels=args.image_channels), ToTensor()]\n", + " transform_evaluate = [ToTensor()]\n", + "\n", + "\n", + " elif 'imagenet' in dataset_name:\n", + "\n", + " transform_train = [Compose([\n", + "\n", + " ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])]\n", + "\n", + " transform_evaluate = [Compose([\n", + "\n", + " ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])]\n", + "\n", + " return transform_train, transform_evaluate" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "8E3bStPExiaz", + "colab_type": "code", + "outputId": "79a03442-1fd5-4c6e-bef5-9b478671579a", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 132 + } + }, + "source": [ + "class FewShotLearningDatasetParallel(Dataset):\n", + " def __init__(self, args):\n", + " \"\"\"\n", + " A data provider class inheriting from Pytorch's Dataset class. It takes care of creating task sets for\n", + " our few-shot learning model training and evaluation\n", + " :param args: Arguments in the form of a Bunch object. Includes all hyperparameters necessary for the\n", + " data-provider. For transparency and readability reasons to explicitly set as self.object_name all arguments\n", + " required for the data provider, such that the reader knows exactly what is necessary for the data provider/\n", + " \"\"\"\n", + " self.data_path = args.dataset_path\n", + " self.dataset_name = args.dataset_name\n", + " self.data_loaded_in_memory = False\n", + " self.image_height, self.image_width, self.image_channel = args.image_height, args.image_width, args.image_channels\n", + " self.args = args\n", + " self.indexes_of_folders_indicating_class = args.indexes_of_folders_indicating_class\n", + " self.reverse_channels = args.reverse_channels\n", + " self.labels_as_int = args.labels_as_int\n", + " self.train_val_test_split = args.train_val_test_split\n", + " self.current_set_name = \"train\"\n", + " self.num_target_samples = args.num_target_samples\n", + " self.reset_stored_filepaths = args.reset_stored_filepaths\n", + " val_rng = np.random.RandomState(seed=args.val_seed)\n", + " val_seed = val_rng.randint(1, 999999)\n", + " train_rng = np.random.RandomState(seed=args.train_seed)\n", + " train_seed = train_rng.randint(1, 999999)\n", + " test_rng = np.random.RandomState(seed=args.val_seed)\n", + " test_seed = test_rng.randint(1, 999999)\n", + " args.val_seed = val_seed\n", + " args.train_seed = train_seed\n", + " args.test_seed = test_seed\n", + " self.init_seed = {\"train\": args.train_seed, \"val\": args.val_seed, 'test': args.val_seed}\n", + " self.seed = {\"train\": args.train_seed, \"val\": args.val_seed, 'test': args.val_seed}\n", + " self.num_of_gpus = args.num_of_gpus\n", + " self.batch_size = args.batch_size\n", + "\n", + " self.train_index = 0\n", + " self.val_index = 0\n", + " self.test_index = 0\n", + "\n", + " self.augment_images = False\n", + " self.num_samples_per_class = args.num_samples_per_class\n", + " self.num_classes_per_set = args.num_classes_per_set\n", + "\n", + " self.rng = np.random.RandomState(seed=self.seed['val'])\n", + " self.datasets = self.load_dataset()\n", + "\n", + " self.indexes = {\"train\": 0, \"val\": 0, 'test': 0}\n", + " self.dataset_size_dict = {\n", + " \"train\": {key: len(self.datasets['train'][key]) for key in list(self.datasets['train'].keys())},\n", + " \"val\": {key: len(self.datasets['val'][key]) for key in list(self.datasets['val'].keys())},\n", + " 'test': {key: len(self.datasets['test'][key]) for key in list(self.datasets['test'].keys())}}\n", + " self.label_set = self.get_label_set()\n", + " self.data_length = {name: np.sum([len(self.datasets[name][key])\n", + " for key in self.datasets[name]]) for name in self.datasets.keys()}\n", + "\n", + " print(\"data\", self.data_length)\n", + " self.observed_seed_set = None\n", + "\n", + " def load_dataset(self):\n", + " \"\"\"\n", + " Loads a dataset's dictionary files and splits the data according to the train_val_test_split variable stored\n", + " in the args object.\n", + " :return: Three sets, the training set, validation set and test sets (referred to as the meta-train,\n", + " meta-val and meta-test in the paper)\n", + " \"\"\"\n", + " rng = np.random.RandomState(seed=self.seed['val'])\n", + "\n", + " if self.args.sets_are_pre_split == True:\n", + " data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths()\n", + " dataset_splits = dict()\n", + " for key, value in data_image_paths.items():\n", + " key = self.get_label_from_index(index=key)\n", + " bits = key.split(\"/\")\n", + " set_name = bits[0]\n", + " class_label = bits[1]\n", + " if set_name not in dataset_splits:\n", + " dataset_splits[set_name] = {class_label: value}\n", + " else:\n", + " dataset_splits[set_name][class_label] = value\n", + " else:\n", + " data_image_paths, index_to_label_name_dict_file, label_to_index = self.load_datapaths()\n", + " total_label_types = len(data_image_paths)\n", + " num_classes_idx = np.arange(len(data_image_paths.keys()), dtype=np.int32)\n", + " rng.shuffle(num_classes_idx)\n", + " keys = list(data_image_paths.keys())\n", + " values = list(data_image_paths.values())\n", + " new_keys = [keys[idx] for idx in num_classes_idx]\n", + " new_values = [values[idx] for idx in num_classes_idx]\n", + " data_image_paths = dict(zip(new_keys, new_values))\n", + " # data_image_paths = self.shuffle(data_image_paths)\n", + " x_train_id, x_val_id, x_test_id = int(self.train_val_test_split[0] * total_label_types), \\\n", + " int(np.sum(self.train_val_test_split[:2]) * total_label_types), \\\n", + " int(total_label_types)\n", + " print(x_train_id, x_val_id, x_test_id)\n", + " x_train_classes = (class_key for class_key in list(data_image_paths.keys())[:x_train_id])\n", + " x_val_classes = (class_key for class_key in list(data_image_paths.keys())[x_train_id:x_val_id])\n", + " x_test_classes = (class_key for class_key in list(data_image_paths.keys())[x_val_id:x_test_id])\n", + " x_train, x_val, x_test = {class_key: data_image_paths[class_key] for class_key in x_train_classes}, \\\n", + " {class_key: data_image_paths[class_key] for class_key in x_val_classes}, \\\n", + " {class_key: data_image_paths[class_key] for class_key in x_test_classes},\n", + " dataset_splits = {\"train\": x_train, \"val\":x_val , \"test\": x_test}\n", + "\n", + " if self.args.load_into_memory is True:\n", + "\n", + " print(\"Loading data into RAM\")\n", + " x_loaded = {\"train\": [], \"val\": [], \"test\": []}\n", + "\n", + " for set_key, set_value in dataset_splits.items():\n", + " print(\"Currently loading into memory the {} set\".format(set_key))\n", + " x_loaded[set_key] = {key: np.zeros(len(value), ) for key, value in set_value.items()}\n", + " # for class_key, class_value in set_value.items():\n", + " with tqdm.tqdm(total=len(set_value)) as pbar_memory_load:\n", + " with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:\n", + " # Process the list of files, but split the work across the process pool to use all CPUs!\n", + " for (class_label, class_images_loaded) in executor.map(self.load_parallel_batch, (set_value.items())):\n", + " x_loaded[set_key][class_label] = class_images_loaded\n", + " pbar_memory_load.update(1)\n", + "\n", + " dataset_splits = x_loaded\n", + " self.data_loaded_in_memory = True\n", + "\n", + " return dataset_splits\n", + "\n", + " def load_datapaths(self):\n", + " \"\"\"\n", + " If saved json dictionaries of the data are available, then this method loads the dictionaries such that the\n", + " data is ready to be read. If the json dictionaries do not exist, then this method calls get_data_paths()\n", + " which will build the json dictionary containing the class to filepath samples, and then store them.\n", + " :return: data_image_paths: dict containing class to filepath list pairs.\n", + " index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable\n", + " string-names of the class\n", + " label_to_index: dictionary containing human understandable string mapped to numerical indexes\n", + " \"\"\"\n", + " dataset_dir = os.environ['DATASET_DIR']\n", + " data_path_file = \"{}/{}.json\".format(dataset_dir, self.dataset_name)\n", + " print(dataset_dir)\n", + " print(data_path_file)\n", + " self.index_to_label_name_dict_file = \"{}/map_to_label_name_{}.json\".format(dataset_dir, self.dataset_name)\n", + " self.label_name_to_map_dict_file = \"{}/label_name_to_map_{}.json\".format(dataset_dir, self.dataset_name)\n", + "\n", + " if not os.path.exists(data_path_file):\n", + " self.reset_stored_filepaths = True\n", + "\n", + " if self.reset_stored_filepaths == True:\n", + " if os.path.exists(data_path_file):\n", + " os.remove(data_path_file)\n", + " self.reset_stored_filepaths = False\n", + "\n", + " try:\n", + " data_image_paths = self.load_from_json(filename=data_path_file)\n", + " label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file)\n", + " index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file)\n", + " print(data_image_paths, index_to_label_name_dict_file, label_to_index)\n", + " return data_image_paths, index_to_label_name_dict_file, label_to_index\n", + " except:\n", + " print(\"Mapped data paths can't be found, remapping paths..\")\n", + " data_image_paths, code_to_label_name, label_name_to_code = self.get_data_paths()\n", + " self.save_to_json(dict_to_store=data_image_paths, filename=data_path_file)\n", + " self.save_to_json(dict_to_store=code_to_label_name, filename=self.index_to_label_name_dict_file)\n", + " self.save_to_json(dict_to_store=label_name_to_code, filename=self.label_name_to_map_dict_file)\n", + " return self.load_datapaths()\n", + "\n", + " def save_to_json(self, filename, dict_to_store):\n", + " with open(os.path.abspath(filename), 'w') as f:\n", + " json.dump(dict_to_store, fp=f)\n", + "\n", + " def load_from_json(self, filename):\n", + " with open(filename, mode=\"r\") as f:\n", + " load_dict = json.load(fp=f)\n", + "\n", + " return load_dict\n", + "\n", + " def load_test_image(self, filepath):\n", + " \"\"\"\n", + " Tests whether a target filepath contains an uncorrupted image. If image is corrupted, attempt to fix.\n", + " :param filepath: Filepath of image to be tested\n", + " :return: Return filepath of image if image exists and is uncorrupted (or attempt to fix has succeeded),\n", + " else return None\n", + " \"\"\"\n", + " image = None\n", + " try:\n", + " image = Image.open(filepath)\n", + " except RuntimeWarning:\n", + " os.system(\"convert {} -strip {}\".format(filepath, filepath))\n", + " print(\"converting\")\n", + " image = Image.open(filepath)\n", + " except:\n", + " print(\"Broken image\")\n", + "\n", + " if image is not None:\n", + " return filepath\n", + " else:\n", + " return None\n", + "\n", + " def get_data_paths(self):\n", + " \"\"\"\n", + " Method that scans the dataset directory and generates class to image-filepath list dictionaries.\n", + " :return: data_image_paths: dict containing class to filepath list pairs.\n", + " index_to_label_name_dict_file: dict containing numerical indexes mapped to the human understandable\n", + " string-names of the class\n", + " label_to_index: dictionary containing human understandable string mapped to numerical indexes\n", + " \"\"\"\n", + " print(\"Get images from\", self.data_path)\n", + " data_image_path_list_raw = []\n", + " labels = set()\n", + " for subdir, dir, files in os.walk(self.data_path):\n", + " for file in files:\n", + " if (\".jpeg\") in file.lower() or (\".png\") in file.lower() or (\".jpg\") in file.lower():\n", + " filepath = os.path.abspath(os.path.join(subdir, file))\n", + " label = self.get_label_from_path(filepath)\n", + " data_image_path_list_raw.append(filepath)\n", + " labels.add(label)\n", + "\n", + " labels = sorted(labels)\n", + " idx_to_label_name = {idx: label for idx, label in enumerate(labels)}\n", + " label_name_to_idx = {label: idx for idx, label in enumerate(labels)}\n", + " data_image_path_dict = {idx: [] for idx in list(idx_to_label_name.keys())}\n", + " with tqdm.tqdm(total=len(data_image_path_list_raw)) as pbar_error:\n", + " with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:\n", + " # Process the list of files, but split the work across the process pool to use all CPUs!\n", + " for image_file in executor.map(self.load_test_image, (data_image_path_list_raw)):\n", + " pbar_error.update(1)\n", + " if image_file is not None:\n", + " label = self.get_label_from_path(image_file)\n", + " data_image_path_dict[label_name_to_idx[label]].append(image_file)\n", + "\n", + " return data_image_path_dict, idx_to_label_name, label_name_to_idx\n", + "\n", + " def get_label_set(self):\n", + " \"\"\"\n", + " Generates a set containing all class numerical indexes\n", + " :return: A set containing all class numerical indexes\n", + " \"\"\"\n", + " index_to_label_name_dict_file = self.load_from_json(filename=self.index_to_label_name_dict_file)\n", + " return set(list(index_to_label_name_dict_file.keys()))\n", + "\n", + " def get_index_from_label(self, label):\n", + " \"\"\"\n", + " Given a class's (human understandable) string, returns the numerical index of that class\n", + " :param label: A string of a human understandable class contained in the dataset\n", + " :return: An int containing the numerical index of the given class-string\n", + " \"\"\"\n", + " label_to_index = self.load_from_json(filename=self.label_name_to_map_dict_file)\n", + " return label_to_index[label]\n", + "\n", + " def get_label_from_index(self, index):\n", + " \"\"\"\n", + " Given an index return the human understandable label mapping to it.\n", + " :param index: A numerical index (int)\n", + " :return: A human understandable label (str)\n", + " \"\"\"\n", + " index_to_label_name = self.load_from_json(filename=self.index_to_label_name_dict_file)\n", + " return index_to_label_name[index]\n", + "\n", + " def get_label_from_path(self, filepath):\n", + " \"\"\"\n", + " Given a path of an image generate the human understandable label for that image.\n", + " :param filepath: The image's filepath\n", + " :return: A human understandable label.\n", + " \"\"\"\n", + " label_bits = filepath.split(\"/\")\n", + " label = \"/\".join([label_bits[idx] for idx in self.indexes_of_folders_indicating_class])\n", + " if self.labels_as_int:\n", + " label = int(label)\n", + " return label\n", + "\n", + " def load_image(self, image_path, channels):\n", + " \"\"\"\n", + " Given an image filepath and the number of channels to keep, load an image and keep the specified channels\n", + " :param image_path: The image's filepath\n", + " :param channels: The number of channels to keep\n", + " :return: An image array of shape (h, w, channels), whose values range between 0.0 and 1.0.\n", + " \"\"\"\n", + " if not self.data_loaded_in_memory:\n", + " image = Image.open(image_path)\n", + " if 'omniglot' in self.dataset_name:\n", + " image = image.resize((self.image_height, self.image_width), resample=Image.LANCZOS)\n", + " image = np.array(image, np.float32)\n", + " if channels == 1:\n", + " image = np.expand_dims(image, axis=2)\n", + " else:\n", + " image = image.resize((self.image_height, self.image_width)).convert('RGB')\n", + " image = np.array(image, np.float32)\n", + " image = image / 255.0\n", + " else:\n", + " image = image_path\n", + "\n", + " return image\n", + "\n", + " def load_batch(self, batch_image_paths):\n", + " \"\"\"\n", + " Load a batch of images, given a list of filepaths\n", + " :param batch_image_paths: A list of filepaths\n", + " :return: A numpy array of images of shape batch, height, width, channels\n", + " \"\"\"\n", + " image_batch = []\n", + "\n", + " if self.data_loaded_in_memory:\n", + " for image_path in batch_image_paths:\n", + " image_batch.append(image_path)\n", + " image_batch = np.array(image_batch, dtype=np.float32)\n", + " #print(image_batch.shape)\n", + " else:\n", + " image_batch = [self.load_image(image_path=image_path, channels=self.image_channel)\n", + " for image_path in batch_image_paths]\n", + " image_batch = np.array(image_batch, dtype=np.float32)\n", + " image_batch = self.preprocess_data(image_batch)\n", + "\n", + " return image_batch\n", + "\n", + " def load_parallel_batch(self, inputs):\n", + " \"\"\"\n", + " Load a batch of images, given a list of filepaths\n", + " :param batch_image_paths: A list of filepaths\n", + " :return: A numpy array of images of shape batch, height, width, channels\n", + " \"\"\"\n", + " class_label, batch_image_paths = inputs\n", + " image_batch = []\n", + "\n", + " if self.data_loaded_in_memory:\n", + " for image_path in batch_image_paths:\n", + " image_batch.append(np.copy(image_path))\n", + " image_batch = np.array(image_batch, dtype=np.float32)\n", + " else:\n", + " #with tqdm.tqdm(total=1) as load_pbar:\n", + " image_batch = [self.load_image(image_path=image_path, channels=self.image_channel)\n", + " for image_path in batch_image_paths]\n", + " #load_pbar.update(1)\n", + "\n", + " image_batch = np.array(image_batch, dtype=np.float32)\n", + " image_batch = self.preprocess_data(image_batch)\n", + "\n", + " return class_label, image_batch\n", + "\n", + " def preprocess_data(self, x):\n", + " \"\"\"\n", + " Preprocesses data such that their shapes match the specified structures\n", + " :param x: A data batch to preprocess\n", + " :return: A preprocessed data batch\n", + " \"\"\"\n", + " x_shape = x.shape\n", + " x = np.reshape(x, (-1, x_shape[-3], x_shape[-2], x_shape[-1]))\n", + " if self.reverse_channels is True:\n", + " reverse_photos = np.ones(shape=x.shape)\n", + " for channel in range(x.shape[-1]):\n", + " reverse_photos[:, :, :, x.shape[-1] - 1 - channel] = x[:, :, :, channel]\n", + " x = reverse_photos\n", + " x = x.reshape(x_shape)\n", + " return x\n", + "\n", + " def reconstruct_original(self, x):\n", + " \"\"\"\n", + " Applies the reverse operations that preprocess_data() applies such that the data returns to their original form\n", + " :param x: A batch of data to reconstruct\n", + " :return: A reconstructed batch of data\n", + " \"\"\"\n", + " x = x * 255.0\n", + " return x\n", + "\n", + " def shuffle(self, x, rng):\n", + " \"\"\"\n", + " Shuffles the data batch along it's first axis\n", + " :param x: A data batch\n", + " :return: A shuffled data batch\n", + " \"\"\"\n", + " indices = np.arange(len(x))\n", + " rng.shuffle(indices)\n", + " x = x[indices]\n", + " return x\n", + "\n", + " def get_set(self, dataset_name, seed, augment_images=False, step):\n", + " \"\"\"\n", + " Generates a task-set to be used for training or evaluation\n", + " :param set_name: The name of the set to use, e.g. \"train\", \"val\" etc.\n", + " :return: A task-set containing an image and label support set, and an image and label target set.\n", + " \"\"\"\n", + " #seed = seed % self.args.total_unique_tasks\n", + " rng = np.random.RandomState(seed)\n", + " selected_classes = rng.choice(list(self.dataset_size_dict[dataset_name].keys()),\n", + " size=self.num_classes_per_set, replace=False)\n", + " rng.shuffle(selected_classes)\n", + " k_list = rng.randint(0, 4, size=self.num_classes_per_set)\n", + " k_dict = {selected_class: k_item for (selected_class, k_item) in zip(selected_classes, k_list)}\n", + " episode_labels = [i for i in range(self.num_classes_per_set)]\n", + " class_to_episode_label = {selected_class: episode_label for (selected_class, episode_label) in\n", + " zip(selected_classes, episode_labels)}\n", + "\n", + " x_images = []\n", + " y_labels = []\n", + "\n", + " for class_entry in selected_classes:\n", + " choose_samples_list = rng.choice(self.dataset_size_dict[dataset_name][class_entry],\n", + " size=self.num_samples_per_class + self.num_target_samples, replace=False)\n", + " class_image_samples = []\n", + " class_labels = []\n", + " for sample in choose_samples_list:\n", + " choose_samples = self.datasets[dataset_name][class_entry][sample]\n", + " x_class_data = self.load_batch([choose_samples])[0]\n", + " k = k_dict[class_entry]\n", + " x_class_data = augment_image(image=x_class_data, k=k,\n", + " channels=self.image_channel, augment_bool=augment_images,\n", + " dataset_name=self.dataset_name, args=self.args)\n", + " class_image_samples.append(x_class_data)\n", + " class_labels.append(int(class_to_episode_label[class_entry]))\n", + " class_image_samples = torch.stack(class_image_samples)\n", + " x_images.append(class_image_samples)\n", + " y_labels.append(class_labels)\n", + "\n", + " x_images = torch.stack(x_images)\n", + " y_labels = np.array(y_labels, dtype=np.float32)\n", + "\n", + " support_set_images = x_images[:, :self.num_samples_per_class]\n", + " support_set_labels = y_labels[:, :self.num_samples_per_class]\n", + " target_set_images = x_images[:, self.num_samples_per_class:]\n", + " target_set_labels = y_labels[:, self.num_samples_per_class:]\n", + " writer.add_images('support_set_images', support_set_images, step)\n", + " writer.add_images('target_set_images', target_set_image, step)\n", + " \n", + " return support_set_images, target_set_images, support_set_labels, target_set_labels, seed\n", + "\n", + " def __len__(self):\n", + " total_samples = self.data_length[self.current_set_name]\n", + " return total_samples\n", + "\n", + " def length(self, set_name):\n", + " self.switch_set(set_name=set_name)\n", + " return len(self)\n", + "\n", + " def set_augmentation(self, augment_images):\n", + " self.augment_images = augment_images\n", + "\n", + " def switch_set(self, set_name, current_iter=None):\n", + " self.current_set_name = set_name\n", + " if set_name == \"train\":\n", + " self.update_seed(dataset_name=set_name, seed=self.init_seed[set_name] + current_iter)\n", + "\n", + " def update_seed(self, dataset_name, seed=100):\n", + " self.seed[dataset_name] = seed\n", + "\n", + " def __getitem__(self, idx):\n", + " \n", + " support_set_images, target_set_image, support_set_labels, target_set_label, seed = \\\n", + " self.get_set(self.current_set_name, seed=self.seed[self.current_set_name] + idx,\n", + " augment_images=self.augment_images, idx)\n", + " \n", + " return support_set_images, target_set_image, support_set_labels, target_set_label, seed\n", + "\n", + " def reset_seed(self):\n", + " self.seed = self.init_seed" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "error", + "ename": "SyntaxError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"\"\u001b[0;36m, line \u001b[0;32m367\u001b[0m\n\u001b[0;31m def get_set(self, dataset_name, seed, augment_images=False, step):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m non-default argument follows default argument\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PzzTQac1xOJW", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaLearningSystemDataLoader(object):\n", + " def __init__(self, args, current_iter=0):\n", + " \"\"\"\n", + " Initializes a meta learning system dataloader. The data loader uses the Pytorch DataLoader class to parallelize\n", + " batch sampling and preprocessing.\n", + " :param args: An arguments NamedTuple containing all the required arguments.\n", + " :param current_iter: Current iter of experiment. Is used to make sure the data loader continues where it left\n", + " of previously.\n", + " \"\"\"\n", + " self.num_of_gpus = args.num_of_gpus\n", + " self.batch_size = args.batch_size\n", + " self.samples_per_iter = args.samples_per_iter\n", + " self.num_workers = args.num_dataprovider_workers\n", + " self.total_train_iters_produced = 0\n", + " self.dataset = FewShotLearningDatasetParallel(args=args)\n", + " self.batches_per_iter = args.samples_per_iter\n", + " self.full_data_length = self.dataset.data_length\n", + " self.continue_from_iter(current_iter=current_iter)\n", + " self.args = args\n", + "\n", + " def get_dataloader(self):\n", + " \"\"\"\n", + " Returns a data loader with the correct set (train, val or test), continuing from the current iter.\n", + " :return:\n", + " \"\"\"\n", + " return DataLoader(self.dataset, batch_size=(self.num_of_gpus * self.batch_size * self.samples_per_iter),\n", + " shuffle=False, num_workers=self.num_workers, drop_last=True)\n", + "\n", + " def continue_from_iter(self, current_iter):\n", + " \"\"\"\n", + " Makes sure the data provider is aware of where we are in terms of training iterations in the experiment.\n", + " :param current_iter:\n", + " \"\"\"\n", + " self.total_train_iters_produced += (current_iter * (self.num_of_gpus * self.batch_size * self.samples_per_iter))\n", + "\n", + " def get_train_batches(self, total_batches=-1, augment_images=False):\n", + " \"\"\"\n", + " Returns a training batches data_loader\n", + " :param total_batches: The number of batches we want the data loader to sample\n", + " :param augment_images: Whether we want the images to be augmented.\n", + " \"\"\"\n", + " if total_batches == -1:\n", + " self.dataset.data_length = self.full_data_length\n", + " else:\n", + " self.dataset.data_length[\"train\"] = total_batches * self.dataset.batch_size\n", + " self.dataset.switch_set(set_name=\"train\", current_iter=self.total_train_iters_produced)\n", + " self.dataset.set_augmentation(augment_images=augment_images)\n", + " self.total_train_iters_produced += (self.num_of_gpus * self.batch_size * self.samples_per_iter)\n", + " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n", + " yield sample_batched\n", + "\n", + "\n", + " def get_val_batches(self, total_batches=-1, augment_images=False):\n", + " \"\"\"\n", + " Returns a validation batches data_loader\n", + " :param total_batches: The number of batches we want the data loader to sample\n", + " :param augment_images: Whether we want the images to be augmented.\n", + " \"\"\"\n", + " if total_batches == -1:\n", + " self.dataset.data_length = self.full_data_length\n", + " else:\n", + " self.dataset.data_length['val'] = total_batches * self.dataset.batch_size\n", + " self.dataset.switch_set(set_name=\"val\")\n", + " self.dataset.set_augmentation(augment_images=augment_images)\n", + " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n", + " yield sample_batched\n", + "\n", + "\n", + " def get_test_batches(self, total_batches=-1, augment_images=False):\n", + " \"\"\"\n", + " Returns a testing batches data_loader\n", + " :param total_batches: The number of batches we want the data loader to sample\n", + " :param augment_images: Whether we want the images to be augmented.\n", + " \"\"\"\n", + " if total_batches == -1:\n", + " self.dataset.data_length = self.full_data_length\n", + " else:\n", + " self.dataset.data_length['test'] = total_batches * self.dataset.batch_size\n", + " self.dataset.switch_set(set_name='test')\n", + " self.dataset.set_augmentation(augment_images=augment_images)\n", + " for sample_id, sample_batched in enumerate(self.get_dataloader()):\n", + " yield sample_batched" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "HCshYy8PvXly", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def extract_top_level_dict(current_dict):\n", + " \"\"\"\n", + " Builds a graph dictionary from the passed depth_keys, value pair. Useful for dynamically passing external params\n", + " :param depth_keys: A list of strings making up the name of a variable. Used to make a graph for that params tree.\n", + " :param value: Param value\n", + " :param key_exists: If none then assume new dict, else load existing dict and add new key->value pairs to it.\n", + " :return: A dictionary graph of the params already added to the graph.\n", + " \"\"\"\n", + " output_dict = dict()\n", + " for key in current_dict.keys():\n", + " name = key.replace(\"layer_dict.\", \"\")\n", + " top_level = name.split(\".\")[0]\n", + " sub_level = \".\".join(name.split(\".\")[1:])\n", + "\n", + " if top_level not in output_dict:\n", + " if sub_level == \"\":\n", + " output_dict[top_level] = current_dict[key]\n", + " else:\n", + " output_dict[top_level] = {sub_level: current_dict[key]}\n", + " else:\n", + " new_item = {key: value for key, value in output_dict[top_level].items()}\n", + " new_item[sub_level] = current_dict[key]\n", + " output_dict[top_level] = new_item\n", + "\n", + " #print(current_dict.keys(), output_dict.keys())\n", + " return output_dict\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "brbh1ISiqIwN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaConv2dLayer(nn.Module):\n", + " def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_bias, groups=1, dilation_rate=1):\n", + " \"\"\"\n", + " A MetaConv2D layer. Applies the same functionality of a standard Conv2D layer with the added functionality of\n", + " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n", + " weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta\n", + " learning setting.\n", + " :param in_channels: Number of input channels\n", + " :param out_channels: Number of output channels\n", + " :param kernel_size: Convolutional kernel size\n", + " :param stride: Convolutional stride\n", + " :param padding: Convolution padding\n", + " :param use_bias: Boolean indicating whether to use a bias or not.\n", + " \"\"\"\n", + " super(MetaConv2dLayer, self).__init__()\n", + " num_filters = out_channels\n", + " self.stride = int(stride)\n", + " self.padding = int(padding)\n", + " self.dilation_rate = int(dilation_rate)\n", + " self.use_bias = use_bias\n", + " self.groups = int(groups)\n", + " self.weight = nn.Parameter(torch.empty(num_filters, in_channels, kernel_size, kernel_size))\n", + " nn.init.xavier_uniform_(self.weight)\n", + "\n", + " if self.use_bias:\n", + " self.bias = nn.Parameter(torch.zeros(num_filters))\n", + "\n", + " def forward(self, x, params=None):\n", + " \"\"\"\n", + " Applies a conv2D forward pass. If params are not None will use the passed params as the conv weights and biases\n", + " :param x: Input image batch.\n", + " :param params: If none, then conv layer will use the stored self.weights and self.bias, if they are not none\n", + " then the conv layer will use the passed params as its parameters.\n", + " :return: The output of a convolutional function.\n", + " \"\"\"\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + " if self.use_bias:\n", + " (weight, bias) = params[\"weight\"], params[\"bias\"]\n", + " else:\n", + " (weight) = params[\"weight\"]\n", + " bias = None\n", + " else:\n", + " #print(\"No inner loop params\")\n", + " if self.use_bias:\n", + " weight, bias = self.weight, self.bias\n", + " else:\n", + " weight = self.weight\n", + " bias = None\n", + "\n", + " out = F.conv2d(input=x, weight=weight, bias=bias, stride=self.stride,\n", + " padding=self.padding, dilation=self.dilation_rate, groups=self.groups)\n", + " return out\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5_yy9Vluuu1E", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaBatchNormLayer(nn.Module):\n", + " def __init__(self, num_features, device, args, eps=1e-5, momentum=0.1, affine=True,\n", + " track_running_stats=True, meta_batch_norm=True, no_learnable_params=False,\n", + " use_per_step_bn_statistics=False):\n", + " \"\"\"\n", + " A MetaBatchNorm layer. Applies the same functionality of a standard BatchNorm layer with the added functionality of\n", + " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n", + " weights instead of the internal ones stored in the conv layer. Useful for inner loop optimization in the meta\n", + " learning setting. Also has the additional functionality of being able to store per step running stats and per step beta and gamma.\n", + " :param num_features:\n", + " :param device:\n", + " :param args:\n", + " :param eps:\n", + " :param momentum:\n", + " :param affine:\n", + " :param track_running_stats:\n", + " :param meta_batch_norm:\n", + " :param no_learnable_params:\n", + " :param use_per_step_bn_statistics:\n", + " \"\"\"\n", + " super(MetaBatchNormLayer, self).__init__()\n", + " self.num_features = num_features\n", + " self.eps = eps\n", + "\n", + " self.affine = affine\n", + " self.track_running_stats = track_running_stats\n", + " self.meta_batch_norm = meta_batch_norm\n", + " self.num_features = num_features\n", + " self.device = device\n", + " self.use_per_step_bn_statistics = use_per_step_bn_statistics\n", + " self.args = args\n", + " self.learnable_gamma = self.args.learnable_bn_gamma\n", + " self.learnable_beta = self.args.learnable_bn_beta\n", + "\n", + " if use_per_step_bn_statistics:\n", + " self.running_mean = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features),\n", + " requires_grad=False)\n", + " self.running_var = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features),\n", + " requires_grad=False)\n", + " self.bias = nn.Parameter(torch.zeros(args.number_of_training_steps_per_iter, num_features),\n", + " requires_grad=self.learnable_beta)\n", + " self.weight = nn.Parameter(torch.ones(args.number_of_training_steps_per_iter, num_features),\n", + " requires_grad=self.learnable_gamma)\n", + " else:\n", + " self.running_mean = nn.Parameter(torch.zeros(num_features), requires_grad=False)\n", + " self.running_var = nn.Parameter(torch.zeros(num_features), requires_grad=False)\n", + " self.bias = nn.Parameter(torch.zeros(num_features),\n", + " requires_grad=self.learnable_beta)\n", + " self.weight = nn.Parameter(torch.ones(num_features),\n", + " requires_grad=self.learnable_gamma)\n", + "\n", + " if self.args.enable_inner_loop_optimizable_bn_params:\n", + " self.bias = nn.Parameter(torch.zeros(num_features),\n", + " requires_grad=self.learnable_beta)\n", + " self.weight = nn.Parameter(torch.ones(num_features),\n", + " requires_grad=self.learnable_gamma)\n", + "\n", + " self.backup_running_mean = torch.zeros(self.running_mean.shape)\n", + " self.backup_running_var = torch.ones(self.running_var.shape)\n", + "\n", + " self.momentum = momentum\n", + "\n", + " def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False):\n", + " \"\"\"\n", + " Forward propagates by applying a bach norm function. If params are none then internal params are used.\n", + " Otherwise passed params will be used to execute the function.\n", + " :param input: input data batch, size either can be any.\n", + " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n", + " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n", + " :param params: A dictionary containing 'weight' and 'bias'.\n", + " :param training: Whether this is currently the training or evaluation phase.\n", + " :param backup_running_statistics: Whether to backup the running statistics. This is used\n", + " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n", + " :return: The result of the batch norm operation.\n", + " \"\"\"\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + " (weight, bias) = params[\"weight\"], params[\"bias\"]\n", + " #print(num_step, params['weight'])\n", + " else:\n", + " #print(num_step, \"no params\")\n", + " weight, bias = self.weight, self.bias\n", + "\n", + " if self.use_per_step_bn_statistics:\n", + " running_mean = self.running_mean[num_step]\n", + " running_var = self.running_var[num_step]\n", + " if params is None:\n", + " if not self.args.enable_inner_loop_optimizable_bn_params:\n", + " bias = self.bias[num_step]\n", + " weight = self.weight[num_step]\n", + " else:\n", + " running_mean = None\n", + " running_var = None\n", + "\n", + "\n", + " if backup_running_statistics and self.use_per_step_bn_statistics:\n", + " self.backup_running_mean.data = copy(self.running_mean.data)\n", + " self.backup_running_var.data = copy(self.running_var.data)\n", + "\n", + " momentum = self.momentum\n", + "\n", + " output = F.batch_norm(input, running_mean, running_var, weight, bias,\n", + " training=True, momentum=momentum, eps=self.eps)\n", + "\n", + " return output\n", + "\n", + " def restore_backup_stats(self):\n", + " \"\"\"\n", + " Resets batch statistics to their backup values which are collected after each forward pass.\n", + " \"\"\"\n", + " if self.use_per_step_bn_statistics:\n", + " self.running_mean = nn.Parameter(self.backup_running_mean.to(device=self.device), requires_grad=False)\n", + " self.running_var = nn.Parameter(self.backup_running_var.to(device=self.device), requires_grad=False)\n", + "\n", + " def extra_repr(self):\n", + " return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \\\n", + " 'track_running_stats={track_running_stats}'.format(**self.__dict__)\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LsbTVhQ9vply", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaLayerNormLayer(nn.Module):\n", + " def __init__(self, input_feature_shape, eps=1e-5, elementwise_affine=True):\n", + " \"\"\"\n", + " A MetaLayerNorm layer. A layer that applies the same functionality as a layer norm layer with the added\n", + " capability of being able to receive params at inference time to use instead of the internal ones. As well as\n", + " being able to use its own internal weights.\n", + " :param input_feature_shape: The input shape without the batch dimension, e.g. c, h, w\n", + " :param eps: Epsilon to use for protection against overflows\n", + " :param elementwise_affine: Whether to learn a multiplicative interaction parameter 'w' in addition to\n", + " the biases.\n", + " \"\"\"\n", + " super(MetaLayerNormLayer, self).__init__()\n", + " if isinstance(input_feature_shape, numbers.Integral):\n", + " input_feature_shape = (input_feature_shape,)\n", + " self.normalized_shape = torch.Size(input_feature_shape)\n", + " self.eps = eps\n", + " self.elementwise_affine = elementwise_affine\n", + " if self.elementwise_affine:\n", + " self.weight = nn.Parameter(torch.Tensor(*input_feature_shape), requires_grad=False)\n", + " self.bias = nn.Parameter(torch.Tensor(*input_feature_shape))\n", + " else:\n", + " self.register_parameter('weight', None)\n", + " self.register_parameter('bias', None)\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " \"\"\"\n", + " Reset parameters to their initialization values.\n", + " \"\"\"\n", + " if self.elementwise_affine:\n", + " self.weight.data.fill_(1)\n", + " self.bias.data.zero_()\n", + "\n", + " def forward(self, input, num_step, params=None, training=False, backup_running_statistics=False):\n", + " \"\"\"\n", + " Forward propagates by applying a layer norm function. If params are none then internal params are used.\n", + " Otherwise passed params will be used to execute the function.\n", + " :param input: input data batch, size either can be any.\n", + " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n", + " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n", + " :param params: A dictionary containing 'weight' and 'bias'.\n", + " :param training: Whether this is currently the training or evaluation phase.\n", + " :param backup_running_statistics: Whether to backup the running statistics. This is used\n", + " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n", + " :return: The result of the batch norm operation.\n", + " \"\"\"\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + " bias = params[\"bias\"]\n", + " else:\n", + " bias = self.bias\n", + " #print('no inner loop params', self)\n", + "\n", + " return F.layer_norm(\n", + " input, self.normalized_shape, self.weight, bias, self.eps)\n", + "\n", + " def restore_backup_stats(self):\n", + " pass\n", + "\n", + " def extra_repr(self):\n", + " return '{normalized_shape}, eps={eps}, ' \\\n", + " 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-Up-rbPipql9", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaConvNormLayerReLU(nn.Module):\n", + " def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True,\n", + " meta_layer=True, no_bn_learnable_params=False, device=None):\n", + " \"\"\"\n", + " Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order.\n", + " :param args: A named tuple containing the system's hyperparameters.\n", + " :param device: The device to run the layer on.\n", + " :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm'\n", + " :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm,\n", + " meta-conv etc.\n", + " :param input_shape: The image input shape in the form (b, c, h, w)\n", + " :param num_filters: number of filters for convolutional layer\n", + " :param kernel_size: the kernel size of the convolutional layer\n", + " :param stride: the stride of the convolutional layer\n", + " :param padding: the bias of the convolutional layer\n", + " :param use_bias: whether the convolutional layer utilizes a bias\n", + " \"\"\"\n", + " super(MetaConvNormLayerReLU, self).__init__()\n", + " self.normalization = normalization\n", + " self.use_per_step_bn_statistics = args.per_step_bn_statistics\n", + " self.input_shape = input_shape\n", + " self.args = args\n", + " self.num_filters = num_filters\n", + " self.kernel_size = kernel_size\n", + " self.stride = stride\n", + " self.padding = padding\n", + " self.use_bias = use_bias\n", + " self.meta_layer = meta_layer\n", + " self.no_bn_learnable_params = no_bn_learnable_params\n", + " self.device = device\n", + " self.layer_dict = nn.ModuleDict()\n", + " self.build_block()\n", + "\n", + " def build_block(self):\n", + "\n", + " x = torch.zeros(self.input_shape)\n", + "\n", + " out = x\n", + "\n", + " self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters,\n", + " kernel_size=self.kernel_size,\n", + " stride=self.stride, padding=self.padding, use_bias=self.use_bias)\n", + "\n", + "\n", + "\n", + " out = self.conv(out)\n", + "\n", + " if self.normalization:\n", + " if self.args.norm_layer == \"batch_norm\":\n", + " self.norm_layer = MetaBatchNormLayer(out.shape[1], track_running_stats=True,\n", + " meta_batch_norm=self.meta_layer,\n", + " no_learnable_params=self.no_bn_learnable_params,\n", + " device=self.device,\n", + " use_per_step_bn_statistics=self.use_per_step_bn_statistics,\n", + " args=self.args)\n", + " elif self.args.norm_layer == \"layer_norm\":\n", + " self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:])\n", + "\n", + " out = self.norm_layer(out, num_step=0)\n", + "\n", + " out = F.leaky_relu(out)\n", + "\n", + " print(out.shape)\n", + "\n", + " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n", + " \"\"\"\n", + " Forward propagates by applying the function. If params are none then internal params are used.\n", + " Otherwise passed params will be used to execute the function.\n", + " :param input: input data batch, size either can be any.\n", + " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n", + " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n", + " :param params: A dictionary containing 'weight' and 'bias'.\n", + " :param training: Whether this is currently the training or evaluation phase.\n", + " :param backup_running_statistics: Whether to backup the running statistics. This is used\n", + " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n", + " :return: The result of the batch norm operation.\n", + " \"\"\"\n", + " batch_norm_params = None\n", + " conv_params = None\n", + " activation_function_pre_params = None\n", + "\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + "\n", + " if self.normalization:\n", + " if 'norm_layer' in params:\n", + " batch_norm_params = params['norm_layer']\n", + "\n", + " if 'activation_function_pre' in params:\n", + " activation_function_pre_params = params['activation_function_pre']\n", + "\n", + " conv_params = params['conv']\n", + "\n", + " out = x\n", + "\n", + "\n", + " out = self.conv(out, params=conv_params)\n", + "\n", + " if self.normalization:\n", + " out = self.norm_layer.forward(out, num_step=num_step,\n", + " params=batch_norm_params, training=training,\n", + " backup_running_statistics=backup_running_statistics)\n", + "\n", + " out = F.leaky_relu(out)\n", + "\n", + " return out\n", + "\n", + " def restore_backup_stats(self):\n", + " \"\"\"\n", + " Restore stored statistics from the backup, replacing the current ones.\n", + " \"\"\"\n", + " if self.normalization:\n", + " self.norm_layer.restore_backup_stats()\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "qujCuA12uLby", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaNormLayerConvReLU(nn.Module):\n", + " def __init__(self, input_shape, num_filters, kernel_size, stride, padding, use_bias, args, normalization=True,\n", + " meta_layer=True, no_bn_learnable_params=False, device=None):\n", + " \"\"\"\n", + " Initializes a BatchNorm->Conv->ReLU layer which applies those operation in that order.\n", + " :param args: A named tuple containing the system's hyperparameters.\n", + " :param device: The device to run the layer on.\n", + " :param normalization: The type of normalization to use 'batch_norm' or 'layer_norm'\n", + " :param meta_layer: Whether this layer will require meta-layer capabilities such as meta-batch norm,\n", + " meta-conv etc.\n", + " :param input_shape: The image input shape in the form (b, c, h, w)\n", + " :param num_filters: number of filters for convolutional layer\n", + " :param kernel_size: the kernel size of the convolutional layer\n", + " :param stride: the stride of the convolutional layer\n", + " :param padding: the bias of the convolutional layer\n", + " :param use_bias: whether the convolutional layer utilizes a bias\n", + " \"\"\"\n", + " super(MetaNormLayerConvReLU, self).__init__()\n", + " self.normalization = normalization\n", + " self.use_per_step_bn_statistics = args.per_step_bn_statistics\n", + " self.input_shape = input_shape\n", + " self.args = args\n", + " self.num_filters = num_filters\n", + " self.kernel_size = kernel_size\n", + " self.stride = stride\n", + " self.padding = padding\n", + " self.use_bias = use_bias\n", + " self.meta_layer = meta_layer\n", + " self.no_bn_learnable_params = no_bn_learnable_params\n", + " self.device = device\n", + " self.layer_dict = nn.ModuleDict()\n", + " self.build_block()\n", + "\n", + " def build_block(self):\n", + "\n", + " x = torch.zeros(self.input_shape)\n", + "\n", + " out = x\n", + " if self.normalization:\n", + " if self.args.norm_layer == \"batch_norm\":\n", + " self.norm_layer = MetaBatchNormLayer(self.input_shape[1], track_running_stats=True,\n", + " meta_batch_norm=self.meta_layer,\n", + " no_learnable_params=self.no_bn_learnable_params,\n", + " device=self.device,\n", + " use_per_step_bn_statistics=self.use_per_step_bn_statistics,\n", + " args=self.args)\n", + " elif self.args.norm_layer == \"layer_norm\":\n", + " self.norm_layer = MetaLayerNormLayer(input_feature_shape=out.shape[1:])\n", + "\n", + " out = self.norm_layer.forward(out, num_step=0)\n", + " self.conv = MetaConv2dLayer(in_channels=out.shape[1], out_channels=self.num_filters,\n", + " kernel_size=self.kernel_size,\n", + " stride=self.stride, padding=self.padding, use_bias=self.use_bias)\n", + "\n", + "\n", + " self.layer_dict['activation_function_pre'] = nn.LeakyReLU()\n", + "\n", + "\n", + " out = self.layer_dict['activation_function_pre'].forward(self.conv.forward(out))\n", + " print(out.shape)\n", + "\n", + " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n", + " \"\"\"\n", + " Forward propagates by applying the function. If params are none then internal params are used.\n", + " Otherwise passed params will be used to execute the function.\n", + " :param input: input data batch, size either can be any.\n", + " :param num_step: The current inner loop step being taken. This is used when we are learning per step params and\n", + " collecting per step batch statistics. It indexes the correct object to use for the current time-step\n", + " :param params: A dictionary containing 'weight' and 'bias'.\n", + " :param training: Whether this is currently the training or evaluation phase.\n", + " :param backup_running_statistics: Whether to backup the running statistics. This is used\n", + " at evaluation time, when after the pass is complete we want to throw away the collected validation stats.\n", + " :return: The result of the batch norm operation.\n", + " \"\"\"\n", + " batch_norm_params = None\n", + "\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + "\n", + " if self.normalization:\n", + " if 'norm_layer' in params:\n", + " batch_norm_params = params['norm_layer']\n", + "\n", + " conv_params = params['conv']\n", + " else:\n", + " conv_params = None\n", + " #print('no inner loop params', self)\n", + "\n", + " out = x\n", + "\n", + " if self.normalization:\n", + " out = self.norm_layer.forward(out, num_step=num_step,\n", + " params=batch_norm_params, training=training,\n", + " backup_running_statistics=backup_running_statistics)\n", + "\n", + " out = self.conv.forward(out, params=conv_params)\n", + " out = self.layer_dict['activation_function_pre'].forward(out)\n", + "\n", + " return out\n", + "\n", + " def restore_backup_stats(self):\n", + " \"\"\"\n", + " Restore stored statistics from the backup, replacing the current ones.\n", + " \"\"\"\n", + " if self.normalization:\n", + " self.norm_layer.restore_backup_stats()\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "LLwmEEAgv5EO", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MetaLinearLayer(nn.Module):\n", + " def __init__(self, input_shape, num_filters, use_bias):\n", + " \"\"\"\n", + " A MetaLinear layer. Applies the same functionality of a standard linearlayer with the added functionality of\n", + " being able to receive a parameter dictionary at the forward pass which allows the convolution to use external\n", + " weights instead of the internal ones stored in the linear layer. Useful for inner loop optimization in the meta\n", + " learning setting.\n", + " :param input_shape: The shape of the input data, in the form (b, f)\n", + " :param num_filters: Number of output filters\n", + " :param use_bias: Whether to use biases or not.\n", + " \"\"\"\n", + " super(MetaLinearLayer, self).__init__()\n", + " b, c = input_shape\n", + "\n", + " self.use_bias = use_bias\n", + " self.weights = nn.Parameter(torch.ones(num_filters, c))\n", + " nn.init.xavier_uniform_(self.weights)\n", + " if self.use_bias:\n", + " self.bias = nn.Parameter(torch.zeros(num_filters))\n", + "\n", + " def forward(self, x, params=None):\n", + " \"\"\"\n", + " Forward propagates by applying a linear function (Wx + b). If params are none then internal params are used.\n", + " Otherwise passed params will be used to execute the function.\n", + " :param x: Input data batch, in the form (b, f)\n", + " :param params: A dictionary containing 'weights' and 'bias'. If params are none then internal params are used.\n", + " Otherwise the external are used.\n", + " :return: The result of the linear function.\n", + " \"\"\"\n", + " if params is not None:\n", + " params = extract_top_level_dict(current_dict=params)\n", + " if self.use_bias:\n", + " (weight, bias) = params[\"weights\"], params[\"bias\"]\n", + " else:\n", + " (weight) = params[\"weights\"]\n", + " bias = None\n", + " else:\n", + " pass\n", + " #print('no inner loop params', self)\n", + "\n", + " if self.use_bias:\n", + " weight, bias = self.weights, self.bias\n", + " else:\n", + " weight = self.weights\n", + " bias = None\n", + " # print(x.shape)\n", + " out = F.linear(input=x, weight=weight, bias=bias)\n", + " return out" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "FoWzgA2qtpT9", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class VGGReLUNormNetwork(nn.Module):\n", + " def __init__(self, im_shape, num_output_classes, args, device, meta_classifier=True):\n", + " \"\"\"\n", + " Builds a multilayer convolutional network. It also provides functionality for passing external parameters\n", + " to be\n", + " used at inference time. Enables inner loop optimization readily.\n", + " :param im_shape: The input image batch shape.\n", + " :param num_output_classes: The number of output classes of the network.\n", + " :param args: A named tuple containing the system's hyperparameters.\n", + " :param device: The device to run this on.\n", + " :param meta_classifier: A flag indicating whether the system's meta-learning (inner-loop) functionalities\n", + " should\n", + " be enabled.\n", + " \"\"\"\n", + " super(VGGReLUNormNetwork, self).__init__()\n", + " b, c, self.h, self.w = im_shape\n", + " self.device = device\n", + " self.total_layers = 0\n", + " self.args = args\n", + " self.upscale_shapes = []\n", + " self.cnn_filters = args.cnn_num_filters\n", + " self.input_shape = list(im_shape)\n", + " self.num_stages = args.num_stages\n", + " self.num_output_classes = num_output_classes\n", + "\n", + " if args.max_pooling:\n", + " print(\"Using max pooling\")\n", + " self.conv_stride = 1\n", + " else:\n", + " print(\"Using strided convolutions\")\n", + " self.conv_stride = 2\n", + " self.meta_classifier = meta_classifier\n", + "\n", + " self.build_network()\n", + " print(\"meta network params\")\n", + " for name, param in self.named_parameters():\n", + " print(name, param.shape)\n", + "\n", + " def build_network(self):\n", + " \"\"\"\n", + " Builds the network before inference is required by creating some dummy inputs with the same input as the\n", + " self.im_shape tuple. Then passes that through the network and dynamically computes input shapes and\n", + " sets output shapes for each layer.\n", + " \"\"\"\n", + " x = torch.zeros(self.input_shape)\n", + " out = x\n", + " self.layer_dict = nn.ModuleDict()\n", + " self.upscale_shapes.append(x.shape)\n", + "\n", + " for i in range(self.num_stages):\n", + " self.layer_dict['conv{}'.format(i)] = MetaConvNormLayerReLU(input_shape=out.shape,\n", + " num_filters=self.cnn_filters,\n", + " kernel_size=3, stride=self.conv_stride,\n", + " padding=self.args.conv_padding,\n", + " use_bias=True, args=self.args,\n", + " normalization=True,\n", + " meta_layer=self.meta_classifier,\n", + " no_bn_learnable_params=False,\n", + " device=self.device)\n", + " out = self.layer_dict['conv{}'.format(i)](out, training=True, num_step=0)\n", + "\n", + " if self.args.max_pooling:\n", + " out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0)\n", + "\n", + "\n", + " if not self.args.max_pooling:\n", + " out = F.avg_pool2d(out, out.shape[2])\n", + "\n", + " self.encoder_features_shape = list(out.shape)\n", + " out = out.view(out.shape[0], -1)\n", + "\n", + " self.layer_dict['linear'] = MetaLinearLayer(input_shape=(out.shape[0], np.prod(out.shape[1:])),\n", + " num_filters=self.num_output_classes, use_bias=True)\n", + "\n", + " out = self.layer_dict['linear'](out)\n", + " print(\"VGGNetwork build\", out.shape)\n", + "\n", + " def forward(self, x, num_step, params=None, training=False, backup_running_statistics=False):\n", + " \"\"\"\n", + " Forward propages through the network. If any params are passed then they are used instead of stored params.\n", + " :param x: Input image batch.\n", + " :param num_step: The current inner loop step number\n", + " :param params: If params are None then internal parameters are used. If params are a dictionary with keys the\n", + " same as the layer names then they will be used instead.\n", + " :param training: Whether this is training (True) or eval time.\n", + " :param backup_running_statistics: Whether to backup the running statistics in their backup store. Which is\n", + " then used to reset the stats back to a previous state (usually after an eval loop, when we want to throw away stored statistics)\n", + " :return: Logits of shape b, num_output_classes.\n", + " \"\"\"\n", + " param_dict = dict()\n", + "\n", + " if params is not None:\n", + " param_dict = extract_top_level_dict(current_dict=params)\n", + "\n", + " # print('top network', param_dict.keys())\n", + " for name, param in self.layer_dict.named_parameters():\n", + " path_bits = name.split(\".\")\n", + " layer_name = path_bits[0]\n", + " if layer_name not in param_dict:\n", + " param_dict[layer_name] = None\n", + "\n", + " out = x\n", + "\n", + " for i in range(self.num_stages):\n", + " out = self.layer_dict['conv{}'.format(i)](out, params=param_dict['conv{}'.format(i)], training=training,\n", + " backup_running_statistics=backup_running_statistics,\n", + " num_step=num_step)\n", + " if self.args.max_pooling:\n", + " out = F.max_pool2d(input=out, kernel_size=(2, 2), stride=2, padding=0)\n", + "\n", + " if not self.args.max_pooling:\n", + " out = F.avg_pool2d(out, out.shape[2])\n", + "\n", + " out = out.view(out.size(0), -1)\n", + " out = self.layer_dict['linear'](out, param_dict['linear'])\n", + "\n", + " return out\n", + "\n", + " def zero_grad(self, params=None):\n", + " if params is None:\n", + " for param in self.parameters():\n", + " if param.requires_grad == True:\n", + " if param.grad is not None:\n", + " if torch.sum(param.grad) > 0:\n", + " print(param.grad)\n", + " param.grad.zero_()\n", + " else:\n", + " for name, param in params.items():\n", + " if param.requires_grad == True:\n", + " if param.grad is not None:\n", + " if torch.sum(param.grad) > 0:\n", + " print(param.grad)\n", + " param.grad.zero_()\n", + " params[name].grad = None\n", + "\n", + " def restore_backup_stats(self):\n", + " \"\"\"\n", + " Reset stored batch statistics from the stored backup.\n", + " \"\"\"\n", + " for i in range(self.num_stages):\n", + " self.layer_dict['conv{}'.format(i)].restore_backup_stats()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-P0Tewg7iNRF", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class GradientDescentLearningRule(nn.Module):\n", + " \"\"\"Simple (stochastic) gradient descent learning rule.\n", + " For a scalar error function `E(p[0], p_[1] ... )` of some set of\n", + " potentially multidimensional parameters this attempts to find a local\n", + " minimum of the loss function by applying updates to each parameter of the\n", + " form\n", + " p[i] := p[i] - learning_rate * dE/dp[i]\n", + " With `learning_rate` a positive scaling parameter.\n", + " The error function used in successive applications of these updates may be\n", + " a stochastic estimator of the true error function (e.g. when the error with\n", + " respect to only a subset of data-points is calculated) in which case this\n", + " will correspond to a stochastic gradient descent learning rule.\n", + " \"\"\"\n", + "\n", + " def __init__(self, device, learning_rate=1e-3):\n", + " \"\"\"Creates a new learning rule object.\n", + " Args:\n", + " learning_rate: A postive scalar to scale gradient updates to the\n", + " parameters by. This needs to be carefully set - if too large\n", + " the learning dynamic will be unstable and may diverge, while\n", + " if set too small learning will proceed very slowly.\n", + " \"\"\"\n", + " super(GradientDescentLearningRule, self).__init__()\n", + " assert learning_rate > 0., 'learning_rate should be positive.'\n", + " self.learning_rate = torch.ones(1) * learning_rate\n", + " self.learning_rate.to(device)\n", + "\n", + " def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9):\n", + " \"\"\"Applies a single gradient descent update to all parameters.\n", + " All parameter updates are performed using in-place operations and so\n", + " nothing is returned.\n", + " Args:\n", + " grads_wrt_params: A list of gradients of the scalar loss function\n", + " with respect to each of the parameters passed to `initialise`\n", + " previously, with this list expected to be in the same order.\n", + " \"\"\"\n", + " updated_names_weights_dict = dict()\n", + " for key in names_weights_dict.keys():\n", + " updated_names_weights_dict[key] = names_weights_dict[key] - self.learning_rate * \\\n", + " names_grads_wrt_params_dict[\n", + " key]\n", + "\n", + " return updated_names_weights_dict\n", + "\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "t4ogOl5vqpqM", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class LSLRGradientDescentLearningRule(nn.Module):\n", + " \"\"\"Simple (stochastic) gradient descent learning rule.\n", + " For a scalar error function `E(p[0], p_[1] ... )` of some set of\n", + " potentially multidimensional parameters this attempts to find a local\n", + " minimum of the loss function by applying updates to each parameter of the\n", + " form\n", + " p[i] := p[i] - learning_rate * dE/dp[i]\n", + " With `learning_rate` a positive scaling parameter.\n", + " The error function used in successive applications of these updates may be\n", + " a stochastic estimator of the true error function (e.g. when the error with\n", + " respect to only a subset of data-points is calculated) in which case this\n", + " will correspond to a stochastic gradient descent learning rule.\n", + " \"\"\"\n", + "\n", + " def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, init_learning_rate=1e-3):\n", + " \"\"\"Creates a new learning rule object.\n", + " Args:\n", + " init_learning_rate: A postive scalar to scale gradient updates to the\n", + " parameters by. This needs to be carefully set - if too large\n", + " the learning dynamic will be unstable and may diverge, while\n", + " if set too small learning will proceed very slowly.\n", + " \"\"\"\n", + " super(LSLRGradientDescentLearningRule, self).__init__()\n", + " print(init_learning_rate)\n", + " assert init_learning_rate > 0., 'learning_rate should be positive.'\n", + "\n", + " self.init_learning_rate = torch.ones(1) * init_learning_rate\n", + " self.init_learning_rate.to(device)\n", + " self.total_num_inner_loop_steps = total_num_inner_loop_steps\n", + " self.use_learnable_learning_rates = use_learnable_learning_rates\n", + "\n", + " def initialise(self, names_weights_dict):\n", + " self.names_learning_rates_dict = nn.ParameterDict()\n", + " for idx, (key, param) in enumerate(names_weights_dict.items()):\n", + " self.names_learning_rates_dict[key.replace(\".\", \"-\")] = nn.Parameter(\n", + " data=torch.ones(self.total_num_inner_loop_steps + 1) * self.init_learning_rate,\n", + " requires_grad=self.use_learnable_learning_rates)\n", + "\n", + " def reset(self):\n", + "\n", + " # for key, param in self.names_learning_rates_dict.items():\n", + " # param.fill_(self.init_learning_rate)\n", + " pass\n", + "\n", + " def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.1):\n", + " \"\"\"Applies a single gradient descent update to all parameters.\n", + " All parameter updates are performed using in-place operations and so\n", + " nothing is returned.\n", + " Args:\n", + " grads_wrt_params: A list of gradients of the scalar loss function\n", + " with respect to each of the parameters passed to `initialise`\n", + " previously, with this list expected to be in the same order.\n", + " \"\"\"\n", + " updated_names_weights_dict = dict()\n", + " for key in names_grads_wrt_params_dict.keys():\n", + " updated_names_weights_dict[key] = names_weights_dict[key] - \\\n", + " self.names_learning_rates_dict[key.replace(\".\", \"-\")][num_step] \\\n", + " * names_grads_wrt_params_dict[\n", + " key]\n", + "\n", + " return updated_names_weights_dict" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "nJtmTvZ8nqYN", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def set_torch_seed(seed):\n", + " \"\"\"\n", + " Sets the pytorch seeds for current experiment run\n", + " :param seed: The seed (int)\n", + " :return: A random number generator to use\n", + " \"\"\"\n", + " rng = np.random.RandomState(seed=seed)\n", + " torch_seed = rng.randint(0, 999999)\n", + " torch.manual_seed(seed=torch_seed)\n", + "\n", + " return rng" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "iS8_p22sp3hW", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class MAMLFewShotClassifier(nn.Module):\n", + " def __init__(self, im_shape, device, args):\n", + " \"\"\"\n", + " Initializes a MAML few shot learning system\n", + " :param im_shape: The images input size, in batch, c, h, w shape\n", + " :param device: The device to use to use the model on.\n", + " :param args: A namedtuple of arguments specifying various hyperparameters.\n", + " \"\"\"\n", + " super(MAMLFewShotClassifier, self).__init__()\n", + " self.args = args\n", + " self.device = device\n", + " self.batch_size = args.batch_size\n", + " self.use_cuda = args.use_cuda\n", + " self.im_shape = im_shape\n", + " self.current_epoch = 0\n", + "\n", + " self.rng = set_torch_seed(seed=args.seed)\n", + " self.classifier = VGGReLUNormNetwork(im_shape=self.im_shape, num_output_classes=self.args.\n", + " num_classes_per_set,\n", + " args=args, device=device, meta_classifier=True).to(device=self.device)\n", + " self.task_learning_rate = args.task_learning_rate\n", + "\n", + " self.inner_loop_optimizer = LSLRGradientDescentLearningRule(device=device,\n", + " init_learning_rate=self.task_learning_rate,\n", + " total_num_inner_loop_steps=self.args.number_of_training_steps_per_iter,\n", + " use_learnable_learning_rates=self.args.learnable_per_layer_per_step_inner_loop_learning_rate)\n", + " self.inner_loop_optimizer.initialise(\n", + " names_weights_dict=self.get_inner_loop_parameter_dict(params=self.classifier.named_parameters()))\n", + "\n", + " print(\"Inner Loop parameters\")\n", + " for key, value in self.inner_loop_optimizer.named_parameters():\n", + " print(key, value.shape)\n", + "\n", + " self.use_cuda = args.use_cuda\n", + " self.device = device\n", + " self.args = args\n", + " self.to(device)\n", + " print(\"Outer Loop parameters\")\n", + " for name, param in self.named_parameters():\n", + " if param.requires_grad:\n", + " print(name, param.shape, param.device, param.requires_grad)\n", + "\n", + "\n", + " self.optimizer = optim.Adam(self.trainable_parameters(), lr=args.meta_learning_rate, amsgrad=False)\n", + " self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.args.total_epochs,\n", + " eta_min=self.args.min_learning_rate)\n", + "\n", + " def get_per_step_loss_importance_vector(self):\n", + " \"\"\"\n", + " Generates a tensor of dimensionality (num_inner_loop_steps) indicating the importance of each step's target\n", + " loss towards the optimization loss.\n", + " :return: A tensor to be used to compute the weighted average of the loss, useful for\n", + " the MSL (Multi Step Loss) mechanism.\n", + " \"\"\"\n", + " loss_weights = np.ones(shape=(self.args.number_of_training_steps_per_iter)) * (\n", + " 1.0 / self.args.number_of_training_steps_per_iter)\n", + " decay_rate = 1.0 / self.args.number_of_training_steps_per_iter / self.args.multi_step_loss_num_epochs\n", + " min_value_for_non_final_losses = 0.03 / self.args.number_of_training_steps_per_iter\n", + " for i in range(len(loss_weights) - 1):\n", + " curr_value = np.maximum(loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses)\n", + " loss_weights[i] = curr_value\n", + "\n", + " curr_value = np.minimum(\n", + " loss_weights[-1] + (self.current_epoch * (self.args.number_of_training_steps_per_iter - 1) * decay_rate),\n", + " 1.0 - ((self.args.number_of_training_steps_per_iter - 1) * min_value_for_non_final_losses))\n", + " loss_weights[-1] = curr_value\n", + " loss_weights = torch.Tensor(loss_weights).to(device=self.device)\n", + " return loss_weights\n", + "\n", + " def get_inner_loop_parameter_dict(self, params):\n", + " \"\"\"\n", + " Returns a dictionary with the parameters to use for inner loop updates.\n", + " :param params: A dictionary of the network's parameters.\n", + " :return: A dictionary of the parameters to use for the inner loop optimization process.\n", + " \"\"\"\n", + " param_dict = dict()\n", + " for name, param in params:\n", + " if param.requires_grad:\n", + " if self.args.enable_inner_loop_optimizable_bn_params:\n", + " param_dict[name] = param.to(device=self.device)\n", + " else:\n", + " if \"norm_layer\" not in name:\n", + " param_dict[name] = param.to(device=self.device)\n", + "\n", + " return param_dict\n", + "\n", + " def apply_inner_loop_update(self, loss, names_weights_copy, use_second_order, current_step_idx):\n", + " \"\"\"\n", + " Applies an inner loop update given current step's loss, the weights to update, a flag indicating whether to use\n", + " second order derivatives and the current step's index.\n", + " :param loss: Current step's loss with respect to the support set.\n", + " :param names_weights_copy: A dictionary with names to parameters to update.\n", + " :param use_second_order: A boolean flag of whether to use second order derivatives.\n", + " :param current_step_idx: Current step's index.\n", + " :return: A dictionary with the updated weights (name, param)\n", + " \"\"\"\n", + " self.classifier.zero_grad(names_weights_copy)\n", + "\n", + " grads = torch.autograd.grad(loss, names_weights_copy.values(),\n", + " create_graph=use_second_order)\n", + " names_grads_wrt_params = dict(zip(names_weights_copy.keys(), grads))\n", + "\n", + " names_weights_copy = self.inner_loop_optimizer.update_params(names_weights_dict=names_weights_copy,\n", + " names_grads_wrt_params_dict=names_grads_wrt_params,\n", + " num_step=current_step_idx)\n", + "\n", + " return names_weights_copy\n", + "\n", + " def get_across_task_loss_metrics(self, total_losses, total_accuracies):\n", + " losses = dict()\n", + "\n", + " losses['loss'] = torch.mean(torch.stack(total_losses))\n", + " losses['accuracy'] = np.mean(total_accuracies)\n", + "\n", + " return losses\n", + "\n", + " def forward(self, data_batch, epoch, use_second_order, use_multi_step_loss_optimization, num_steps, training_phase):\n", + " \"\"\"\n", + " Runs a forward outer loop pass on the batch of tasks using the MAML/++ framework.\n", + " :param data_batch: A data batch containing the support and target sets.\n", + " :param epoch: Current epoch's index\n", + " :param use_second_order: A boolean saying whether to use second order derivatives.\n", + " :param use_multi_step_loss_optimization: Whether to optimize on the outer loop using just the last step's\n", + " target loss (True) or whether to use multi step loss which improves the stability of the system (False)\n", + " :param num_steps: Number of inner loop steps.\n", + " :param training_phase: Whether this is a training phase (True) or an evaluation phase (False)\n", + " :return: A dictionary with the collected losses of the current outer forward propagation.\n", + " \"\"\"\n", + " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n", + "\n", + " [b, ncs, spc] = y_support_set.shape\n", + "\n", + " self.num_classes_per_set = ncs\n", + "\n", + " total_losses = []\n", + " total_accuracies = []\n", + " per_task_target_preds = [[] for i in range(len(x_target_set))]\n", + " self.classifier.zero_grad()\n", + " for task_id, (x_support_set_task, y_support_set_task, x_target_set_task, y_target_set_task) in \\\n", + " enumerate(zip(x_support_set,\n", + " y_support_set,\n", + " x_target_set,\n", + " y_target_set)):\n", + " task_losses = []\n", + " task_accuracies = []\n", + " per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector()\n", + " names_weights_copy = self.get_inner_loop_parameter_dict(self.classifier.named_parameters())\n", + "\n", + " n, s, c, h, w = x_target_set_task.shape\n", + "\n", + " x_support_set_task = x_support_set_task.view(-1, c, h, w)\n", + " y_support_set_task = y_support_set_task.view(-1)\n", + " x_target_set_task = x_target_set_task.view(-1, c, h, w)\n", + " y_target_set_task = y_target_set_task.view(-1)\n", + "\n", + " for num_step in range(num_steps):\n", + "\n", + " support_loss, support_preds = self.net_forward(x=x_support_set_task,\n", + " y=y_support_set_task,\n", + " weights=names_weights_copy,\n", + " backup_running_statistics=\n", + " True if (num_step == 0) else False,\n", + " training=True, num_step=num_step)\n", + "\n", + " names_weights_copy = self.apply_inner_loop_update(loss=support_loss,\n", + " names_weights_copy=names_weights_copy,\n", + " use_second_order=use_second_order,\n", + " current_step_idx=num_step)\n", + "\n", + " if use_multi_step_loss_optimization and training_phase and epoch < self.args.multi_step_loss_num_epochs:\n", + " target_loss, target_preds = self.net_forward(x=x_target_set_task,\n", + " y=y_target_set_task, weights=names_weights_copy,\n", + " backup_running_statistics=False, training=True,\n", + " num_step=num_step)\n", + "\n", + " task_losses.append(per_step_loss_importance_vectors[num_step] * target_loss)\n", + " else:\n", + " if num_step == (self.args.number_of_training_steps_per_iter - 1):\n", + " target_loss, target_preds = self.net_forward(x=x_target_set_task,\n", + " y=y_target_set_task, weights=names_weights_copy,\n", + " backup_running_statistics=False, training=True,\n", + " num_step=num_step)\n", + " task_losses.append(target_loss)\n", + "\n", + " per_task_target_preds[task_id] = target_preds.detach().cpu().numpy()\n", + " _, predicted = torch.max(target_preds.data, 1)\n", + "\n", + " accuracy = predicted.float().eq(y_target_set_task.data.float()).cpu().float()\n", + " task_losses = torch.sum(torch.stack(task_losses))\n", + " total_losses.append(task_losses)\n", + " total_accuracies.extend(accuracy)\n", + "\n", + " if not training_phase:\n", + " self.classifier.restore_backup_stats()\n", + "\n", + " losses = self.get_across_task_loss_metrics(total_losses=total_losses,\n", + " total_accuracies=total_accuracies)\n", + "\n", + " for idx, item in enumerate(per_step_loss_importance_vectors):\n", + " losses['loss_importance_vector_{}'.format(idx)] = item.detach().cpu().numpy()\n", + "\n", + " return losses, per_task_target_preds\n", + "\n", + " def net_forward(self, x, y, weights, backup_running_statistics, training, num_step):\n", + " \"\"\"\n", + " A base model forward pass on some data points x. Using the parameters in the weights dictionary. Also requires\n", + " boolean flags indicating whether to reset the running statistics at the end of the run (if at evaluation phase).\n", + " A flag indicating whether this is the training session and an int indicating the current step's number in the\n", + " inner loop.\n", + " :param x: A data batch of shape b, c, h, w\n", + " :param y: A data targets batch of shape b, n_classes\n", + " :param weights: A dictionary containing the weights to pass to the network.\n", + " :param backup_running_statistics: A flag indicating whether to reset the batch norm running statistics to their\n", + " previous values after the run (only for evaluation)\n", + " :param training: A flag indicating whether the current process phase is a training or evaluation.\n", + " :param num_step: An integer indicating the number of the step in the inner loop.\n", + " :return: the crossentropy losses with respect to the given y, the predictions of the base model.\n", + " \"\"\"\n", + " preds = self.classifier.forward(x=x, params=weights,\n", + " training=training,\n", + " backup_running_statistics=backup_running_statistics, num_step=num_step)\n", + "\n", + " loss = F.cross_entropy(input=preds, target=y)\n", + "\n", + " return loss, preds\n", + "\n", + " def trainable_parameters(self):\n", + " \"\"\"\n", + " Returns an iterator over the trainable parameters of the model.\n", + " \"\"\"\n", + " for param in self.parameters():\n", + " if param.requires_grad:\n", + " yield param\n", + "\n", + " def train_forward_prop(self, data_batch, epoch):\n", + " \"\"\"\n", + " Runs an outer loop forward prop using the meta-model and base-model.\n", + " :param data_batch: A data batch containing the support set and the target set input, output pairs.\n", + " :param epoch: The index of the currrent epoch.\n", + " :return: A dictionary of losses for the current step.\n", + " \"\"\"\n", + " losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch,\n", + " use_second_order=self.args.second_order and\n", + " epoch > self.args.first_order_to_second_order_epoch,\n", + " use_multi_step_loss_optimization=self.args.use_multi_step_loss_optimization,\n", + " num_steps=self.args.number_of_training_steps_per_iter,\n", + " training_phase=True)\n", + " return losses, per_task_target_preds\n", + "\n", + " def evaluation_forward_prop(self, data_batch, epoch):\n", + " \"\"\"\n", + " Runs an outer loop evaluation forward prop using the meta-model and base-model.\n", + " :param data_batch: A data batch containing the support set and the target set input, output pairs.\n", + " :param epoch: The index of the currrent epoch.\n", + " :return: A dictionary of losses for the current step.\n", + " \"\"\"\n", + " losses, per_task_target_preds = self.forward(data_batch=data_batch, epoch=epoch, use_second_order=False,\n", + " use_multi_step_loss_optimization=True,\n", + " num_steps=self.args.number_of_evaluation_steps_per_iter,\n", + " training_phase=False)\n", + "\n", + " return losses, per_task_target_preds\n", + "\n", + " def meta_update(self, loss):\n", + " \"\"\"\n", + " Applies an outer loop update on the meta-parameters of the model.\n", + " :param loss: The current crossentropy loss.\n", + " \"\"\"\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " if 'imagenet' in self.args.dataset_name:\n", + " for name, param in self.classifier.named_parameters():\n", + " if param.requires_grad:\n", + " param.grad.data.clamp_(-10, 10) # not sure if this is necessary, more experiments are needed\n", + " self.optimizer.step()\n", + "\n", + " def run_train_iter(self, data_batch, epoch):\n", + " \"\"\"\n", + " Runs an outer loop update step on the meta-model's parameters.\n", + " :param data_batch: input data batch containing the support set and target set input, output pairs\n", + " :param epoch: the index of the current epoch\n", + " :return: The losses of the ran iteration.\n", + " \"\"\"\n", + " epoch = int(epoch)\n", + " self.scheduler.step(epoch=epoch)\n", + " if self.current_epoch != epoch:\n", + " self.current_epoch = epoch\n", + "\n", + " if not self.training:\n", + " self.train()\n", + "\n", + " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n", + "\n", + " x_support_set = torch.Tensor(x_support_set).float().to(device=self.device)\n", + " x_target_set = torch.Tensor(x_target_set).float().to(device=self.device)\n", + " y_support_set = torch.Tensor(y_support_set).long().to(device=self.device)\n", + " y_target_set = torch.Tensor(y_target_set).long().to(device=self.device)\n", + "\n", + " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n", + "\n", + " losses, per_task_target_preds = self.train_forward_prop(data_batch=data_batch, epoch=epoch)\n", + "\n", + " self.meta_update(loss=losses['loss'])\n", + " losses['learning_rate'] = self.scheduler.get_lr()[0]\n", + " self.optimizer.zero_grad()\n", + " self.zero_grad()\n", + "\n", + " return losses, per_task_target_preds\n", + "\n", + " def run_validation_iter(self, data_batch):\n", + " \"\"\"\n", + " Runs an outer loop evaluation step on the meta-model's parameters.\n", + " :param data_batch: input data batch containing the support set and target set input, output pairs\n", + " :param epoch: the index of the current epoch\n", + " :return: The losses of the ran iteration.\n", + " \"\"\"\n", + "\n", + " if self.training:\n", + " self.eval()\n", + "\n", + " x_support_set, x_target_set, y_support_set, y_target_set = data_batch\n", + "\n", + " x_support_set = torch.Tensor(x_support_set).float().to(device=self.device)\n", + " x_target_set = torch.Tensor(x_target_set).float().to(device=self.device)\n", + " y_support_set = torch.Tensor(y_support_set).long().to(device=self.device)\n", + " y_target_set = torch.Tensor(y_target_set).long().to(device=self.device)\n", + "\n", + " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n", + "\n", + " losses, per_task_target_preds = self.evaluation_forward_prop(data_batch=data_batch, epoch=self.current_epoch)\n", + "\n", + " # losses['loss'].backward() # uncomment if you get the weird memory error\n", + " # self.zero_grad()\n", + " # self.optimizer.zero_grad()\n", + "\n", + " return losses, per_task_target_preds\n", + "\n", + " def save_model(self, model_save_dir, state):\n", + " \"\"\"\n", + " Save the network parameter state and experiment state dictionary.\n", + " :param model_save_dir: The directory to store the state at.\n", + " :param state: The state containing the experiment state and the network. It's in the form of a dictionary\n", + " object.\n", + " \"\"\"\n", + " state['network'] = self.state_dict()\n", + " torch.save(state, f=model_save_dir)\n", + "\n", + " def load_model(self, model_save_dir, model_name, model_idx):\n", + " \"\"\"\n", + " Load checkpoint and return the state dictionary containing the network state params and experiment state.\n", + " :param model_save_dir: The directory from which to load the files.\n", + " :param model_name: The model_name to be loaded from the direcotry.\n", + " :param model_idx: The index of the model (i.e. epoch number or 'latest' for the latest saved model of the current\n", + " experiment)\n", + " :return: A dictionary containing the experiment state and the saved model parameters.\n", + " \"\"\"\n", + " filepath = os.path.join(model_save_dir, \"{}_{}\".format(model_name, model_idx))\n", + " state = torch.load(filepath)\n", + " state_dict_loaded = state['network']\n", + " self.load_state_dict(state_dict=state_dict_loaded)\n", + " return state" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UmUi50wLrYr4", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def save_to_json(filename, dict_to_store):\n", + " with open(os.path.abspath(filename), 'w') as f:\n", + " json.dump(dict_to_store, fp=f)\n", + "\n", + "def load_from_json(filename):\n", + " with open(filename, mode=\"r\") as f:\n", + " load_dict = json.load(fp=f)\n", + "\n", + " return load_dict\n", + "\n", + "def save_statistics(experiment_name, line_to_add, filename=\"summary_statistics.csv\", create=False):\n", + " summary_filename = \"{}/{}\".format(experiment_name, filename)\n", + " if create:\n", + " with open(summary_filename, 'w') as f:\n", + " writer = csv.writer(f)\n", + " writer.writerow(line_to_add)\n", + " else:\n", + " with open(summary_filename, 'a') as f:\n", + " writer = csv.writer(f)\n", + " writer.writerow(line_to_add)\n", + "\n", + " return summary_filename\n", + "\n", + "def load_statistics(experiment_name, filename=\"summary_statistics.csv\"):\n", + " data_dict = dict()\n", + " summary_filename = \"{}/{}\".format(experiment_name, filename)\n", + " with open(summary_filename, 'r') as f:\n", + " lines = f.readlines()\n", + " data_labels = lines[0].replace(\"\\n\", \"\").split(\",\")\n", + " del lines[0]\n", + "\n", + " for label in data_labels:\n", + " data_dict[label] = []\n", + "\n", + " for line in lines:\n", + " data = line.replace(\"\\n\", \"\").split(\",\")\n", + " for key, item in zip(data_labels, data):\n", + " data_dict[key].append(item)\n", + " return data_dict\n", + "\n", + "\n", + "def build_experiment_folder(experiment_name):\n", + " experiment_path = os.path.abspath(experiment_name)\n", + " saved_models_filepath = \"{}/{}\".format(experiment_path, \"saved_models\")\n", + " logs_filepath = \"{}/{}\".format(experiment_path, \"logs\")\n", + " samples_filepath = \"{}/{}\".format(experiment_path, \"visual_outputs\")\n", + "\n", + " if not os.path.exists(experiment_path):\n", + " os.makedirs(experiment_path)\n", + " if not os.path.exists(logs_filepath):\n", + " os.makedirs(logs_filepath)\n", + " if not os.path.exists(samples_filepath):\n", + " os.makedirs(samples_filepath)\n", + " if not os.path.exists(saved_models_filepath):\n", + " os.makedirs(saved_models_filepath)\n", + "\n", + " outputs = (saved_models_filepath, logs_filepath, samples_filepath)\n", + " outputs = (os.path.abspath(item) for item in outputs)\n", + " return outputs\n", + "\n", + "def get_best_validation_model_statistics(experiment_name, filename=\"summary_statistics.csv\"):\n", + " \"\"\"\n", + " Returns the best val epoch and val accuracy from a log csv file\n", + " :param log_dir: The log directory the file is saved in\n", + " :param statistics_file_name: The log file name\n", + " :return: The best validation accuracy and the epoch at which it is produced\n", + " \"\"\"\n", + " log_file_dict = load_statistics(filename=filename, experiment_name=experiment_name)\n", + " d_val_loss = np.array(log_file_dict['total_d_val_loss_mean'], dtype=np.float32)\n", + " best_d_val_loss = np.min(d_val_loss)\n", + " best_d_val_epoch = np.argmin(d_val_loss)\n", + "\n", + " return best_d_val_loss, best_d_val_epoch\n", + "\n", + "def create_json_experiment_log(experiment_log_dir, args, log_name=\"experiment_log.json\"):\n", + " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n", + "\n", + " experiment_summary_dict = dict()\n", + "\n", + " for key, value in vars(args).items():\n", + " experiment_summary_dict[key] = value\n", + "\n", + " experiment_summary_dict[\"epoch_stats\"] = dict()\n", + " timestamp = datetime.datetime.now().timestamp()\n", + " experiment_summary_dict[\"experiment_status\"] = [(timestamp, \"initialization\")]\n", + " experiment_summary_dict[\"experiment_initialization_time\"] = timestamp\n", + " with open(os.path.abspath(summary_filename), 'w') as f:\n", + " json.dump(experiment_summary_dict, fp=f)\n", + "\n", + "def update_json_experiment_log_dict(key, value, experiment_log_dir, log_name=\"experiment_log.json\"):\n", + " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n", + " with open(summary_filename) as f:\n", + " summary_dict = json.load(fp=f)\n", + "\n", + " summary_dict[key].append(value)\n", + "\n", + " with open(summary_filename, 'w') as f:\n", + " json.dump(summary_dict, fp=f)\n", + "\n", + "def change_json_log_experiment_status(experiment_status, experiment_log_dir, log_name=\"experiment_log.json\"):\n", + " timestamp = datetime.datetime.now().timestamp()\n", + " experiment_status = (timestamp, experiment_status)\n", + " update_json_experiment_log_dict(key=\"experiment_status\", value=experiment_status,\n", + " experiment_log_dir=experiment_log_dir, log_name=log_name)\n", + "\n", + "def update_json_experiment_log_epoch_stats(epoch_stats, experiment_log_dir, log_name=\"experiment_log.json\"):\n", + " summary_filename = \"{}/{}\".format(experiment_log_dir, log_name)\n", + " with open(summary_filename) as f:\n", + " summary_dict = json.load(fp=f)\n", + "\n", + " epoch_stats_dict = summary_dict[\"epoch_stats\"]\n", + "\n", + " for key in epoch_stats.keys():\n", + " entry = float(epoch_stats[key])\n", + " if key in epoch_stats_dict:\n", + " epoch_stats_dict[key].append(entry)\n", + " else:\n", + " epoch_stats_dict[key] = [entry]\n", + "\n", + " summary_dict['epoch_stats'] = epoch_stats_dict\n", + "\n", + " with open(summary_filename, 'w') as f:\n", + " json.dump(summary_dict, fp=f)\n", + " return summary_filename" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "xCz8qhr0w2g3", + "colab_type": "code", + "colab": {} + }, + "source": [ + "class ExperimentBuilder(object):\n", + " def __init__(self, args, data, model, device):\n", + " \"\"\"\n", + " Initializes an experiment builder using a named tuple (args), a data provider (data), a meta learning system\n", + " (model) and a device (e.g. gpu/cpu/n)\n", + " :param args: A namedtuple containing all experiment hyperparameters\n", + " :param data: A data provider of instance MetaLearningSystemDataLoader\n", + " :param model: A meta learning system instance\n", + " :param device: Device/s to use for the experiment\n", + " \"\"\"\n", + " self.args, self.device = args, device\n", + "\n", + " self.model = model\n", + " self.saved_models_filepath, self.logs_filepath, self.samples_filepath = build_experiment_folder(\n", + " experiment_name=self.args.experiment_name)\n", + "\n", + " self.total_losses = dict()\n", + " self.state = dict()\n", + " self.state['best_val_acc'] = 0.\n", + " self.state['best_val_iter'] = 0\n", + " self.state['current_iter'] = 0\n", + " self.state['current_iter'] = 0\n", + " self.start_epoch = 0\n", + " self.max_models_to_save = self.args.max_models_to_save\n", + " self.create_summary_csv = False\n", + "\n", + " if self.args.continue_from_epoch == 'from_scratch':\n", + " self.create_summary_csv = True\n", + "\n", + " elif self.args.continue_from_epoch == 'latest':\n", + " checkpoint = os.path.join(self.saved_models_filepath, \"train_model_latest\")\n", + " print(\"attempting to find existing checkpoint\", )\n", + " if os.path.exists(checkpoint):\n", + " self.state = \\\n", + " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n", + " model_idx='latest')\n", + " self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n", + "\n", + " else:\n", + " self.args.continue_from_epoch = 'from_scratch'\n", + " self.create_summary_csv = True\n", + " elif int(self.args.continue_from_epoch) >= 0:\n", + " self.state = \\\n", + " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n", + " model_idx=self.args.continue_from_epoch)\n", + " self.start_epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n", + "\n", + " self.data = data(args=args, current_iter=self.state['current_iter'])\n", + "\n", + " print(\"train_seed {}, val_seed: {}, at start time\".format(self.data.dataset.seed[\"train\"],\n", + " self.data.dataset.seed[\"val\"]))\n", + " self.total_epochs_before_pause = self.args.total_epochs_before_pause\n", + " self.state['best_epoch'] = int(self.state['best_val_iter'] / self.args.total_iter_per_epoch)\n", + " self.epoch = int(self.state['current_iter'] / self.args.total_iter_per_epoch)\n", + " self.augment_flag = True if 'omniglot' in self.args.dataset_name.lower() else False\n", + " self.start_time = time.time()\n", + " self.epochs_done_in_this_run = 0\n", + " print(self.state['current_iter'], int(self.args.total_iter_per_epoch * self.args.total_epochs))\n", + "\n", + " def build_summary_dict(self, total_losses, phase, summary_losses=None):\n", + " \"\"\"\n", + " Builds/Updates a summary dict directly from the metric dict of the current iteration.\n", + " :param total_losses: Current dict with total losses (not aggregations) from experiment\n", + " :param phase: Current training phase\n", + " :param summary_losses: Current summarised (aggregated/summarised) losses stats means, stdv etc.\n", + " :return: A new summary dict with the updated summary statistics information.\n", + " \"\"\"\n", + " if summary_losses is None:\n", + " summary_losses = dict()\n", + "\n", + " for key in total_losses:\n", + " summary_losses[\"{}_{}_mean\".format(phase, key)] = np.mean(total_losses[key])\n", + " summary_losses[\"{}_{}_std\".format(phase, key)] = np.std(total_losses[key])\n", + "\n", + " return summary_losses\n", + "\n", + " def build_loss_summary_string(self, summary_losses):\n", + " \"\"\"\n", + " Builds a progress bar summary string given current summary losses dictionary\n", + " :param summary_losses: Current summary statistics\n", + " :return: A summary string ready to be shown to humans.\n", + " \"\"\"\n", + " output_update = \"\"\n", + " for key, value in zip(list(summary_losses.keys()), list(summary_losses.values())):\n", + " if \"loss\" in key or \"accuracy\" in key:\n", + " value = float(value)\n", + " output_update += \"{}: {:.4f}, \".format(key, value)\n", + "\n", + " return output_update\n", + "\n", + " def merge_two_dicts(self, first_dict, second_dict):\n", + " \"\"\"Given two dicts, merge them into a new dict as a shallow copy.\"\"\"\n", + " z = first_dict.copy()\n", + " z.update(second_dict)\n", + " return z\n", + "\n", + " def train_iteration(self, train_sample, sample_idx, epoch_idx, total_losses, current_iter, pbar_train):\n", + " \"\"\"\n", + " Runs a training iteration, updates the progress bar and returns the total and current epoch train losses.\n", + " :param train_sample: A sample from the data provider\n", + " :param sample_idx: The index of the incoming sample, in relation to the current training run.\n", + " :param epoch_idx: The epoch index.\n", + " :param total_losses: The current total losses dictionary to be updated.\n", + " :param current_iter: The current training iteration in relation to the whole experiment.\n", + " :param pbar_train: The progress bar of the training.\n", + " :return: Updates total_losses, train_losses, current_iter\n", + " \"\"\"\n", + " x_support_set, x_target_set, y_support_set, y_target_set, seed = train_sample\n", + " data_batch = (x_support_set, x_target_set, y_support_set, y_target_set)\n", + "\n", + " if sample_idx == 0:\n", + " print(\"shape of data\", x_support_set.shape, x_target_set.shape, y_support_set.shape,\n", + " y_target_set.shape)\n", + "\n", + " losses, _ = self.model.run_train_iter(data_batch=data_batch, epoch=epoch_idx)\n", + "\n", + " for key, value in zip(list(losses.keys()), list(losses.values())):\n", + " if key not in total_losses:\n", + " total_losses[key] = [float(value)]\n", + " else:\n", + " total_losses[key].append(float(value))\n", + "\n", + " train_losses = self.build_summary_dict(total_losses=total_losses, phase=\"train\")\n", + " train_output_update = self.build_loss_summary_string(losses)\n", + "\n", + " pbar_train.update(1)\n", + " pbar_train.set_description(\"training phase {} -> {}\".format(self.epoch, train_output_update))\n", + "\n", + " current_iter += 1\n", + "\n", + " return train_losses, total_losses, current_iter\n", + "\n", + " def evaluation_iteration(self, val_sample, total_losses, pbar_val, phase):\n", + " \"\"\"\n", + " Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.\n", + " :param val_sample: A sample from the data provider\n", + " :param total_losses: The current total losses dictionary to be updated.\n", + " :param pbar_val: The progress bar of the val stage.\n", + " :return: The updated val_losses, total_losses\n", + " \"\"\"\n", + " x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample\n", + " data_batch = (\n", + " x_support_set, x_target_set, y_support_set, y_target_set)\n", + "\n", + " losses, _ = self.model.run_validation_iter(data_batch=data_batch)\n", + " for key, value in zip(list(losses.keys()), list(losses.values())):\n", + " if key not in total_losses:\n", + " total_losses[key] = [float(value)]\n", + " else:\n", + " total_losses[key].append(float(value))\n", + "\n", + " val_losses = self.build_summary_dict(total_losses=total_losses, phase=phase)\n", + " val_output_update = self.build_loss_summary_string(losses)\n", + "\n", + " pbar_val.update(1)\n", + " pbar_val.set_description(\n", + " \"val_phase {} -> {}\".format(self.epoch, val_output_update))\n", + "\n", + " return val_losses, total_losses\n", + "\n", + " def test_evaluation_iteration(self, val_sample, model_idx, sample_idx, per_model_per_batch_preds, pbar_test):\n", + " \"\"\"\n", + " Runs a validation iteration, updates the progress bar and returns the total and current epoch val losses.\n", + " :param val_sample: A sample from the data provider\n", + " :param total_losses: The current total losses dictionary to be updated.\n", + " :param pbar_test: The progress bar of the val stage.\n", + " :return: The updated val_losses, total_losses\n", + " \"\"\"\n", + " x_support_set, x_target_set, y_support_set, y_target_set, seed = val_sample\n", + " data_batch = (\n", + " x_support_set, x_target_set, y_support_set, y_target_set)\n", + "\n", + " losses, per_task_preds = self.model.run_validation_iter(data_batch=data_batch)\n", + "\n", + " per_model_per_batch_preds[model_idx].extend(list(per_task_preds))\n", + "\n", + " test_output_update = self.build_loss_summary_string(losses)\n", + "\n", + " pbar_test.update(1)\n", + " pbar_test.set_description(\n", + " \"test_phase {} -> {}\".format(self.epoch, test_output_update))\n", + "\n", + " return per_model_per_batch_preds\n", + "\n", + " def save_models(self, model, epoch, state):\n", + " \"\"\"\n", + " Saves two separate instances of the current model. One to be kept for history and reloading later and another\n", + " one marked as \"latest\" to be used by the system for the next epoch training. Useful when the training/val\n", + " process is interrupted or stopped. Leads to fault tolerant training and validation systems that can continue\n", + " from where they left off before.\n", + " :param model: Current meta learning model of any instance within the few_shot_learning_system.py\n", + " :param epoch: Current epoch\n", + " :param state: Current model and experiment state dict.\n", + " \"\"\"\n", + " model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, \"train_model_{}\".format(int(epoch))),\n", + " state=state)\n", + "\n", + " model.save_model(model_save_dir=os.path.join(self.saved_models_filepath, \"train_model_latest\"),\n", + " state=state)\n", + "\n", + " print(\"saved models to\", self.saved_models_filepath)\n", + "\n", + " def pack_and_save_metrics(self, start_time, create_summary_csv, train_losses, val_losses, state, step):\n", + " \"\"\"\n", + " Given current epochs start_time, train losses, val losses and whether to create a new stats csv file, pack stats\n", + " and save into a statistics csv file. Return a new start time for the new epoch.\n", + " :param start_time: The start time of the current epoch\n", + " :param create_summary_csv: A boolean variable indicating whether to create a new statistics file or\n", + " append results to existing one\n", + " :param train_losses: A dictionary with the current train losses\n", + " :param val_losses: A dictionary with the currrent val loss\n", + " :return: The current time, to be used for the next epoch.\n", + " \"\"\"\n", + " epoch_summary_losses = self.merge_two_dicts(first_dict=train_losses, second_dict=val_losses)\n", + "\n", + " if 'per_epoch_statistics' not in state:\n", + " state['per_epoch_statistics'] = dict()\n", + "\n", + " for key, value in epoch_summary_losses.items():\n", + "\n", + " if key not in state['per_epoch_statistics']:\n", + " state['per_epoch_statistics'][key] = [value]\n", + " else:\n", + " state['per_epoch_statistics'][key].append(value)\n", + "\n", + " epoch_summary_string = self.build_loss_summary_string(epoch_summary_losses)\n", + " epoch_summary_losses[\"epoch\"] = self.epoch\n", + " epoch_summary_losses['epoch_run_time'] = time.time() - start_time\n", + "\n", + " if create_summary_csv:\n", + " self.summary_statistics_filepath = save_statistics(self.logs_filepath, list(epoch_summary_losses.keys()),\n", + " create=True)\n", + " self.create_summary_csv = False\n", + "\n", + " start_time = time.time()\n", + " \n", + " writer.add_scalar('epoch_summary_losses', epoch_summary_losses[\"epoch\"], step)\n", + " \n", + " print(\"epoch {} -> {}\".format(epoch_summary_losses[\"epoch\"], epoch_summary_string))\n", + "\n", + " self.summary_statistics_filepath = save_statistics(self.logs_filepath,\n", + " list(epoch_summary_losses.values()))\n", + " return start_time, state\n", + "\n", + " def evaluated_test_set_using_the_best_models(self, top_n_models):\n", + " per_epoch_statistics = self.state['per_epoch_statistics']\n", + " val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])\n", + " val_idx = np.array([i for i in range(len(val_acc))])\n", + " sorted_idx = np.argsort(val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]\n", + "\n", + " sorted_val_acc = val_acc[sorted_idx]\n", + " val_idx = val_idx[sorted_idx]\n", + " print(sorted_idx)\n", + " print(sorted_val_acc)\n", + "\n", + " top_n_idx = val_idx[:top_n_models]\n", + " per_model_per_batch_preds = [[] for i in range(top_n_models)]\n", + " per_model_per_batch_targets = [[] for i in range(top_n_models)]\n", + " test_losses = [dict() for i in range(top_n_models)]\n", + " for idx, model_idx in enumerate(top_n_idx):\n", + " self.state = \\\n", + " self.model.load_model(model_save_dir=self.saved_models_filepath, model_name=\"train_model\",\n", + " model_idx=model_idx + 1)\n", + " with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_test:\n", + " for sample_idx, test_sample in enumerate(\n", + " self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),\n", + " augment_images=False)):\n", + " #print(test_sample[4])\n", + " per_model_per_batch_targets[idx].extend(np.array(test_sample[3]))\n", + " per_model_per_batch_preds = self.test_evaluation_iteration(val_sample=test_sample,\n", + " sample_idx=sample_idx,\n", + " model_idx=idx,\n", + " per_model_per_batch_preds=per_model_per_batch_preds,\n", + " pbar_test=pbar_test)\n", + " # for i in range(top_n_models):\n", + " # print(\"test assertion\", 0)\n", + " # print(per_model_per_batch_targets[0], per_model_per_batch_targets[i])\n", + " # assert np.equal(np.array(per_model_per_batch_targets[0]), np.array(per_model_per_batch_targets[i]))\n", + "\n", + " per_batch_preds = np.mean(per_model_per_batch_preds, axis=0)\n", + " #print(per_batch_preds.shape)\n", + " per_batch_max = np.argmax(per_batch_preds, axis=2)\n", + " per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(per_batch_max.shape)\n", + " #print(per_batch_max)\n", + " accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))\n", + " accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))\n", + "\n", + " test_losses = {\"test_accuracy_mean\": accuracy, \"test_accuracy_std\": accuracy_std}\n", + "\n", + " _ = save_statistics(self.logs_filepath,\n", + " list(test_losses.keys()),\n", + " create=True, filename=\"test_summary.csv\")\n", + "\n", + " summary_statistics_filepath = save_statistics(self.logs_filepath,\n", + " list(test_losses.values()),\n", + " create=False, filename=\"test_summary.csv\")\n", + " print(test_losses)\n", + " print(\"saved test performance at\", summary_statistics_filepath)\n", + "\n", + " def run_experiment(self):\n", + " \"\"\"\n", + " Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,\n", + " will return the test set evaluation results on the best performing validation model.\n", + " \"\"\"\n", + " with tqdm.tqdm(initial=self.state['current_iter'],\n", + " total=int(self.args.total_iter_per_epoch * self.args.total_epochs)) as pbar_train:\n", + "\n", + " while (self.state['current_iter'] < (self.args.total_epochs * self.args.total_iter_per_epoch)) and (self.args.evaluate_on_test_set_only == False):\n", + "\n", + " for train_sample_idx, train_sample in enumerate(\n", + " self.data.get_train_batches(total_batches=int(self.args.total_iter_per_epoch *\n", + " self.args.total_epochs) - self.state[\n", + " 'current_iter'],\n", + " augment_images=self.augment_flag)):\n", + " # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))\n", + " train_losses, total_losses, self.state['current_iter'] = self.train_iteration(\n", + " train_sample=train_sample,\n", + " total_losses=self.total_losses,\n", + " epoch_idx=(self.state['current_iter'] /\n", + " self.args.total_iter_per_epoch),\n", + " pbar_train=pbar_train,\n", + " current_iter=self.state['current_iter'],\n", + " sample_idx=self.state['current_iter'])\n", + "\n", + " if self.state['current_iter'] % self.args.total_iter_per_epoch == 0:\n", + "\n", + " total_losses = dict()\n", + " val_losses = dict()\n", + " with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_val:\n", + " for _, val_sample in enumerate(\n", + " self.data.get_val_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),\n", + " augment_images=False)):\n", + " val_losses, total_losses = self.evaluation_iteration(val_sample=val_sample,\n", + " total_losses=total_losses,\n", + " pbar_val=pbar_val, phase='val')\n", + "\n", + " if val_losses[\"val_accuracy_mean\"] > self.state['best_val_acc']:\n", + " print(\"Best validation accuracy\", val_losses[\"val_accuracy_mean\"])\n", + " writer.add_scalar(\"best validation accuracy\", val_losses[\"val_accuracy_mean\"],\n", + " self.epoch)\n", + " self.state['best_val_acc'] = val_losses[\"val_accuracy_mean\"]\n", + " self.state['best_val_iter'] = self.state['current_iter']\n", + " self.state['best_epoch'] = int(\n", + " self.state['best_val_iter'] / self.args.total_iter_per_epoch)\n", + "\n", + "\n", + " self.epoch += 1\n", + " self.state = self.merge_two_dicts(first_dict=self.merge_two_dicts(first_dict=self.state,\n", + " second_dict=train_losses),\n", + " second_dict=val_losses)\n", + "\n", + " self.save_models(model=self.model, epoch=self.epoch, state=self.state)\n", + "\n", + " self.start_time, self.state = self.pack_and_save_metrics(start_time=self.start_time,\n", + " create_summary_csv=self.create_summary_csv,\n", + " train_losses=train_losses,\n", + " val_losses=val_losses,\n", + " state=self.state,\n", + " step=self.epoch)\n", + "\n", + " self.total_losses = dict()\n", + "\n", + " self.epochs_done_in_this_run += 1\n", + "\n", + " save_to_json(filename=os.path.join(self.logs_filepath, \"summary_statistics.json\"),\n", + " dict_to_store=self.state['per_epoch_statistics'])\n", + "\n", + " if self.epochs_done_in_this_run >= self.total_epochs_before_pause:\n", + " print(\"train_seed {}, val_seed: {}, at pause time\".format(self.data.dataset.seed[\"train\"],\n", + " self.data.dataset.seed[\"val\"]))\n", + " sys.exit()\n", + " self.evaluated_test_set_using_the_best_models(top_n_models=5)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "ClIPkzMEXvKS", + "colab_type": "code", + "colab": {} + }, + "source": [ + "from torch import cuda\n", + "\n", + "\n", + "def get_args():\n", + " import argparse\n", + " import os\n", + " import torch\n", + " import json\n", + " parser = argparse.ArgumentParser(description='Welcome to the MAML++ training and inference system')\n", + "\n", + " parser.add_argument('--batch_size', nargs=\"?\", type=int, default=32, help='Batch_size for experiment')\n", + " parser.add_argument('--image_height', nargs=\"?\", type=int, default=28)\n", + " parser.add_argument('--image_width', nargs=\"?\", type=int, default=28)\n", + " parser.add_argument('--image_channels', nargs=\"?\", type=int, default=1)\n", + " parser.add_argument('--reset_stored_filepaths', type=str, default=\"False\")\n", + " parser.add_argument('--reverse_channels', type=str, default=\"False\")\n", + " parser.add_argument('--num_of_gpus', type=int, default=1)\n", + " parser.add_argument('--indexes_of_folders_indicating_class', nargs='+', default=[-2, -3])\n", + " parser.add_argument('--train_val_test_split', nargs='+', default=[0.73982737361, 0.26, 0.13008631319])\n", + " parser.add_argument('--samples_per_iter', nargs=\"?\", type=int, default=1)\n", + " parser.add_argument('--labels_as_int', type=str, default=\"False\")\n", + " parser.add_argument('--seed', type=int, default=104)\n", + "\n", + " parser.add_argument('--gpu_to_use', type=int)\n", + " parser.add_argument('--num_dataprovider_workers', nargs=\"?\", type=int, default=4)\n", + " parser.add_argument('--max_models_to_save', nargs=\"?\", type=int, default=5)\n", + " parser.add_argument('--dataset_name', type=str, default=\"omniglot_dataset\")\n", + " parser.add_argument('--dataset_path', type=str, default=\"datasets/omniglot_dataset\")\n", + " parser.add_argument('--reset_stored_paths', type=str, default=\"False\")\n", + " parser.add_argument('--experiment_name', nargs=\"?\", type=str, )\n", + " parser.add_argument('--architecture_name', nargs=\"?\", type=str)\n", + " parser.add_argument('--continue_from_epoch', nargs=\"?\", type=str, default='latest', help='Continue from checkpoint of epoch')\n", + " parser.add_argument('--dropout_rate_value', type=float, default=0.3, help='Dropout_rate_value')\n", + " parser.add_argument('--num_target_samples', type=int, default=15, help='Dropout_rate_value')\n", + " parser.add_argument('--second_order', type=str, default=\"False\", help='Dropout_rate_value')\n", + " parser.add_argument('--total_epochs', type=int, default=200, help='Number of epochs per experiment')\n", + " parser.add_argument('--total_iter_per_epoch', type=int, default=500, help='Number of iters per epoch')\n", + " parser.add_argument('--min_learning_rate', type=float, default=0.00001, help='Min learning rate')\n", + " parser.add_argument('--meta_learning_rate', type=float, default=0.001, help='Learning rate of overall MAML system')\n", + " parser.add_argument('--meta_opt_bn', type=str, default=\"False\")\n", + " parser.add_argument('--task_learning_rate', type=float, default=0.1, help='Learning rate per task gradient step')\n", + "\n", + " parser.add_argument('--norm_layer', type=str, default=\"batch_norm\")\n", + " parser.add_argument('--max_pooling', type=str, default=\"False\")\n", + " parser.add_argument('--per_step_bn_statistics', type=str, default=\"False\")\n", + " parser.add_argument('--num_classes_per_set', type=int, default=20, help='Number of classes to sample per set')\n", + " parser.add_argument('--cnn_num_blocks', type=int, default=4, help='Number of classes to sample per set')\n", + " parser.add_argument('--number_of_training_steps_per_iter', type=int, default=1, help='Number of classes to sample per set')\n", + " parser.add_argument('--number_of_evaluation_steps_per_iter', type=int, default=1, help='Number of classes to sample per set')\n", + " parser.add_argument('--cnn_num_filters', type=int, default=64, help='Number of classes to sample per set')\n", + " parser.add_argument('--cnn_blocks_per_stage', type=int, default=1,\n", + " help='Number of classes to sample per set')\n", + " parser.add_argument('--num_samples_per_class', type=int, default=1, help='Number of samples per set to sample')\n", + " parser.add_argument('--name_of_args_json_file', type=str, default=\"./experiment_config/omniglot_maml++-omniglot_1_8_0.1_64_5_0.json\")\n", + " # parser.add_argument('--num_stages', default=4)\n", + " # parser.add_argument('--conv_padding', default=True)\n", + " \n", + " args = parser.parse_args('')\n", + " args_dict = vars(args)\n", + " if args.name_of_args_json_file is not \"None\":\n", + " args_dict = extract_args_from_json(args.name_of_args_json_file, args_dict)\n", + "\n", + " for key in list(args_dict.keys()):\n", + "\n", + " if str(args_dict[key]).lower() == \"true\":\n", + " args_dict[key] = True\n", + " elif str(args_dict[key]).lower() == \"false\":\n", + " args_dict[key] = False\n", + " if key == \"dataset_path\":\n", + " args_dict[key] = os.path.join(os.environ['DATASET_DIR'], args_dict[key])\n", + " print(key, os.path.join(os.environ['DATASET_DIR'], args_dict[key]))\n", + "\n", + " print(key, args_dict[key], type(args_dict[key]))\n", + "\n", + " args = Bunch(args_dict)\n", + "\n", + "\n", + " args.use_cuda = torch.cuda.is_available()\n", + "\n", + " if args.gpu_to_use == -1:\n", + " args.use_cuda = False\n", + "\n", + " if args.use_cuda:\n", + " os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(args.gpu_to_use)\n", + " device = cuda.current_device()\n", + " else:\n", + " device = torch.device('cpu')\n", + "\n", + " return args, device\n", + "\n", + "\n", + "\n", + "class Bunch(object):\n", + " def __init__(self, adict):\n", + " self.__dict__.update(adict)\n", + "\n", + "def extract_args_from_json(json_file_path, args_dict):\n", + " import json\n", + " summary_filename = json_file_path\n", + " with open(summary_filename) as f:\n", + " summary_dict = json.load(fp=f)\n", + "\n", + " for key in summary_dict.keys():\n", + " if \"continue_from\" in key:\n", + " pass\n", + " elif \"gpu_to_use\" in key:\n", + " pass\n", + " else:\n", + " args_dict[key] = summary_dict[key]\n", + "\n", + " return args_dict\n", + "\n", + "\n", + "\n", + "\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "GmWqo4GQ66oG", + "colab_type": "code", + "outputId": "bf944b05-3d91-4b35-bd3e-506d94ab78e5", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 2675 + } + }, + "source": [ + "# Combines the arguments, model, data and experiment builders to run an experiment\n", + "args, device = get_args()\n", + "print(args.image_channels)\n", + "model = MAMLFewShotClassifier(args=args, device=device,\n", + " im_shape=(2, args.image_channels, args.image_height, args.image_width))\n", + "# maybe_unzip_dataset(args=args)\n", + "data = MetaLearningSystemDataLoader\n", + "maml_system = ExperimentBuilder(model=model, data=data, args=args, device=device)\n", + "maml_system.run_experiment()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "batch_size 8 \n", + "image_height 28 \n", + "image_width 28 \n", + "image_channels 1 \n", + "reset_stored_filepaths False \n", + "reverse_channels False \n", + "num_of_gpus 1 \n", + "indexes_of_folders_indicating_class [-3, -2] \n", + "train_val_test_split [0.70918052988, 0.03080714725, 0.2606284658] \n", + "samples_per_iter 1 \n", + "labels_as_int False \n", + "seed 104 \n", + "gpu_to_use None \n", + "num_dataprovider_workers 4 \n", + "max_models_to_save 5 \n", + "dataset_name omniglot_dataset \n", + "dataset_path ./datasets/./datasets/omniglot_dataset\n", + "dataset_path ./datasets/omniglot_dataset \n", + "reset_stored_paths False \n", + "experiment_name omniglot_1_8_0.1_64_5_0 \n", + "architecture_name None \n", + "continue_from_epoch latest \n", + "dropout_rate_value 0.0 \n", + "num_target_samples 1 \n", + "second_order True \n", + "total_epochs 100 \n", + "total_iter_per_epoch 500 \n", + "min_learning_rate 1e-05 \n", + "meta_learning_rate 0.001 \n", + "meta_opt_bn False \n", + "task_learning_rate 0.1 \n", + "norm_layer batch_norm \n", + "max_pooling True \n", + "per_step_bn_statistics True \n", + "num_classes_per_set 5 \n", + "cnn_num_blocks 4 \n", + "number_of_training_steps_per_iter 5 \n", + "number_of_evaluation_steps_per_iter 5 \n", + "cnn_num_filters 64 \n", + "cnn_blocks_per_stage 1 \n", + "num_samples_per_class 1 \n", + "name_of_args_json_file ./experiment_config/omniglot_maml++-omniglot_1_8_0.1_64_5_0.json \n", + "train_seed 0 \n", + "val_seed 0 \n", + "load_from_npz_files False \n", + "sets_are_pre_split False \n", + "load_into_memory True \n", + "init_inner_loop_learning_rate 0.1 \n", + "train_in_stages False \n", + "multi_step_loss_num_epochs 10 \n", + "minimum_per_task_contribution 0.01 \n", + "num_evaluation_tasks 600 \n", + "learnable_per_layer_per_step_inner_loop_learning_rate True \n", + "enable_inner_loop_optimizable_bn_params False \n", + "evaluate_on_test_set_only False \n", + "learnable_batch_norm_momentum False \n", + "evalute_on_test_set_only False \n", + "learnable_bn_gamma True \n", + "learnable_bn_beta True \n", + "weight_decay 0.0 \n", + "total_epochs_before_pause 100 \n", + "first_order_to_second_order_epoch -1 \n", + "num_stages 4 \n", + "conv_padding True \n", + "use_multi_step_loss_optimization True \n", + "1\n", + "Using max pooling\n", + "torch.Size([2, 64, 28, 28])\n", + "torch.Size([2, 64, 14, 14])\n", + "torch.Size([2, 64, 7, 7])\n", + "torch.Size([2, 64, 3, 3])\n", + "VGGNetwork build torch.Size([2, 5])\n", + "meta network params\n", + "layer_dict.conv0.conv.weight torch.Size([64, 1, 3, 3])\n", + "layer_dict.conv0.conv.bias torch.Size([64])\n", + "layer_dict.conv0.norm_layer.running_mean torch.Size([5, 64])\n", + "layer_dict.conv0.norm_layer.running_var torch.Size([5, 64])\n", + "layer_dict.conv0.norm_layer.bias torch.Size([5, 64])\n", + "layer_dict.conv0.norm_layer.weight torch.Size([5, 64])\n", + "layer_dict.conv1.conv.weight torch.Size([64, 64, 3, 3])\n", + "layer_dict.conv1.conv.bias torch.Size([64])\n", + "layer_dict.conv1.norm_layer.running_mean torch.Size([5, 64])\n", + "layer_dict.conv1.norm_layer.running_var torch.Size([5, 64])\n", + "layer_dict.conv1.norm_layer.bias torch.Size([5, 64])\n", + "layer_dict.conv1.norm_layer.weight torch.Size([5, 64])\n", + "layer_dict.conv2.conv.weight torch.Size([64, 64, 3, 3])\n", + "layer_dict.conv2.conv.bias torch.Size([64])\n", + "layer_dict.conv2.norm_layer.running_mean torch.Size([5, 64])\n", + "layer_dict.conv2.norm_layer.running_var torch.Size([5, 64])\n", + "layer_dict.conv2.norm_layer.bias torch.Size([5, 64])\n", + "layer_dict.conv2.norm_layer.weight torch.Size([5, 64])\n", + "layer_dict.conv3.conv.weight torch.Size([64, 64, 3, 3])\n", + "layer_dict.conv3.conv.bias torch.Size([64])\n", + "layer_dict.conv3.norm_layer.running_mean torch.Size([5, 64])\n", + "layer_dict.conv3.norm_layer.running_var torch.Size([5, 64])\n", + "layer_dict.conv3.norm_layer.bias torch.Size([5, 64])\n", + "layer_dict.conv3.norm_layer.weight torch.Size([5, 64])\n", + "layer_dict.linear.weights torch.Size([5, 64])\n", + "layer_dict.linear.bias torch.Size([5])\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "\r 0%| | 0/1150 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMetaLearningSystemDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmaml_system\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExperimentBuilder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmaml_system\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mrun_experiment\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal_epochs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m 'current_iter'],\n\u001b[0;32m--> 311\u001b[0;31m augment_images=self.augment_flag)):\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;31m# print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m train_losses, total_losses, self.state['current_iter'] = self.train_iteration(\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mget_train_batches\u001b[0;34m(self, total_batches, augment_images)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_augmentation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maugment_images\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maugment_images\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtotal_train_iters_produced\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_of_gpus\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples_per_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msample_id\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_batched\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_dataloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0msample_batched\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreorder_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 582\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_next_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 583\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0mnext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m__next__\u001b[0m \u001b[0;31m# Python 2 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_next_batch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"KeyError:\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 607\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 608\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: Traceback (most recent call last):\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py\", line 99, in _worker_loop\n samples = collate_fn([dataset[i] for i in batch_indices])\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py\", line 99, in \n samples = collate_fn([dataset[i] for i in batch_indices])\n File \"\", line 438, in __getitem__\n writer.add_images('support_set_images', support_set_images, idx)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/writer.py\", line 404, in add_images\n image(tag, img_tensor, dataformats=dataformats), global_step, walltime)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/summary.py\", line 221, in image\n tensor = convert_to_HWC(tensor, dataformats)\n File \"/usr/local/lib/python3.6/dist-packages/torch/utils/tensorboard/_utils.py\", line 95, in convert_to_HWC\n tensor shape: {}, input_format: {}\".format(tensor.shape, input_format)\nAssertionError: size of input tensor and input format are different. tensor shape: (5, 1, 1, 28, 28), input_format: NCHW\n" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "gMlXOtEcMl8B", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file