diff --git a/_example_download_weights (1).ipynb b/_example_download_weights (1).ipynb new file mode 100644 index 0000000..dc0973e --- /dev/null +++ b/_example_download_weights (1).ipynb @@ -0,0 +1,496 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "is_executing": false + }, + "tags": [], + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1y5Ow9S1ZVgf", + "outputId": "9455ec1e-e7ee-4261-90bf-e80120dcc75f" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/bin/bash: line 1: svn: command not found\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\n", + "Collecting resampy\n", + " Downloading resampy-0.4.3-py3-none-any.whl.metadata (3.0 kB)\n", + "Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.17.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (1.16.0)\n", + "Requirement already satisfied: soundfile in /usr/local/lib/python3.10/dist-packages (0.12.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from resampy) (1.26.4)\n", + "Requirement already satisfied: numba>=0.53 in /usr/local/lib/python3.10/dist-packages (from resampy) (0.60.0)\n", + "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25)\n", + "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.6.0)\n", + "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n", + "Requirement already satisfied: h5py>=3.10.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.12.1)\n", + "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1)\n", + "Requirement already satisfied: ml-dtypes<0.5.0,>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.4.1)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.4.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.1)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.32.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (75.1.0)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.5.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.2)\n", + "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.64.1)\n", + "Requirement already satisfied: tensorboard<2.18,>=2.17 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.17.0)\n", + "Requirement already satisfied: keras>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.4.1)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.1)\n", + "Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile) (1.17.1)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.44.0)\n", + "Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile) (2.22)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (13.9.3)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.0.8)\n", + "Requirement already satisfied: optree in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.13.0)\n", + "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.53->resampy) (0.43.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2024.8.30)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.7)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.0.6)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.18,>=2.17->tensorflow) (3.0.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (2.18.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->tensorflow) (0.1.2)\n", + "Downloading resampy-0.4.3-py3-none-any.whl (3.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: resampy\n", + "Successfully installed resampy-0.4.3\n", + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 277M 100 277M 0 0 63.8M 0 0:00:04 0:00:04 --:--:-- 66.9M\n", + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 73020 100 73020 0 0 400k 0 --:--:-- --:--:-- --:--:-- 402k\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "This notebook demonstrates how to replicate converting tensorflow\n", + "weights from tensorflow's vggish to torchvggish\n", + "\"\"\"\n", + "\n", + "# Download the audioset directory using subversion\n", + "# !apt-get -qq install subversion # uncomment if on linux\n", + "!svn checkout https://github.com/tensorflow/models/trunk/research/audioset\n", + "\n", + "# Download audioset requirements\n", + "!pip install numpy scipy\n", + "!pip install resampy tensorflow six soundfile\n", + "\n", + "# grab the VGGish model checkpoints & PCA params\n", + "!curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt\n", + "!curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz" + ] + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/tensorflow/models/" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "NnanY1uvaUeW", + "outputId": "cbda6ff3-d345-43a4-a765-339f3fa5ae20" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'models'...\n", + "remote: Enumerating objects: 98393, done.\u001b[K\n", + "remote: Counting objects: 100% (899/899), done.\u001b[K\n", + "remote: Compressing objects: 100% (489/489), done.\u001b[K\n", + "remote: Total 98393 (delta 482), reused 736 (delta 386), pack-reused 97494 (from 1)\u001b[K\n", + "Receiving objects: 100% (98393/98393), 621.98 MiB | 17.06 MiB/s, done.\n", + "Resolving deltas: 100% (71466/71466), done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "cd models/research/audioset/vggish/" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Sa7k30NMbKzV", + "outputId": "1356d7d8-1d70-4417-c35a-75a872248a35" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/models/research/audioset/vggish\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "pwd" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "hgm9MY9lbRn5", + "outputId": "8c9a4855-0894-45bf-9fd0-2c6a4477ef4b" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'/content/models/research/audioset/vggish'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "pycharm": { + "is_executing": false + }, + "tags": [], + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-HKuu0FSZVgh", + "outputId": "e9f1ed27-7f46-438d-849e-185ab7c30336" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Testing your install of VGGish\n", + "\n", + "Resampling via resampy works!\n", + "Log Mel Spectrogram example: [[-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " ...\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]]\n", + "VGGish embedding: [-2.72986382e-01 -1.80314153e-01 5.19921184e-02 -1.43571526e-01\n", + " -1.04673728e-01 -4.96598154e-01 -1.75267965e-01 4.23147976e-01\n", + " -8.22126150e-01 -2.16801405e-01 -1.17509276e-01 -6.70077026e-01\n", + " 1.43174574e-01 -1.44183934e-01 8.73491913e-03 -8.71972442e-02\n", + " -1.84393525e-01 5.96655607e-01 -3.43809605e-01 -5.79104424e-02\n", + " -1.65071294e-01 4.22911644e-02 -2.55293399e-01 -2.36356765e-01\n", + " 1.80295616e-01 3.02612185e-01 1.08356833e-01 -4.48398024e-01\n", + " 1.22757629e-01 -2.99955189e-01 -5.55934191e-01 5.05966544e-01\n", + " 2.05210358e-01 8.87591839e-01 9.03702497e-01 -2.10566416e-01\n", + " -3.27462405e-02 1.38691410e-01 -2.27416530e-01 1.14804000e-01\n", + " 5.95410109e-01 -4.76971269e-01 2.28232622e-01 1.54627025e-01\n", + " 1.64934218e-01 7.19252825e-01 1.24101830e+00 5.61996222e-01\n", + " 2.73531973e-01 3.09788287e-02 2.10977703e-01 -6.09551668e-01\n", + " -3.15282375e-01 1.76392645e-01 -8.96190405e-02 -4.26822364e-01\n", + " 3.12993884e-01 -1.56592295e-01 3.31673503e-01 1.29436389e-01\n", + " 1.66024208e-01 3.01903039e-02 -1.54465199e-01 -4.29332554e-01\n", + " -2.68703818e-01 -1.58071086e-01 4.00485486e-01 -2.55945086e-01\n", + " -2.66429391e-02 8.16181302e-03 2.98492879e-01 3.48756194e-01\n", + " -1.07143626e-01 8.88779089e-02 1.26810491e-01 -3.34817201e-01\n", + " -2.55428016e-01 5.07779241e-01 3.97584617e-01 1.78759634e-01\n", + " -8.04521963e-02 4.84320521e-02 -2.01262981e-01 -2.97957748e-01\n", + " 3.66831303e-01 4.56224501e-01 5.37960529e-01 -2.00488269e-02\n", + " -6.24543577e-02 4.15623039e-01 -1.88741475e-01 -5.36903143e-01\n", + " -1.78362012e-01 3.81366849e-01 3.96645039e-01 3.21936429e-01\n", + " -4.26683240e-02 -1.41018063e-01 -4.53833699e-01 -1.07017279e-01\n", + " -2.21892655e-01 3.51183444e-01 -2.58386552e-01 3.31110060e-01\n", + " -7.28939176e-01 -2.55487382e-01 3.56361002e-01 -3.16188633e-01\n", + " 3.12793672e-01 1.23501822e-01 -1.83649734e-02 -3.99395853e-01\n", + " -5.13507247e-01 -2.74227202e-01 -2.68650651e-01 2.24091530e-01\n", + " 1.09625012e-01 1.30929738e-01 -1.25994891e-01 -1.92615181e-01\n", + " 1.83567405e-04 2.04150438e-01 -1.03096753e-01 2.93378532e-02\n", + " -3.38305712e-01 -2.25749940e-01 -2.46723339e-01 -1.20763183e-01]\n", + "embedding mean/stddev 0.00065699156 0.34301957\n", + "Postprocessed VGGish embedding: [160 53 124 132 154 120 119 105 155 173 129 69 149 93 59 0 52 97\n", + " 157 144 153 194 251 108 48 174 131 190 195 79 59 60 169 93 167 247\n", + " 28 75 255 56 134 169 234 137 232 100 19 80 162 255 0 255 101 0\n", + " 222 252 79 211 64 88 248 0 0 255 246 62 81 255 0 159 22 168\n", + " 70 255 99 135 204 192 255 150 0 0 255 255 67 235 55 255 69 0\n", + " 0 17 241 44 255 224 0 255 40 0 255 0 211 252 62 0 28 218\n", + " 112 0 255 0 81 67 153 0 255 0 129 229 53 255 55 101 0 255\n", + " 0 255]\n", + "postproc embedding mean/stddev 126.359375 89.33878063086252\n", + "\n", + "Looks Good To Me!\n", + "\n" + ] + } + ], + "source": [ + "# Test install\n", + "# !mv audioset/* .\n", + "from vggish_smoke_test import *" + ] + }, + { + "cell_type": "code", + "source": [ + "# import tensorflow as tf\n", + "# import vggish_slim\n", + "\n", + "# vggish_dict = {}\n", + "# # load the model and get info\n", + "# with tf.Graph().as_default(), tf.Session() as sess:\n", + "# vggish_slim.define_vggish_slim(training=True)\n", + "# vggish_slim.load_vggish_slim_checkpoint(sess,\"vggish_model.ckpt\")\n", + "\n", + "# tvars = tf.trainable_variables()\n", + "# tvars_vals = sess.run(tvars)\n", + "\n", + "# for var, val in zip(tvars, tvars_vals):\n", + "# print(\"%s\" % (var.name))\n", + "# print(\"\\t\" + str(var.shape))\n", + "# vggish_dict[var.name] = val\n", + "# print(\"values written to vggish_dict\")" + ], + "metadata": { + "id": "3e1KbwzCeV1s" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "pycharm": { + "is_executing": false + }, + "tags": [], + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "InMwNzlSZVgj", + "outputId": "1c1080dd-6f30-4ce8-9d4d-932b7b5c5971" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "vggish/conv1/weights:0\n", + "\t(3, 3, 1, 64)\n", + "vggish/conv1/biases:0\n", + "\t(64,)\n", + "vggish/conv2/weights:0\n", + "\t(3, 3, 64, 128)\n", + "vggish/conv2/biases:0\n", + "\t(128,)\n", + "vggish/conv3/conv3_1/weights:0\n", + "\t(3, 3, 128, 256)\n", + "vggish/conv3/conv3_1/biases:0\n", + "\t(256,)\n", + "vggish/conv3/conv3_2/weights:0\n", + "\t(3, 3, 256, 256)\n", + "vggish/conv3/conv3_2/biases:0\n", + "\t(256,)\n", + "vggish/conv4/conv4_1/weights:0\n", + "\t(3, 3, 256, 512)\n", + "vggish/conv4/conv4_1/biases:0\n", + "\t(512,)\n", + "vggish/conv4/conv4_2/weights:0\n", + "\t(3, 3, 512, 512)\n", + "vggish/conv4/conv4_2/biases:0\n", + "\t(512,)\n", + "vggish/fc1/fc1_1/weights:0\n", + "\t(12288, 4096)\n", + "vggish/fc1/fc1_1/biases:0\n", + "\t(4096,)\n", + "vggish/fc1/fc1_2/weights:0\n", + "\t(4096, 4096)\n", + "vggish/fc1/fc1_2/biases:0\n", + "\t(4096,)\n", + "vggish/fc2/weights:0\n", + "\t(4096, 128)\n", + "vggish/fc2/biases:0\n", + "\t(128,)\n", + "values written to vggish_dict\n" + ] + } + ], + "source": [ + "import tensorflow.compat.v1 as tf\n", + "tf.disable_v2_behavior() # This enables 1.x-style sessions and graph-based execution\n", + "import vggish_slim\n", + "\n", + "vggish_dict = {}\n", + "# load the model and get info\n", + "with tf.Graph().as_default(), tf.Session() as sess:\n", + " vggish_slim.define_vggish_slim(training=True)\n", + " vggish_slim.load_vggish_slim_checkpoint(sess, \"/content/vggish_model.ckpt\")\n", + "\n", + " tvars = tf.trainable_variables()\n", + " tvars_vals = sess.run(tvars)\n", + "\n", + " for var, val in zip(tvars, tvars_vals):\n", + " print(\"%s\" % (var.name))\n", + " print(\"\\t\" + str(var.shape))\n", + " vggish_dict[var.name] = val\n", + " print(\"values written to vggish_dict\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "pycharm": { + "is_executing": false, + "name": "#%%\n" + }, + "tags": [], + "id": "nP0xb46AZVgl" + }, + "outputs": [], + "source": [ + "# Define torch model for vggish\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "\n", + "# From vggish_slim:\n", + "# The VGG stack of alternating convolutions and max-pools.\n", + "# net = slim.conv2d(net, 64, scope='conv1')\n", + "# net = slim.max_pool2d(net, scope='pool1')\n", + "# net = slim.conv2d(net, 128, scope='conv2')\n", + "# net = slim.max_pool2d(net, scope='pool2')\n", + "# net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')\n", + "# net = slim.max_pool2d(net, scope='pool3')\n", + "# net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')\n", + "# net = slim.max_pool2d(net, scope='pool4')\n", + "# # Flatten before entering fully-connected layers\n", + "# net = slim.flatten(net)\n", + "# net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')\n", + "# # The embedding layer.\n", + "# net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')\n", + "\n", + "vggish_list = list(vggish_dict.values())\n", + "def param_generator():\n", + " param = vggish_list.pop(0)\n", + " transposed = np.transpose(param)\n", + " to_torch = torch.from_numpy(transposed)\n", + " result = torch.nn.Parameter(to_torch)\n", + " yield result\n", + "\n", + "class VGGish(nn.Module):\n", + " def __init__(self):\n", + " super(VGGish, self).__init__()\n", + " self.features = nn.Sequential(\n", + " nn.Conv2d(1, 64, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(64, 128, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(128, 256, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(256, 256, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(256, 512, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(512, 512, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2))\n", + " self.embeddings = nn.Sequential(\n", + " nn.Linear(512*24, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(4096, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(4096, 128),\n", + " nn.ReLU(inplace=True))\n", + "\n", + " # extract weights from `vggish_list`\n", + " for seq in (self.features, self.embeddings):\n", + " for layer in seq:\n", + " if type(layer).__name__ != \"MaxPool2d\" and type(layer).__name__ != \"ReLU\":\n", + " layer.weight = next(param_generator())\n", + " layer.bias = next(param_generator())\n", + "\n", + " def forward(self, x):\n", + " x = self.features(x)\n", + " x = x.view(x.size(0),-1)\n", + " x = self.embeddings(x)\n", + " return x\n", + "\n", + "net = VGGish()\n", + "net.eval()\n", + "\n", + "# Save weights to disk\n", + "torch.save(net.state_dict(), \"./vggish.pth\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + }, + "colab": { + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/_example_download_weights_v2.ipynb b/_example_download_weights_v2.ipynb new file mode 100644 index 0000000..8b1ee85 --- /dev/null +++ b/_example_download_weights_v2.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/tensorflow/models" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2VRmHhOURs1Q", + "outputId": "85d89491-ee5c-45f4-b5c8-6a28050d4c5b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'models'...\n", + "remote: Enumerating objects: 98107, done.\u001b[K\n", + "remote: Counting objects: 100% (620/620), done.\u001b[K\n", + "remote: Compressing objects: 100% (368/368), done.\u001b[K\n", + "remote: Total 98107 (delta 292), reused 524 (delta 234), pack-reused 97487 (from 1)\u001b[K\n", + "Receiving objects: 100% (98107/98107), 621.63 MiB | 16.85 MiB/s, done.\n", + "Resolving deltas: 100% (71270/71270), done.\n", + "Updating files: 100% (3871/3871), done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "%cd models/research/audioset/vggish/" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CZQpPKs-R2YA", + "outputId": "29331aa8-c12b-4e8c-ca12-c750d812e777" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/models/research/audioset/vggish\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": false + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YwhsYT8LQ8Xs", + "outputId": "9dac93d8-22e6-4654-abee-56557171f193" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/bin/bash: line 1: svn: command not found\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n", + "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (1.13.1)\n", + "Collecting resampy\n", + " Downloading resampy-0.4.3-py3-none-any.whl.metadata (3.0 kB)\n", + "Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.17.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (1.16.0)\n", + "Requirement already satisfied: soundfile in /usr/local/lib/python3.10/dist-packages (0.12.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from resampy) (1.26.4)\n", + "Requirement already satisfied: numba>=0.53 in /usr/local/lib/python3.10/dist-packages (from resampy) (0.60.0)\n", + "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25)\n", + "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.6.0)\n", + "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n", + "Requirement already satisfied: h5py>=3.10.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.11.0)\n", + "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1)\n", + "Requirement already satisfied: ml-dtypes<0.5.0,>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.4.0)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.3.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.1)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.32.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (71.0.4)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.4.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.2)\n", + "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.64.1)\n", + "Requirement already satisfied: tensorboard<2.18,>=2.17 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.17.0)\n", + "Requirement already satisfied: keras>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.4.1)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.1)\n", + "Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile) (1.17.1)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.44.0)\n", + "Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile) (2.22)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (13.8.1)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.0.8)\n", + "Requirement already satisfied: optree in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.12.1)\n", + "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.53->resampy) (0.43.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.8)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2024.8.30)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.7)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.0.4)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.18,>=2.17->tensorflow) (2.1.5)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (2.16.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->tensorflow) (0.1.2)\n", + "Downloading resampy-0.4.3-py3-none-any.whl (3.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m26.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: resampy\n", + "Successfully installed resampy-0.4.3\n", + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 277M 100 277M 0 0 88.5M 0 0:00:03 0:00:03 --:--:-- 88.5M\n", + " % Total % Received % Xferd Average Speed Time Time Time Current\n", + " Dload Upload Total Spent Left Speed\n", + "100 73020 100 73020 0 0 266k 0 --:--:-- --:--:-- --:--:-- 267k\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "This notebook demonstrates how to replicate converting tensorflow\n", + "weights from tensorflow's vggish to torchvggish\n", + "\"\"\"\n", + "\n", + "# Download the audioset directory using subversion\n", + "# !apt-get -qq install subversion # uncomment if on linux\n", + "!svn checkout https://github.com/tensorflow/models/trunk/research/audioset\n", + "\n", + "# Download audioset requirements\n", + "!pip install numpy scipy\n", + "!pip install resampy tensorflow six soundfile\n", + "\n", + "# grab the VGGish model checkpoints & PCA params\n", + "!curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt\n", + "!curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": false + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zHs_phLMQ8Xu", + "outputId": "dc6c9feb-ab23-4fce-80b9-c90c131d54b8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Testing your install of VGGish\n", + "\n", + "Resampling via resampy works!\n", + "Log Mel Spectrogram example: [[-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " ...\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]\n", + " [-4.48313252 -4.27083405 -4.17064267 ... -4.60069383 -4.60098887\n", + " -4.60116305]]\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.\n", + " warnings.warn('`layer.apply` is deprecated and '\n", + "/usr/local/lib/python3.10/dist-packages/tensorflow/python/keras/legacy_tf_layers/core.py:318: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.\n", + " warnings.warn('`tf.layers.flatten` is deprecated and '\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "VGGish embedding: [-2.72986382e-01 -1.80314153e-01 5.19921184e-02 -1.43571526e-01\n", + " -1.04673728e-01 -4.96598154e-01 -1.75267965e-01 4.23147976e-01\n", + " -8.22126150e-01 -2.16801405e-01 -1.17509276e-01 -6.70077026e-01\n", + " 1.43174574e-01 -1.44183934e-01 8.73491913e-03 -8.71972442e-02\n", + " -1.84393525e-01 5.96655607e-01 -3.43809605e-01 -5.79104424e-02\n", + " -1.65071294e-01 4.22911644e-02 -2.55293399e-01 -2.36356765e-01\n", + " 1.80295616e-01 3.02612185e-01 1.08356833e-01 -4.48398024e-01\n", + " 1.22757629e-01 -2.99955189e-01 -5.55934191e-01 5.05966544e-01\n", + " 2.05210358e-01 8.87591839e-01 9.03702497e-01 -2.10566416e-01\n", + " -3.27462405e-02 1.38691410e-01 -2.27416530e-01 1.14804000e-01\n", + " 5.95410109e-01 -4.76971269e-01 2.28232622e-01 1.54627025e-01\n", + " 1.64934218e-01 7.19252825e-01 1.24101830e+00 5.61996222e-01\n", + " 2.73531973e-01 3.09788287e-02 2.10977703e-01 -6.09551668e-01\n", + " -3.15282375e-01 1.76392645e-01 -8.96190405e-02 -4.26822364e-01\n", + " 3.12993884e-01 -1.56592295e-01 3.31673503e-01 1.29436389e-01\n", + " 1.66024208e-01 3.01903039e-02 -1.54465199e-01 -4.29332554e-01\n", + " -2.68703818e-01 -1.58071086e-01 4.00485486e-01 -2.55945086e-01\n", + " -2.66429391e-02 8.16181302e-03 2.98492879e-01 3.48756194e-01\n", + " -1.07143626e-01 8.88779089e-02 1.26810491e-01 -3.34817201e-01\n", + " -2.55428016e-01 5.07779241e-01 3.97584617e-01 1.78759634e-01\n", + " -8.04521963e-02 4.84320521e-02 -2.01262981e-01 -2.97957748e-01\n", + " 3.66831303e-01 4.56224501e-01 5.37960529e-01 -2.00488269e-02\n", + " -6.24543577e-02 4.15623039e-01 -1.88741475e-01 -5.36903143e-01\n", + " -1.78362012e-01 3.81366849e-01 3.96645039e-01 3.21936429e-01\n", + " -4.26683240e-02 -1.41018063e-01 -4.53833699e-01 -1.07017279e-01\n", + " -2.21892655e-01 3.51183444e-01 -2.58386552e-01 3.31110060e-01\n", + " -7.28939176e-01 -2.55487382e-01 3.56361002e-01 -3.16188633e-01\n", + " 3.12793672e-01 1.23501822e-01 -1.83649734e-02 -3.99395853e-01\n", + " -5.13507247e-01 -2.74227202e-01 -2.68650651e-01 2.24091530e-01\n", + " 1.09625012e-01 1.30929738e-01 -1.25994891e-01 -1.92615181e-01\n", + " 1.83567405e-04 2.04150438e-01 -1.03096753e-01 2.93378532e-02\n", + " -3.38305712e-01 -2.25749940e-01 -2.46723339e-01 -1.20763183e-01]\n", + "embedding mean/stddev 0.00065699156 0.34301957\n", + "Postprocessed VGGish embedding: [160 53 124 132 154 120 119 105 155 173 129 69 149 93 59 0 52 97\n", + " 157 144 153 194 251 108 48 174 131 190 195 79 59 60 169 93 167 247\n", + " 28 75 255 56 134 169 234 137 232 100 19 80 162 255 0 255 101 0\n", + " 222 252 79 211 64 88 248 0 0 255 246 62 81 255 0 159 22 168\n", + " 70 255 99 135 204 192 255 150 0 0 255 255 67 235 55 255 69 0\n", + " 0 17 241 44 255 224 0 255 40 0 255 0 211 252 62 0 28 218\n", + " 112 0 255 0 81 67 153 0 255 0 129 229 53 255 55 101 0 255\n", + " 0 255]\n", + "postproc embedding mean/stddev 126.359375 89.33878063086252\n", + "\n", + "Looks Good To Me!\n", + "\n" + ] + } + ], + "source": [ + "# Test install\n", + "# !mv audioset/* .\n", + "from vggish_smoke_test import *" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": false + }, + "id": "ekfk36EoQ8Xu", + "outputId": "53372a6d-fd4e-45e0-f181-9afd33750006", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "vggish/conv1/weights:0\n", + "\t(3, 3, 1, 64)\n", + "vggish/conv1/biases:0\n", + "\t(64,)\n", + "vggish/conv2/weights:0\n", + "\t(3, 3, 64, 128)\n", + "vggish/conv2/biases:0\n", + "\t(128,)\n", + "vggish/conv3/conv3_1/weights:0\n", + "\t(3, 3, 128, 256)\n", + "vggish/conv3/conv3_1/biases:0\n", + "\t(256,)\n", + "vggish/conv3/conv3_2/weights:0\n", + "\t(3, 3, 256, 256)\n", + "vggish/conv3/conv3_2/biases:0\n", + "\t(256,)\n", + "vggish/conv4/conv4_1/weights:0\n", + "\t(3, 3, 256, 512)\n", + "vggish/conv4/conv4_1/biases:0\n", + "\t(512,)\n", + "vggish/conv4/conv4_2/weights:0\n", + "\t(3, 3, 512, 512)\n", + "vggish/conv4/conv4_2/biases:0\n", + "\t(512,)\n", + "vggish/fc1/fc1_1/weights:0\n", + "\t(12288, 4096)\n", + "vggish/fc1/fc1_1/biases:0\n", + "\t(4096,)\n", + "vggish/fc1/fc1_2/weights:0\n", + "\t(4096, 4096)\n", + "vggish/fc1/fc1_2/biases:0\n", + "\t(4096,)\n", + "vggish/fc2/weights:0\n", + "\t(4096, 128)\n", + "vggish/fc2/biases:0\n", + "\t(128,)\n", + "values written to vggish_dict\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "import vggish_slim\n", + "\n", + "vggish_dict = {}\n", + "# load the model and get info\n", + "with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n", + " vggish_slim.define_vggish_slim(training=True)\n", + " vggish_slim.load_vggish_slim_checkpoint(sess,\"vggish_model.ckpt\")\n", + "\n", + " tvars = tf.compat.v1.trainable_variables()\n", + " tvars_vals = sess.run(tvars)\n", + "\n", + " for var, val in zip(tvars, tvars_vals):\n", + " print(\"%s\" % (var.name))\n", + " print(\"\\t\" + str(var.shape))\n", + " vggish_dict[var.name] = val\n", + " print(\"values written to vggish_dict\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": false, + "name": "#%%\n" + }, + "id": "4C6cVDhfQ8Xv" + }, + "outputs": [], + "source": [ + "# Define torch model for vggish\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "\n", + "# From vggish_slim:\n", + "# The VGG stack of alternating convolutions and max-pools.\n", + "# net = slim.conv2d(net, 64, scope='conv1')\n", + "# net = slim.max_pool2d(net, scope='pool1')\n", + "# net = slim.conv2d(net, 128, scope='conv2')\n", + "# net = slim.max_pool2d(net, scope='pool2')\n", + "# net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')\n", + "# net = slim.max_pool2d(net, scope='pool3')\n", + "# net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')\n", + "# net = slim.max_pool2d(net, scope='pool4')\n", + "# # Flatten before entering fully-connected layers\n", + "# net = slim.flatten(net)\n", + "# net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')\n", + "# # The embedding layer.\n", + "# net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')\n", + "\n", + "vggish_list = list(vggish_dict.values())\n", + "def param_generator():\n", + " param = vggish_list.pop(0)\n", + " transposed = np.transpose(param)\n", + " to_torch = torch.from_numpy(transposed)\n", + " result = torch.nn.Parameter(to_torch)\n", + " yield result\n", + "\n", + "class VGGish(nn.Module):\n", + " def __init__(self):\n", + " super(VGGish, self).__init__()\n", + " self.features = nn.Sequential(\n", + " nn.Conv2d(1, 64, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(64, 128, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(128, 256, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(256, 256, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.Conv2d(256, 512, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(512, 512, 3, 1, 1),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool2d(2, 2))\n", + " self.embeddings = nn.Sequential(\n", + " nn.Linear(512*24, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(4096, 4096),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(4096, 128),\n", + " nn.ReLU(inplace=True))\n", + "\n", + " # extract weights from `vggish_list`\n", + " for seq in (self.features, self.embeddings):\n", + " for layer in seq:\n", + " if type(layer).__name__ != \"MaxPool2d\" and type(layer).__name__ != \"ReLU\":\n", + " layer.weight = next(param_generator())\n", + " layer.bias = next(param_generator())\n", + "\n", + " def forward(self, x):\n", + " x = self.features(x)\n", + " x = x.view(x.size(0),-1)\n", + " x = self.embeddings(x)\n", + " return x\n", + "\n", + "net = VGGish()\n", + "net.eval()\n", + "\n", + "# Save weights to disk\n", + "torch.save(net.state_dict(), \"./vggish.pth\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "source": [], + "metadata": { + "collapsed": false + } + } + }, + "colab": { + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file