From 82a5d0968cf69cf4d8bb1e1d2a8ac409d9f7f48e Mon Sep 17 00:00:00 2001 From: albertoesmp Date: Thu, 29 May 2025 12:32:08 +0200 Subject: [PATCH] Corrections for gradient centralization example. --- examples/vision/gradient_centralization.py | 48 +- .../ipynb/gradient_centralization.ipynb | 431 ++++++++++++++++-- examples/vision/md/gradient_centralization.md | 58 +-- 3 files changed, 446 insertions(+), 91 deletions(-) diff --git a/examples/vision/gradient_centralization.py b/examples/vision/gradient_centralization.py index aefc5e3e9f..ec2fee2aed 100644 --- a/examples/vision/gradient_centralization.py +++ b/examples/vision/gradient_centralization.py @@ -2,10 +2,11 @@ Title: Gradient Centralization for Better Training Performance Author: [Rishit Dagli](https://github.com/Rishit-dagli) Date created: 06/18/21 -Last modified: 07/25/23 +Last modified: 05/29/25 Description: Implement Gradient Centralization to improve training performance of DNNs. Accelerator: GPU Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com) +Debugged by: [Alberto M. Esmorís](https://github.com/albertoesmp) """ """ @@ -122,27 +123,28 @@ def prepare(ds, shuffle=False, augment=False): In this section we will define a Convolutional neural network. """ -model = keras.Sequential( - [ - layers.Input(shape=input_shape), - layers.Conv2D(16, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Conv2D(32, (3, 3), activation="relu"), - layers.Dropout(0.5), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.Dropout(0.5), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Flatten(), - layers.Dropout(0.5), - layers.Dense(512, activation="relu"), - layers.Dense(1, activation="sigmoid"), - ] -) +def make_model(): + return keras.Sequential( + [ + layers.Input(shape=input_shape), + layers.Conv2D(16, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Conv2D(32, (3, 3), activation="relu"), + layers.Dropout(0.5), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.Dropout(0.5), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(512, activation="relu"), + layers.Dense(1, activation="sigmoid"), + ] + ) """ ## Implement Gradient Centralization @@ -216,6 +218,7 @@ def on_epoch_end(self, batch, logs={}): """ time_callback_no_gc = TimeHistory() +model = make_model() model.compile( loss="binary_crossentropy", optimizer=RMSprop(learning_rate=1e-4), @@ -241,6 +244,7 @@ def on_epoch_end(self, batch, logs={}): """ time_callback_gc = TimeHistory() +model = make_model() model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"]) model.summary() diff --git a/examples/vision/ipynb/gradient_centralization.ipynb b/examples/vision/ipynb/gradient_centralization.ipynb index e2d9bef68a..b2c8864e95 100644 --- a/examples/vision/ipynb/gradient_centralization.ipynb +++ b/examples/vision/ipynb/gradient_centralization.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [Rishit Dagli](https://github.com/Rishit-dagli)
\n", "**Date created:** 06/18/21
\n", - "**Last modified:** 07/25/23
\n", + "**Last modified:** 05/29/25
\n", "**Description:** Implement Gradient Centralization to improve training performance of DNNs." ] }, @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 12, "metadata": { "colab_type": "code" }, @@ -81,11 +81,21 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 13, "metadata": { "colab_type": "code" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: (300, 300, 3)\n", + "Training images: 1027\n", + "Test images: 256\n" + ] + } + ], "source": [ "num_classes = 2\n", "input_shape = (300, 300, 3)\n", @@ -118,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 14, "metadata": { "colab_type": "code" }, @@ -172,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 15, "metadata": { "colab_type": "code" }, @@ -195,33 +205,34 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 16, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ - "model = keras.Sequential(\n", - " [\n", - " layers.Input(shape=input_shape),\n", - " layers.Conv2D(16, (3, 3), activation=\"relu\"),\n", - " layers.MaxPooling2D(2, 2),\n", - " layers.Conv2D(32, (3, 3), activation=\"relu\"),\n", - " layers.Dropout(0.5),\n", - " layers.MaxPooling2D(2, 2),\n", - " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", - " layers.Dropout(0.5),\n", - " layers.MaxPooling2D(2, 2),\n", - " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", - " layers.MaxPooling2D(2, 2),\n", - " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", - " layers.MaxPooling2D(2, 2),\n", - " layers.Flatten(),\n", - " layers.Dropout(0.5),\n", - " layers.Dense(512, activation=\"relu\"),\n", - " layers.Dense(1, activation=\"sigmoid\"),\n", - " ]\n", - ")" + "def make_model():\n", + " return keras.Sequential(\n", + " [\n", + " layers.Input(shape=input_shape),\n", + " layers.Conv2D(16, (3, 3), activation=\"relu\"),\n", + " layers.MaxPooling2D(2, 2),\n", + " layers.Conv2D(32, (3, 3), activation=\"relu\"),\n", + " layers.Dropout(0.5),\n", + " layers.MaxPooling2D(2, 2),\n", + " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", + " layers.Dropout(0.5),\n", + " layers.MaxPooling2D(2, 2),\n", + " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", + " layers.MaxPooling2D(2, 2),\n", + " layers.Conv2D(64, (3, 3), activation=\"relu\"),\n", + " layers.MaxPooling2D(2, 2),\n", + " layers.Flatten(),\n", + " layers.Dropout(0.5),\n", + " layers.Dense(512, activation=\"relu\"),\n", + " layers.Dense(1, activation=\"sigmoid\"),\n", + " ]\n", + " )" ] }, { @@ -255,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 17, "metadata": { "colab_type": "code" }, @@ -297,7 +308,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 18, "metadata": { "colab_type": "code" }, @@ -329,13 +340,148 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 19, "metadata": { "colab_type": "code" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"sequential_2\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"sequential_2\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ conv2d_10 (Conv2D)              │ (None, 298, 298, 16)   │           448 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_10 (MaxPooling2D) │ (None, 149, 149, 16)   │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_11 (Conv2D)              │ (None, 147, 147, 32)   │         4,640 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_6 (Dropout)             │ (None, 147, 147, 32)   │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_11 (MaxPooling2D) │ (None, 73, 73, 32)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_12 (Conv2D)              │ (None, 71, 71, 64)     │        18,496 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_7 (Dropout)             │ (None, 71, 71, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_12 (MaxPooling2D) │ (None, 35, 35, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_13 (Conv2D)              │ (None, 33, 33, 64)     │        36,928 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_13 (MaxPooling2D) │ (None, 16, 16, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_14 (Conv2D)              │ (None, 14, 14, 64)     │        36,928 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_14 (MaxPooling2D) │ (None, 7, 7, 64)       │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ flatten_2 (Flatten)             │ (None, 3136)           │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_8 (Dropout)             │ (None, 3136)           │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_4 (Dense)                 │ (None, 512)            │     1,606,144 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_5 (Dense)                 │ (None, 1)              │           513 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ conv2d_10 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m298\u001b[0m, \u001b[38;5;34m298\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m448\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_10 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_11 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m4,640\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_6 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_11 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m73\u001b[0m, \u001b[38;5;34m73\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_12 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_7 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_12 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m, \u001b[38;5;34m35\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_13 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m33\u001b[0m, \u001b[38;5;34m33\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_13 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_14 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_14 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ flatten_2 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3136\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_8 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3136\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m1,606,144\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m513\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 1,704,097 (6.50 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,704,097\u001b[0m (6.50 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 1,704,097 (6.50 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,704,097\u001b[0m (6.50 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "time_callback_no_gc = TimeHistory()\n", + "model = make_model()\n", "model.compile(\n", " loss=\"binary_crossentropy\",\n", " optimizer=RMSprop(learning_rate=1e-4),\n", @@ -357,11 +503,38 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 20, "metadata": { "colab_type": "code" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 318ms/step - accuracy: 0.4847 - loss: 0.7547\n", + "Epoch 2/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 57ms/step - accuracy: 0.5324 - loss: 0.6859\n", + "Epoch 3/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 57ms/step - accuracy: 0.6199 - loss: 0.6608\n", + "Epoch 4/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 58ms/step - accuracy: 0.6368 - loss: 0.6489\n", + "Epoch 5/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.6941 - loss: 0.6193\n", + "Epoch 6/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 54ms/step - accuracy: 0.7075 - loss: 0.6009\n", + "Epoch 7/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 49ms/step - accuracy: 0.6734 - loss: 0.5738\n", + "Epoch 8/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 48ms/step - accuracy: 0.7328 - loss: 0.5422\n", + "Epoch 9/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 52ms/step - accuracy: 0.7720 - loss: 0.5008\n", + "Epoch 10/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 55ms/step - accuracy: 0.8163 - loss: 0.4797\n" + ] + } + ], "source": [ "history_no_gc = model.fit(\n", " train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]\n", @@ -382,13 +555,174 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 21, "metadata": { "colab_type": "code" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
Model: \"sequential_3\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"sequential_3\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ conv2d_15 (Conv2D)              │ (None, 298, 298, 16)   │           448 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_15 (MaxPooling2D) │ (None, 149, 149, 16)   │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_16 (Conv2D)              │ (None, 147, 147, 32)   │         4,640 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_9 (Dropout)             │ (None, 147, 147, 32)   │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_16 (MaxPooling2D) │ (None, 73, 73, 32)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_17 (Conv2D)              │ (None, 71, 71, 64)     │        18,496 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_10 (Dropout)            │ (None, 71, 71, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_17 (MaxPooling2D) │ (None, 35, 35, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_18 (Conv2D)              │ (None, 33, 33, 64)     │        36,928 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_18 (MaxPooling2D) │ (None, 16, 16, 64)     │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ conv2d_19 (Conv2D)              │ (None, 14, 14, 64)     │        36,928 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ max_pooling2d_19 (MaxPooling2D) │ (None, 7, 7, 64)       │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ flatten_3 (Flatten)             │ (None, 3136)           │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dropout_11 (Dropout)            │ (None, 3136)           │             0 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_6 (Dense)                 │ (None, 512)            │     1,606,144 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_7 (Dense)                 │ (None, 1)              │           513 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ conv2d_15 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m298\u001b[0m, \u001b[38;5;34m298\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m448\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_15 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_16 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m4,640\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_9 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m147\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_16 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m73\u001b[0m, \u001b[38;5;34m73\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_17 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_10 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m71\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_17 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m35\u001b[0m, \u001b[38;5;34m35\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_18 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m33\u001b[0m, \u001b[38;5;34m33\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_18 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m16\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ conv2d_19 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m36,928\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ max_pooling2d_19 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ flatten_3 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3136\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dropout_11 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3136\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_6 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m1,606,144\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_7 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m513\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 1,704,097 (6.50 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,704,097\u001b[0m (6.50 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 1,704,097 (6.50 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,704,097\u001b[0m (6.50 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 325ms/step - accuracy: 0.5186 - loss: 0.7297\n", + "Epoch 2/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 58ms/step - accuracy: 0.5596 - loss: 0.7033\n", + "Epoch 3/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 56ms/step - accuracy: 0.5418 - loss: 0.7137\n", + "Epoch 4/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 58ms/step - accuracy: 0.5604 - loss: 0.6767\n", + "Epoch 5/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 54ms/step - accuracy: 0.6309 - loss: 0.6617\n", + "Epoch 6/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 63ms/step - accuracy: 0.6204 - loss: 0.6525\n", + "Epoch 7/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 60ms/step - accuracy: 0.7200 - loss: 0.6029\n", + "Epoch 8/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 57ms/step - accuracy: 0.6939 - loss: 0.6000\n", + "Epoch 9/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 60ms/step - accuracy: 0.7594 - loss: 0.5419\n", + "Epoch 10/10\n", + "\u001b[1m9/9\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 59ms/step - accuracy: 0.7197 - loss: 0.5385\n" + ] + } + ], "source": [ "time_callback_gc = TimeHistory()\n", + "model = make_model()\n", "model.compile(loss=\"binary_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"])\n", "\n", "model.summary()\n", @@ -407,11 +741,26 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 22, "metadata": { "colab_type": "code" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Not using Gradient Centralization\n", + "Loss: 0.48009756207466125\n", + "Accuracy: 0.8033106327056885\n", + "Training Time: 26.10426688194275\n", + "Using Gradient Centralization\n", + "Loss: 0.525072455406189\n", + "Accuracy: 0.7526776790618896\n", + "Training Time: 26.47221279144287\n" + ] + } + ], "source": [ "print(\"Not using Gradient Centralization\")\n", "print(f\"Loss: {history_no_gc.history['loss'][-1]}\")\n", @@ -451,7 +800,7 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -465,9 +814,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.12.3" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/examples/vision/md/gradient_centralization.md b/examples/vision/md/gradient_centralization.md index 3e72396e31..92a39707a7 100644 --- a/examples/vision/md/gradient_centralization.md +++ b/examples/vision/md/gradient_centralization.md @@ -140,27 +140,28 @@ In this section we will define a Convolutional neural network. ```python -model = keras.Sequential( - [ - layers.Input(shape=input_shape), - layers.Conv2D(16, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Conv2D(32, (3, 3), activation="relu"), - layers.Dropout(0.5), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.Dropout(0.5), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Conv2D(64, (3, 3), activation="relu"), - layers.MaxPooling2D(2, 2), - layers.Flatten(), - layers.Dropout(0.5), - layers.Dense(512, activation="relu"), - layers.Dense(1, activation="sigmoid"), - ] -) +def make_model(): + return keras.Sequential( + [ + layers.Input(shape=input_shape), + layers.Conv2D(16, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Conv2D(32, (3, 3), activation="relu"), + layers.Dropout(0.5), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.Dropout(0.5), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Conv2D(64, (3, 3), activation="relu"), + layers.MaxPooling2D(2, 2), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(512, activation="relu"), + layers.Dense(1, activation="sigmoid"), + ] + ) ``` --- @@ -240,6 +241,7 @@ compare to the training performance of the model trained with Gradient Centraliz ```python time_callback_no_gc = TimeHistory() +model = make_model() model.compile( loss="binary_crossentropy", optimizer=RMSprop(learning_rate=1e-4), @@ -357,6 +359,7 @@ notice our optimizer is the one using Gradient Centralization this time. ```python time_callback_gc = TimeHistory() +model = make_model() model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"]) model.summary() @@ -472,14 +475,13 @@ print(f"Training Time: {sum(time_callback_gc.times)}")
``` Not using Gradient Centralization -Loss: 0.5345584154129028 -Accuracy: 0.7604166865348816 -Training Time: 112.48799777030945 +Loss: 0.5709779858589172 +Accuracy: 0.7380720376968384 +Training Time: 30.397282361984253 Using Gradient Centralization -Loss: 0.4014038145542145 -Accuracy: 0.8153935074806213 -Training Time: 98.31573963165283 - +Loss: 0.5860965847969055 +Accuracy: 0.7039921879768372 +Training Time: 26.152544498443604 ```
Readers are encouraged to try out Gradient Centralization on different datasets from