diff --git a/notebooks/regression/hetero_data_comparison.ipynb b/notebooks/regression/hetero_data_comparison.ipynb index f036424..9e66eaa 100644 --- a/notebooks/regression/hetero_data_comparison.ipynb +++ b/notebooks/regression/hetero_data_comparison.ipynb @@ -12,29 +12,21 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 18, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "from models.gaussian_mlp import gmlp\n", "from models.mlp import mlp\n", "from utilities.fits import fit\n", "from utilities.gmm import gmm_mean_var\n", "from utilities.predict import predict\n", - "from utilities import plot,errors\n" + "from utilities import plot,errors" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -70,8 +62,7 @@ "source": [ "os.environ['LATEXIFY']=str(1)\n", "os.environ['FIG_DIR']='figures/hetero'\n", - "latexify(width_scale_factor=2.4, fig_height=2)\n", - "\n" + "latexify(width_scale_factor=2.4, fig_height=2)" ] }, { @@ -81,7 +72,7 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAALoAAACWCAYAAABpYJK8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAS0ElEQVR4nO2dTWzbVrbH/0eOjUQGHgTT3ZqGvOvqjeO3mmVloMjywUlRKAU6wKsCFRjMzkHhdRDEu2KAzESLADMh8VDHi9kOoi5nVUVv16UaZx5QYOIPvUGSZtI65y1IKrTES93LT1G8P+DCFkXxUuRfh+eec3hJzAyNZt6p5L0DGk0WaKFrSoEWuqYUaKFrSoEWuqYUaKFrSkGiQieiGhE1iGiHiHZ9y1tEtElEO2HLgvj4448ZgG66yTQhSVv0GwBOmfkQwCeu8HcB9Ji5D2AbAIKWiTg+Pk54FzVlJFGhM3PHFa/3eghgA8DQXTQkok3BMo0mENsG1teBSsX5a9vq27iU9E4BjlsC4IuAt2oyy9zPtwBgbW0tyV3TFAzbBlot4PVr5/XRkfMaAJpN+e0kPhglogaALoCBu+gpfGJ2LX7QMvhed5h5i5m3Pvjgg6R3UVMg9vbei9zj9WtnuQqJWnRX5PcAnAJYAXCVmTtEtEtEKwCeAI6Qx5fNIsPhEMfHx/j555/z3pVCsLi4iNXVVdRqtcS2+fy52nIRiQqdmbsArgYs35dZNmv8+OOPWF9fx+XLl0FEee/OTMPMePPmDZ49exZb6LbtWOznzx2//Px8ch1Vj3Yu4uhJDFZEXLlyRYtcAiLClStXYm/H88mPjgDmYJFXq8CdO2rbLbzQxw+MN1hJUuxZsr+/j8PDQxweHmJjYwPdbhe3bt3CYDCY+tnBYID9/ekXyn6/j42NDRweHqLb7aLT6aDf74d+Rqb/JAjyyQHAb2si/Z6Yeabb1atXOQzTZHYkfrGZZujHpPj++++l17Usi03TZCJi0zTZsqxIfT5+/JiZmc/OzrjRaIz+f/r0aaTtifC27bGzs8NnZ2eB656dnfHu7q7UdlWOWRBEwedzvFWrzAGHWKijQlt023YseBCqg5V4+2Gj1Wrh6OgIzIyjoyO0Wi3YES4rOzuTieJarYbBYIDr169jf38f/X4fw+EQnU4Hh4eHI2vc7/dx69YtAEC328X29jb6/T5u3749td/t7W0cHBwEbrfX66Hf76Pb7QJA4DpJIet7q0ZeCit0z2URkWX4fW9vD6/HrrevX7/GnmoMLATvB7C7u4vNzff5tXq9jrt37wLAheWNRgOnp6fY3NyEYRjKgvRvt9FooF6vo9FoCNdJijt3HB9cBhVjVlihi3w5INpgJQ7PBUdctDwqKysro/87nQ7q9fpI3EE+dL1el972kydP0Gg0Qrfr/S/Td1SaTaDTAUzT8ctNEzCM4HVVjFlhhR6moU5HLWsWF1H2NmpWdzgc4uDgAIPBYOQudLtdDAaDkajq9ToGgwEODw+xsrKC4XCIfr+PXq83+t9b/7vvvkOv1xtt33vPPxi9desW6vV64HaB9+6TqO8kaTaBZ8+AR4+c1ycnFwejQARjFubAz0ILGoxaFvPCQnqDUA/ZgZVlWVytVi9U0lWr1cgD0iITdzDqYVnOgFN0jgWHVqij3IU8rY0LPewACEbikckj6lJ0khK6KJpGFHqO50foogOwsJCsyJmTO2llIqljFhZmDLlqC3VUOB9d5JufnwOffZZ8ZlSTD75x9wSikHIYhRN62PiO3czoZ58BX36Z3T5pkudf/xK/R6RuzAondJk4KzPwhz8Aq6vp1L9o0sW2gZcvxe8zq5fpFk7owGSoScTJSfHqX+LWuowzy3UtImRErJyiCHPgZ6H5B6NhEReZphp6zGMwGrXWJaweJa26liCSOGYy9S6Cczkfg9GwbKgMada/JFUqLKp1qdVqo/oSzzr3+310Oh0Mh8OJepQwZqWuRUTYQBRwrujXrqlts1BCjyvUtOpfsigVvn37Nra2trC5uYnT01P0+31888032NraQq1WE9ajyJBXXUsQtg3885/h6zADf/qT2vFNRehEVEtju3GEmmb9S1L3NYaxsrKCWq2Ger2OGzduYGdnB1999RUePHgQWI8SxqzUtQSxtwfI3LmYe/Wie9/o47FlT4nosfuet0xqAiM/d+7ID0T9GEa69S9J3dfoEVTrcu/evZHL4tWZeKW7XrGXvx7FY9brWsZdPpUYudLxDXPgozYAj8deb4693vWWAXgQtq3xzGgWg1AP2YFVmjd/FA21sonEgwu5D0Yb7lR1XgV55AmMTFO987RvwgiK7WddKlxE4gQXVI9vJkJn5n12Zgi4HvB2bXyB69b0iKj34sWLC+9FEU/aN2EE1VBnXSpcROIYoDdvgL/9TX791IXuita7A8ALHEWewKjZBC5fVtuHLCyrV0P97p3zV4t8OnEM0Lt3TvZbttQjrcHops8dOQBQdweddwFHyHDdGShOYGTbzq9Zfn9Utq7JEpXb5kR0OnLrJT73ouuibPheD+FMUTe+XqQJjH73O9X9AT7/3Pk/ipX96aef9ARGEjA7Exip4J2Pvb1oFYlA8LwvQRBz6LTSubO1tcX+28Ci6m1hAfjlF7XP6Cnp1IgzJd2lS/Ki9TN2XoXqSGU23Vnk/BxoNACJDPkIL/WuSRfbjiZyIHwmCD+FKgEAgOXl6J/99ttiVDCWiS+/dO4fiEK7Ddy/L7duoYRu28Dbt/G2kWRaXhMP2wb++EdnHKUKEfDrXyusXyQfXTVFLGLGv3JpiHs+TdMJ5foQ+uiFsuhJZThXV7ULMwvENVpzO1NXUhnOk5Pi3HE0ryRx7KfVrfsplNCTSDB4JF1Gq5Fn2ryZsqiE7QsldH9NSRJkOeOu5j1x7xTzePVK/spQqMGonyQSlQGDGU0GVCrJBQTGzuF8DEb9iGZYVUGX0eZDktWkslflwgr966/jff6jj3SFYV4kOdaSHZAWVujNZrQsaaXiZNRUSgE0yeKNteJkuVUppNC9+wxfvVL/7Lt3wMGBDi3mTbPpzMb14YfxtnN6Krde4YTun1oiiIrENzo50fMzzgK2Hc1Y+ZH19wsndFFoyjCApSXHYsvA7NRZaMuePbbtZKdv3oyXHV1akg8oFE7oolH2yYl6wRdHmKxSEw+vWvHkJP62Hj6UDygUagIjIPkbnXXSKDviVCuOY5pqUbOsJjCamKwoygRGQLKhKSDbxzSWnb295BJFuT8i3b1ndOi9JqJdAD33Tv9t0TJZkiwD0HOvZEtSV0/DUM+BSAmdiL4gon+PsE9A8GRFkScwAt5PLRG3DEDPvZItSVw9iaIlC2UtehfAfxDRARH9lYj+U70rAAGTFQUtC5vACHgfR49zGdQ39WdPElfPyOc8bL46rwFYB/Bv7v+/AvARgP8KWf+x7/8WxuZZDFomaiqPX1RtS0vJP8lOE45hxD9vIY/ZjD334nUA94joGwA3mPlbAGdBK45PYMQBkxUFLZMlqRJPwAlH6vBidtg2kMRkvFHuJZAq0yWiXwEYMPP/+ZZ95Ao+VcbLdJMs8QQcF0Y2yaSJjm0Dv/mN3NznMgjOW7wyXWb+H7/I3WWpizyIpMOBlcr/wtbp0dSRneBfFlUdFC4zGhRHX1yMujXG+flf0Gq1tNgTxrZtrK+vo1KpYH19HUdHyd7gozywDXPgZ6GND0aZnYGIaTpPLzNN53X0Qc5LBj5ls4yz9qeEZVlcrVYZwKgRHSUSQACccy1AqKPchTytBQk9+ODGicb8wEQk1Y9mOqZpXhC50z5l4DzNiAtziI4K57qI8DKm0VjDmq4FSIzngSnQ/0bIWFFIpeJkQuM+YGFuhA44B2BhIcon3+HatWsXfErts0dHbDTUZhJdWAD+/Gfg+DiBByyEmftZaLKui0e7HeWS+G7Cp6xWq2zpbFIkgnx0p6m5Lu22f5uT47IAhDrKXcjTmqrQLYt5eVlV6M8CTgr0ADUGlmWNfPWFhQX3mP4jkj8eNP4S+OrlEHq0Aek7Bn4fKHQ9QI3HRcv+KQM/KV9tTVPp8ZblELrogExvP2iLnjCWZfksOdxjrH5uiJwmem8MoY7majAavd55cvBERLiji9WnMp4Ysm0btm2j1Wrh/MJjLKJFtdbWxFlQlUlGc7fY01o2Fv0fgRZdE07QoLNarbJhGAHHU92i+330xcXJ9wOqT4U6yl3I05pKwigoO1qtMn/4obrQtdsyneDEkKiF++im6URZRJEVUeZ77DTNt9BFg1DDcA5ekDW42M4vnBQi4na7zaZpMhGxaZo61BgAESkIXRx1kbEpkn76fAs9bFQu5878cOGEXBxE6bi6CJFFNwxDKY4uE9ySjLzMt9DDfu2i9977eT+zYfyWiYgNwwgUuXZnghH56JZlcbvdDrD4wRY9pEjL15dULH2+hR7Volcq5yORm6YpGETpuHoYXmLIO4aey6fiuiwvS2U9ZbKj8y30sF+7aMR+6dIvvLj4eaiwRVZduzDBiFP/aiUAUyoUw5hvoTOH/9rHIzKGwWwYv1UWufbXw5kehflBSuje1TgC+QodwFM4s3c1fMtaADYB7IR9VrXWRRb1iIH21z3G3RXvRz/9mH7KwCspoUf0EnMX+ubY611EnO4iDn6rv7Dwd/fAT0YMpvnqZfbXwwag0yy6ZVkTV17J+LgsuQt9F0ADQMt9/QBA3f3/3vgPwd+SEnpwrP3lBbH7XRLZZEjZfHbRcfGOg8iqi66ACpWJMuQr9FFnwBOeFPqDAIvfAtAD0FtbW4v0jccRRV8WFv4emBRScW3K5LOHHRfRj2Da8ZGsNZchP6G7ovVE/dS3LFPXRSaz5vc9w+LpZfbZZa903g8i4yueUEeXkD4HALbcmbvuAgAzd4hol4hWoDhTV1TW1oKfruBVxnkVd6/dacAuVt5NJ/g+yfnCtm0cHx9LrcvMME0Tz2blQa5hv4JZaGn66ETvb9cSWSpZF2beLbplWby4uKh0lcthwJ6f6xK3JRl1abcnXRhv4CMSNBFNFfvS0tLc++hqlYqTP35RSDJhtNCZp5UKBJ9I0zSnnuTl5eXE9nFWGBemqsiJaCTmsJBkwmihM4cPSMNOxvTUttMMw5gLyx5ckKUm8rbvFv4wI5IwWujM00s9wy6vk/dAzqcb0263IwscAFcqlZGIp2VMU/DhtdCZ4ycnLMuSOtmGTN3pDBKW8InSpmVMtUVPSejM8ZMTMuUBAApp1aP44tOaZ9m1j56x0OMia9WLGG5M0pqPuyc66lIwoTPLW/Ui3XNqWdbIv07aomeIFnqSyEZhMrhUJ0KUZBDgjEW8H7NhGLy0tJT3d9ZCTxqZstQwgcyS6OPGyT0yck/CEOpI6mFdeTL+sK5ZZHV1FScnJ0qfWVpawsOHD9GcgSf6UsSHrs6gduI9rEsTzunpqfJn3r59i70cn/3oTSU3TeQLggnnzSSeUZ8hWugJEPVpGUdHR5k/cMC2bayuruLmzZs4Cirn9LG0tIRWq4Xq2NPRqtVq8ealDPNrZqHNqo/uJ+pgzmv+dHla++f54SohRH+tSkEiSHowmjaWZUmHHYNaWgPUKBEiZB8WTAot9KyZhRBknNj4DFvtMLTQ8yDqrXlBhVEqfca5snhXl4KihZ43cf34MNcmCXF7reDVl7MldEhOXsRzJHRmR5DLy8uJCDKNNmuJrAgIdZR5eJGIdgH0mLkPYDvr/vOk2Wzi5cuXMAwj710ZYRgGLMsCM+P4+HgmElhpkEccfQPA0P1/6M4OUCq+/vrridh01hAR2u32XIvbT94Jo1rQQiJqEVGPiHovXrzIeJfSp9lsotPpCLOOaeFlQU3TxKNHj3D//v1M+8+VML8mjQaFyYt4znz0caLGuFXbHPjessyOj87MHQANImogo8mLZhXPspumCSKCYRiJ+e9l8b2lCfsVzEKbZ4sug0rosESWW4RQR7pMVzNPCEsxZ17oRPQCQFCZ3SoAuYkA06HM/c/qdz9m5o+DPjDzQhdBRD1m3tL9l6vvqP3nHV7UaDJBC11TCoos9I7uv5R9R+q/sD66RqNCkS26RiPNXAidiDaJqJVDvzUiahDRjluVmUWfLff77mTR31jfmX9fwX4on++5EDqAT3Lq9waAU2Y+BPAJEdXS7GwGSpwz/b4hKJ/vwgs9z5oZZu64ovNeD1PuMtcS5xy+7wRRz3cWT6WLjeAydQBgBcAATrlvPev+vRPtvv9FWv0LqGXc34icvi+IqI6I57sQQmen4nECIvKyY5sANoioloaVEfXv7kMDQBeA+nRd6jyFT+B+65oVGX/fcTxxK5/vQghdBDN3AYCIcrklzz3p9+Cc9BUAV9Psj3N4PqufrL/vOHHOt46ja0pB4QejGo0MWuiaUqCFrikFWuiaUqCFrikFWuiaUqCFrikFhU4YlQ23iGoLTmawD2CbmW/nulMFQSeMCoSX8iaix8x8Pe/9KRLadSkQ43UdOZbJFg5t0QuEW48+gFPc1IdTmz7MdacKgha6phRo10VTCrTQNaVAC11TCrTQNaVAC11TCrTQNaVAC11TCrTQNaXg/wGraiWrEoqn6QAAAABJRU5ErkJggg==", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAALoAAACWCAYAAABpYJK8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAS50lEQVR4nO2dTWzbSJbH/0+KjUS+CKFzNT3qW58Wjn2a244DNHJdJ4OG00DPYKGGGhjMLUGvz0E2vgULpMdaIEBPSCwS+zDHBqw57GFOcevYt1XizAINTGxHO8jXpNv99kBRoWQWVUWW+CHWDyjYovgl8s/HV+89FomZYTDMOpWsd8BgSAMjdEMpMEI3lAIjdEMpMEI3lAIjdEMpOKdzZURUB9DwGzNvD6Y3AfQA1Jl5TzQtjE8++YS//fZbnbtpmF1I9IVui34dQH8g3F8TUZ2IbgI4YOYOgCsAEDZNxNHRkeZdNJQRrUJn5jYz9wKf+wDWAPT9aUS0IphmMITiusDyMlCpeH9dV30dWl0Xn4HFvib4uj9p2sCtaQLA0tKSzl0zFAzXBZpN4M0b7/PhofcZADY35dejvTNKROsAgj73EwB1/8PA4odNQ+Bzm5lXmXn10qVLunfRUCC2tj6I3OfNG2+6Cro7o+sA7gI4AXARwGUAbQDXiQgA9gezhk3LHf1+H0dHR/jxxx+z3pVCMDc3h8XFRdTrdW3rfP5cbboIrUIfdC4vj03rwxM2AHRF0/LIDz/8gOXlZZw/fx6Di9IggJnx7t07PHv2LLHQXdez2M+fe3756enZeVQ92pmIo+vorIi4cOFCqiLf3t5Gu91Gu93G5cuX0el08MUXX6DX601eWJJut4uPPvoIe3t72Nvbw/b2duL1ExEuXLiQeN98n/zwEGAOF3mtBty+rbhiZs51u3z5MkfhOMy1GrN3WLxWq3nTk/L9998nX4ki+/v7zMz88uVLXl9fH/7/3XffTVz25cuXvLOzI7Udf90+GxsbkfPv7u5KrTfpMbPt0XPpN6IP/1uW8PwKdVR4i66rs5IU13WxvLyMSqWC5eVluDFvK6urq2em1et1NBqNicvW63U0/ZCEImtra9jbC8/b9Xo97O+n05US+d7BxyaOjz2rr3KICy101/VucWGodlaS7YeLZrOJw8NDMDMODw/RbDZjiV3k3x4cHODatWsjbobvdvT7fQAYujn+/1euXEG328WtW7ektitab6/XQ6/XQ6fTGc4/Po8uZH1vVWNWWKH7vpyINMPvW1tbeDN2W3nz5g22NN5W1tfXAQA3b95Eo9FAr9dDvV7HxsYG7ty5M5zn5ORkZP6VlRVYloVuN7rP3+/3I9fbaDSG6wybRxe3b3s+uAwqxqywQg9zWXxidVYS8FxwxEXT43Lx4sXh/71eD7u7u0NrG4ZK9GN/fx8bGxuR6/U/y2w7LpubQLsN2DZA5P21rPB5VYxZYYUepaF2Wy1rlhRR9jZuVrff7+Px48cj7kKn00Gv1xtxJ65c8cqELl68iG63i263O5zH/7/X6+HJkyc4ODgYrt//bm9vD51OB9vb27h7965wvcHlJs2jg81N4Nkz4OFD7/PxsSf6IMrGLKqnmocWFnVxHOZqNbx3btuK3fwIZCMIjuNwrVZjAMNWq9XY0RH6KRi6IlVh0bTgOVaNumQu5EltXOhRB0BXWNFH5aQ5jsO2bTMRsW3bpRQ5sz6hR4UZIw7t7AhddACqVb0iZ84mjl50dB2zYNxc4a4t1FHhfHSRb356Cnz2mf7MqCEbAv3uM4hCylEUTuhR/Ttm7yB89hnw5Zfp7ZNBP//4h/g7InVjVjihy8RZmYGvvwYWF6dT/zJNdNe6TKOuZdq4LvDqlfh75hiZ7yi/Jg8trDO6sCD230QtTke1aLUuonqUadW1hKHjmIn6YeOd0hBmw0f3s6GvX6svm0X9Sxzi1rqo1KPkpa5FhEyebabLdKOyoTJMs/5FV6mwKJt5cnKCdrs9TPJ0Oh10u1202230+/3QepSobeShrkVEVEcU8Hz0q1fV1lkooScV6rTqX8ZrqP3nGnX2C27duoVGo4FGo4Hd3V10u108evQIjUYD9Xr9TD1KFHmpawnDdYG//z16Hmbgm29muHoxiVCnWf+SVqnw6uoqVlZWsLOzg42NDXz11VfY398f6VzKdDTzUtcSxtYWIPPkYi6rF4moEfwbl9u3z9Y8yGBZ061/0fVco09Yrcvdu3dx586dYU3L3t4eDg4OsLa2NlLsNV53kve6lnGXTyVGrnR8o3qqcRqAdQB3A5/r8B6A3oE3epc/vTmYdyNqfeNRF9VoS5L6F9kIgihKoLPupiiolU2Iyzlintf0oi7sPSA9brlvMfMXPBjWQmWkrnFsW32fpv0QRlhsP+1S4SKSJLigenzT8tFXiWh9MDARMGGkLiJqEtEBER28ePFiZEVxxDPthzDCaqjTLhUuIkkM0Lt3wF/+orBAlLmP2wDsCqbv+99j4MZgzKUZb2FluufPq93i4hZ7maIudVSOmUxiaFJrtUZWmZ7rMs7AOtfHJkeO1BWF63pXs/z25ec1pIvKY3Mi2u3J8wDTG5KuEXBHHg8+r8Oz5IA3eNHqYB6lNNzvf6+2P8zA55/Hj2m/ffvWvxsZImBmvH37VmmZoMsXl7BxX8KgvJ/E1dVVDj4GFtdCV6vATz+pLWOGpFMjyZB0587JizbI2HkVqmMqo+nmkdNTYH0dkMiQD6nX61rHETSE47rxRA5EjwQRpFCZUQBYWIi/7J//XJxy3bLw5Zfe8wNxaLWA+/fl5i2U0F0XeP8+2TqKUMFYFlwX+MMfvH6UKkTAL3+pMH+RfHTVFLGInP/k0pD0fNq2NyxGgNTeYTRVdGU4FxeNC5MHkhqtmR2pS1eGM84glQa96Dj2k+rWgxRK6DoSDD5FeeJoFpk0bqYsKonDQgldR4IhSJoj7ho+kPRJMZ/Xr+XvDIXqjAbRkdoP6cwYUqBS0RcQGDuHs9EZDSIaYVUFU0abDTqrSWXvyoUV+r17yZb/1a9MGW1W6OxryXZICyv0zc14WdJKxcuoqZQCGPTi97WSZLlVKaTQ/ecM44zv8vPPwOPHJrSYNZub3mhcH3+cbD2DF3xMpHBCDw4tEUZF4hcdH5vxGfOA68YzVkFk/f3CCV0UmrIsYH7es9gyMHt1Fsayp4/retnpGzeSZUfn5+UDCoUTuqiXfXysXvDFcQarNCTCr1Y8Pk6+rgcP5AMKhRO67gedTdIoPZJUK45j22pRs8IJXWdoCkj3NY1lZ2tLX6JINQcylWdGieju2LTmYPpG1DQZdJYBmLFX0kXX3dOy1HMgUkInon8hon+SmZfHBjAKG6woyQBGwIfX8yUtAzBjr6SLjrsnUbxkoaxF7wJYI6LHRPSIiP5ZYRthgxXFHsAI+BBHT3IbNMNgpI+Ou2fccy4r9GNm/k9mvg7g3wEQEf1rvE1+ELhoGjO3mXmVmVcvXbo0MuOkOLoszMBvf2vCi2myuamnRinOswSyQv83IvqaiB4BuA5vAKKnksuGDVYUewAjXSWegBeONOHF9HBdQMe7BOI8SyA73MUjAD1m/j8AIKJfiGYMDmDEzF14gxVdJ89X8AcrCpsmhe5woAkvpoPrAr/5TfxhLcZRPW+Fq0fX9YC0T7X6v/jmm//GpumVThXd503wLMHs1KOHxdHn5uKujXF6+ic0m024xlnXiuu6WF5eRqVSwfLyMg4P9RpU5Y5t1AikeWhho+k6jjcSK5H313GYLSvuiKyvGPiU7TKO2j8lHMfhWq3GAIaN6DDxyLl+syzhpoU6ylzIk1qY0MMPbpK3JzxlEry40qCObdsjIvfapwycJhb5hPfFCnVUONdFhJ8xjccSlkwtgDaeh/YU/wsRLrSQSsULSSZ9wcLMCB3wDkC1GmfJn3H16tURn9L47PERGw21kEu1Cvzxj8DRkVd+/exZgkx2lLnPQ5N1XXxarTi3xJ/P+JS1Wo2duK/KKDlhPrrX1FyX4NsswvplIQh1lLmQJzVVoTsO88KCqtCfhZwUmA5qAhzHGfrq1Wp1cEz/FssfD+t/CXz1cgg9Xof0Zwb+I1TopoOajFHL/ikDb5Xvtrat9HrLcgg9/sufnhqLrhnHcQKWHINjrH5uiLwm+m4MoY5mqjMaP51/tvNERLhtitUnMp4Ycl0Xruui2WzidCTfHy+qtbQkLu9VGWQ0c4s9qaVj0f8WatEN0YR1Omu1GluWFXI81S160Eefmzv7/fz8GT9dqKPMhTypqSSMwrKjtRrzxx+rC924LZMJTwyJWrSPbttelEUUWRFlvsdO02wLXdQJtSzv4IVZg9F2OnJSiIhbrRbbts1ExLZtm1BjCESkIHRx1EXGpkj66bMt9KheuZw783TkhIx2okxcXYTIoluWpRRHlwluSUZeZlvoUVe76LsPft6PbFm/YyJiy7JCRW7cmXBEPrrjONxqtUIsfrhFjyjSCmxLKpY+20KPa9ErldOhyG3bFnSiTFw9Cj8x5B9D3+VTcV0WFqSynjLZ0dkWetTVLuqxnzv3E8/NfR4pbJFVNy5MOOLUv1oJwIQKxSiyFTqARvCvSlOJuoiu9vGIjGUxW9bvlEVu/PVoJkdhnkoJ3b8bxyA7ocN7CHofwE5Q6ACaANYBbEQtr1rrIot6xMD46z7j7op/0U8+pp8y8FpK6DG9xMyFvjI27aY/DcBO1PI6hR60+tXqXwcH/mzEYJKvXmZ/PaoDOsmiO45z5s4rGR+XJXOh+9a7OZi2G3BndsYvhGDTJfTwWPurEbEHXRLZZEjZfHbRcfGPg8iqi+6ACpWJMmQn9JGNAfscLvTG2HxNAAcADpaWlmL94nFE0Zdq9a+hSSEV16ZMPnvUcRFdBJOOj2StuQyZWvQmgDqPCj1110Umsxb0PaPi6WX22WXvdP4FkfIdT6gj2QGMkvAY3oBGF+FZciDBAEZxWVoKH1fEr4zzK+7eDIYBO1UcaSf8OcnZwnVdHB0dSc3LzLBtG8/y8iLXqKsgD22aPjrRh8e1RJZK1oWZdYvuOA7Pzc0p3eUy6LBn57okbTqjLq3WWRfG7/iIBE1EE8U+Pz8/8z66WqXi2YtfFJLUjBE686RSgfATadv2xJO8sLCgbR/zwrgwVUVOREMxR4UkNWOEzhzdIY06GZNT216zLGsmLHt4QZaayFuBR/ijjIhmjNCZJ5d6Rt1ezz4DOZtuTKvVii1wAFypVIYinpQxnYIPb4TOnDw54TiO1Mm2ZOpOc0hUwidOm5QxNRZ9SkJnTp6ckCkPAFBIqx7HF5/UfMtufPSUhZ4UWatexHCjTms+7p6YqEvBhM4sb9WL9Myp4zhD/1q3RU8RI3SdyEZhUrhVayFOMgjw+iL+xWxZFs/Pz2f9m43QdSNTlholkDyJPmmc3Ccl9yQKoY4K9w6jPLK4uIjj42OlZebn5/HgwYNcvDuJYr50NYfamZ13GOWRk5MT5WXev3+PrQzf/egPJTdJ5FXBgPO2jnfUp4gRugbivi3j8PAw9RcOuK6LxcVF3LhxA4cTXhM3Pz+PZrOJ2tjb0Wq1WvHGpYzya/LQ8uqjB4nbmfNbMF0+rf3z/XCVEGKwVqUgESTTGZ02juNIhx3D2rQ6qHEiREg/LKgLI/S0yUMIMklsPMdWOwoj9CyI+2heWGGUyjaT3Fn8u0tBMULPmqR+fJRro0Pcfit49WW+hA7JwYt4hoTO7AlyYWFBiyCn0fKWyIqBUEephxeJ6CaAA2buALiS9vazZHNzE69evYJlWVnvyhDLsuA4DpgZR0dHuUhgTYMs4uhrAPr+ByJayWAfMuXevXtnYtNpQ0RotVozLe4geUgY9ccnEFGTiA6I6ODFixcZ7NJ02dzcRLvdFmYdp4WfBbVtGw8fPsT9+/dT3X6mRPk102hQGLyIZ8xHHydujFu1zYDvLUt+fHR4gxetDlyWVAYvyiu+ZbdtG0QEy7K0+e9l8b2liboK8tBm2aLLoBI6LJHlFiHUkSnTNcwSwlLM3AudiF4ACCuzWwQgNxDgdCjz9vP624+Y+ZOwBXIvdBFEdMDMq2b75dp23O3nIbxoMEwdI3RDKSiy0Ntm+6XcdqztF9ZHNxhUKLJFNxikSePVLlNnkGVtMPNeytutA2j4jZm3U9hmE0AP3nuhZv73CvZD+XzPikVfB3Axg+1eB9AfHPBfD4QwNXJQ4pzq741A+XwXXuhEtA6gk8W2mbnNzL3A5/6UN5lpiXMGv/cMcc93IVwXItoImdyBd/vqTPuEi7bvn+iBpb02zX0Q0M9gm5n9XiJaiXu+CyF0kS9GRP4VvgLgIyKqT8PKRPmCg+2n5Ss/gfcmbgBA0LqmRcq/V7R95fNdCKGLYOYuABBRJo/kDQ76XQAn8HzGy1PeZOrvZw2Swe8dIcn5NnF0QykofGfUYJDBCN1QCozQDaXACN1QCozQDaXACN1QCozQDaWg0AmjsjEoolqFlxnsAVhj5luZ7lRBMAmjAuGnvIlol5mzqK0pLMZ1KRbXB1b9BBim5A0SGKEXjwaA/yGijUFdukEC47oYSoGx6IZSYIRuKAVG6IZSYIRuKAVG6IZSYIRuKAVG6IZSYIRuKAX/D60B3lhChgjCAAAAAElFTkSuQmCC", "text/plain": [ "
" ] @@ -1106,6 +1097,127 @@ "savefig('MCMC.pdf')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VI" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "from utilities.vi_helper import vi_model,vi_predict\n", + "\n", + "params = [[64,32,1],[nn.relu]*2]\n", + "mlp_model_vi, vi_model, results = vi_model(params,x_train, y_train.flatten())\n", + "\n", + "mean_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).mean(axis = 0)\n", + "sigma_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).std(axis = 0)\n", + "mean_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).mean(axis = 0)\n", + "sigma_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).std(axis = 0)\n", + "mean_vi_test = vi_predict(vi_model, results,mlp_model_vi,x_test).mean(axis = 0)\n", + "sigma_vi_test = vi_predict(vi_model, results,mlp_model_vi,x_test).std(axis = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/hetero/MLP_VI.pdf to figures/hetero/MLP_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/hetero/MLP_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAus0lEQVR4nO2deXhT55nof0erZXnDK7vBrCEkgDFkB5IACWmk6XQgZDJtk/ZeZ+HS2/bpTTPt0+matE07na1Mk4Zpm96baUtopqnUJoF4SiBkARyWJBhsVgMG75ZlS7LW7/7xHW22vLHYMpzf88iSj47O+c753vd7l285ihACDY1rHd1oF0BDIx3QFEFDA00RNDQATRE0NABNEQZFUZSvKoryqPr6QFGUlYqi/FxRlLLRLlsqRqO8iqLkKYqSd6WOPxIYRrsAY4D9QogqtaLXqZ+rgUEFS/3NA0KIF4ZyIkVRyoEyIcTvU3z3KHASyEv8PsVvLqW8a1OdO/od4ALKhRA/6vX1SuAxRVEA8oFK9T0PoL9jphOaRRic6t4bhBAupFAOiBDCNVQlUFmJFKAkFEX5KlAthKgCVg3ym4sqr2oxeh87+t1KIF89fyqh3i+EWCWEWAU8pW6LKuzXBjpvuqApwiCoQpSKCkVRtqquSBnIVlP9P0/9f6WiKD9P+PymoijliqI82/tgqrBV9XOuJcjWOLpveX+/GaC8+aq7tFYty0q1LI+q5S0DytRj9mZVwvn6fC+EOBm9fiFElRBiP1ClWrEf9FOetEJThItEbR0RQvxICHFSVQYXssX8WsI++b323w+0RYUZpGBHvx8irov4zbNIq3ASWAeUA+uBk6rlqlI/pzpmXsJ3qwaIB5ZEP6gKWaWeI+3RFOHSaE/4XIYUsDL698dd/R0oobVdnELQ9qH625DUAg/0m1RUCyH2CyEeQyrsD5CCHStvP0H1B4MdONGSRK2kWs6ysRBIa4owBKJBLwmug/qeWMllwJvq53bV7SiP7pPwuQzZclZEj68KZxVQ0Ou8j6ofX0C6YuXRc/T3m/7Ki/Tdv5ZQrrVqGfaRrNDl9EKNc8rVY70Zdb8SygdSUU+on38P5KkB9pYB3LW0QdHGGmloaBZBQwPQFEFDA9AUQUMD0BRBQwPQFEFDA9AUQUMDGAOKcO+99wrgkl4rVqwQK1asuOTjaK+Re13mOhuUtB992traesnHuPvuuy9DSTRGkpGus7TvUKuoqBDV1X0GVGpoDAdlsB3S3jXS0BgJrglFWLNmDWvWrBntYmgMg5Gus7SPES4HPp9vtIugMUxGus6uCYugoTEYmiJoaKApgoYGcI3ECPfff/9oF0FjmIx0nWn9CBrXAlo/gobGULgmFGHFihWsWLFitIuhMQxGus6uCUXQ0BgMTRE0NNAUQUMDuMzp08SlA5EL0/5I3d5nAdv+FrXV0BgNLnc/wgNAlRDi9+qS5C8Aj6rb9qvrgP5eXdQ2adtlLkdyoR544EoeXuMKMNJ1dlkVoffKz0IIl6IoS0gQdHWVtT7b1DVBrwgbNmy4UofWuEKMdJ1dkRhBbfHX9fO1a7Bt6grN1YqiVLe0tFxyebxeL16v95KPozFyXGqdORxQWSnfh8JlVwR1fcxEVyfVArYpF7VN+P8FIUSFEKKiqKjokst03333cd99913ycTRGjkutM6cTrFb5PhQud7C8Ern8eDtyOfTFyAVsH1CfphJdJDfVNg2Ny4bNJpXAZhva/pc7RqhCCn/iNhdS8AH297dNQ+NicTjiQm+3y212e/zzUBiT/QjD9f80rm6G6walYkwqwuW4cI2rB5sNPJ6+blC0wVQUBnWQxsx8hETzN1z/75FHHrmiZdO4/Aynzvpzg6INJmAHBmw2x8x8hMpKeVEeD2zePNql0hgLRBvP//gP7EIMrAhjxjXqz/wNhdbW1suyYp7GyHE56sxul43mYEoAY0ARXC7p54G8qOFkAqKsXbuWtWvXXtZyaVxZLrbOvvENmDNHvg+HtFeEzk4tMNYYOlu3QmamfB8Oaa8IubkX7xJpXHssWAC1tRAMDi+9nvaKkJcbYfMLIuYSORywerV8af0IGlGiqVKXC8rKwGIZnheR9opA0M83vuJmzpwI3/gGbNoE+/bB0aN9L1TraLt2SUiVMnEiTJo0PC8i7fsR6s8ZefE/LRQVBtm61YjFoiMQSO0uJXa0JQbVTzzxxMgWWuOSGW6dRfuWNm7slVAJh+Clb5Xw8DNNA/0+7fsRSgoXigkl7+Pzwbr7Pew9ksf5C3omToTt25P3TTXmROPqp796d/whiHNLN6aW2mf+/b9vHjCPlPYWwdejUFwYYemiHpoaFZbObaNpSj62T/Yten89jGfPngVgypQpV7q4GpeJ4dRZSk/A78W5pQerJcy2I2UVgx0j7RXBbBKYzYKf/jKX2WVB3Hth3X1tOP+Qz969RpqaBrcAn/nMZwB46623RqbQGpfMUOvM4YD6evl540Z128senK/0UFIcoanNxPI5R/dC8YDHSftgOTc7TN1JI7PLgvJ9Zpitr+fReMbDT/8tTGOj1sdwrZAqGeJ0wrx5UFoKIKj8rI9N/xbGmqWjqd3M5n9s5vlHHHsHO3baK0JeToQf/0M7C68P8IXPdzK+KMyCeX527cuiOD9A3dEQNlt6xzkal4feo46j1qCmBmz3R3C+7KGxwc/h4xnUHDdTUhSi8v8U8/iL9qWDHTvtFaGhycATTxVQfchE3QkDu/ZkcOyUidllIZrbjaxb3YL9tgaIREa7qBpXmMTxZg4HbNgg0+iIMPalp7Eta6budCblNwYpnRyiqcVAY4uO/3xvwecGO3baxwjNbQYyLTo8Ph0nzxgpmxKiuVWPokButmDvxzk4XmnEudODbV0m9k/qR7vIGleIxGRIZSX09IDbHSE/OwA9Hux2K5jbcW7PxLZaTvx/8rvjyMv0NUDGgMdO+/SpNbNcBEP7MBsh0xIhFFaYOimEEBAKKRgMgtZ2PTolwtyZIbb/xQgGU9IxnKottWnjNMYMg9WZwwFPPB7G3xMhwyT42bPt8nfbMykpktbAttoLIsxTXzduO3K++N6Bzpf2ilCUv1AsuuFdjp4w4u7SkZUZYe39XmrqDJRODlN/zsDR4waaWqQlmDcrwHe+p8f+KdMgR9YYi8T6DO52g7uFJ5+dTE62wN2to3RSiHmzg7y+w8KaO314vAqbf9wEb75o4+Fn/jTQcdM+RkAIli7wEQ5BTnaEwvwwHq/Cxs93sfknrWz8vJu5M4MYDGAywamzRpwvd4OvO3aI2tpaamtrR/EiNIZLqjpzOODJJyMcrPbz5NdNYMrgx9/swN2tIycrwuFaIzV1Rtbd301NnYH6cwYc2zOHdL60VwQdYZrqu3numQbuWe7ju191sfknrdjvkT6g/R4v27c08eVHO8nJijB9aoiS4giVj/hxbHED8Nhjj/HYY4+N5mVoDJNUdeb8Y5jZpT4OHzHgC+jY9Ks87Pd4WXd/N3UnjUyaEKZ0coin/95F6eQw82YHcW7PGtL50l4RIhEF222N0NIAAR8kuHKObZlUfqWQb/wwj6YWA88928aBqvM0tZmwZuvY9FOo/KyXtrZRvACNy0PQT4nVRd1xPeOLw2RbBdEnQjW1GJhdFqT2hJFON1R+pZCSohA1tQbqT0f41+03PzrY4dM+a1Q63oP9bherN5TT0Gxi3yFwbrNgu7cH5/ZMDh42cehwFgvnBwBpITrd8LIzh4K8MPNmdtLaHKagIO11XqMfHK/04Nzqpb7BxOyZYfZ/qGfm9BBLF/VQ+ZVCOt1Qd9LInBlBDh3JYM2dPpqadBDwcb7JwK/OlZd/cZBzpL0iJKLodJxtNBLa42ffgWyKS6CmzkhebpiPjxrJy4ng2JbJoSMZMs3apsMTMFOYF4QAciSifkxdskZXO86XBdYsheZ2I3UnjVgzBSBoajFgzRTs2mNh2c091J00su4TXTRdiGBbfJxNJ8oQigkhlEEzQmNGKjY+eIZNv5vKhVYTbo+RiFuwYnE7za2F+Px6Jo0Pqj5hJuvu72brn7L43INdPP33LlZ8CoiE4fxxGD8djObRvhyNoRAKQHM9tjXFON/MprgwwrlGQWeXjmOnjOpOCuvu76apxUDl33ZgX1QLXe1gzgSTCeeOQhbkHjosV3Tpn7RPn1bMKRPVv3kOgMrvzaOxzUTdGSvr7r4gc8XLWrCvteDYVRTrSIkG0lGqdr0DwMqbyyEiYHwZZAwtm6AxCoRDVP3X78DXzcpVK0GRbq1jWyYb/r6AHr+Cr0fhrtt7GF8UZvNPWnH8MYLzzwZst53Hfrebbzw3g61VE1gwsxN3s+foGx/Nvm6gUw7JIiiK8jfACSHEwUu+yEvAtrwZ585iKj91DvsydamPYAAa/NBtAJHJ3gOmPgqxctlt8YOEAnD+GJRMA2vuyF+ExsAE/dBUz8qKBbKxUuKPSI7W55PfyycnK0LdSSML5vZQuSGD+rMK82b4cL43BfvKGrZWTSDTHGHbniLumuMKDXbaIVkERVGmAyuBVYAAfi6E+MtFXuqwmDlpjrhzSRW25c1x4e9NJMLqDYtoaM+k05PBWpsv1uFmW+1l6iT5wPKF8+fJ/cMh6PFCwUTILUq62RqjSI8HGk+BonDw2Ckgoc4ScGzLlI3dilacTh1Wc5C3DhVxptGCosDdS1upq7fS6jJz6w3tl88iAG1CiM3AZkVRFgH5iqL8TyHEfwzzUodNZ7cJqyWMc2dxkiI4dhWy6XdTQYGl17s4fDoXvS5CYXY3HlcYkIHUpl9m89HRZykcF+ajt16SP9YbwJIFbQ3SQuRPBJ2WVRpVutqh5SyYzGAw8aVvPg3AW//1GyBB+Fd7sa/swl5eB64m6CzFuXsCxflB6hszQQj+9HYJ0yd5mVfWzcvPHIDa996GZwY8/VAV4euKouQin3lwEvjBRV/wMMnNCuDx6bEtb07a7txZzNF6K11eAzUnsiif66bujJX7l7XS1KJn6Uw3Td0FgIJeJ2jt6DUYT6eDzBxwt0n3qniqllEaDSIR6GiEzmbIsIIuuZ6iClB/zoDZLHjyO3ns3eGmqW0Ktjszsd/Zhv3ODhy7CvnmczNodZnJyw6hoEjfZYgMtea3ACeFEJ0Qc5VSoj4sZJUQ4qmEbRf9VM28rCCb/6EGkFbAubMY2/JmbMubef2dQrIzQ+TnBBhfEGDBbDdbqyYSFoLzzRnMmerhE6s6+eioQuG4cKrCQmY2+DxaRmk0CIeg+Qz43GDJ7uOitrXrYsGx368QiQiuL3WxdccEcqwhXt4xiS+sr+fpDSexL2uNeQyJcjJUhqQIQogDvf4/BZzqZ98qRVFifeOpnqA5rKdqhkPSd8yw4txZHHOTpHLUxC7YvqyVyu/NY/ZUD9veL8RsjHDqgoWmM+3MmdrdZ0RqEhar7LVuOCaVIcM6lNuicSkEeqDpFIRC0jKnoLVDj9EgcHXq5LB7a4hOn5l1Kxv56ZZS8nODbK2awNL5bpw7iykp6KGpLQPb8uZY40lwaMUZCcd4CQkPC0x4qmbvbST8H3+YoMsNZ2rgwnFstzYkuUn2Za2xTJJjVyG25c2MLwjwyRWN5FhDTJ/Yg+3ODhkH+L0yQO4PkwUMBmkZujsu5/Vr9MbjhoY66bpY+m90CseFmVvaTa41SG5WkAyz4MdfquPpDSf5wvp6TAbBupUXYg3k1qoJsYYyRjgIMOiD+K5IP4KiKFuFEOuin4GnhBAn1db/WfWVtK33AwWjzCwpE3fO/wO2RUexLz4OecU4Di/AuXtCTAmslnBMQRJNYrSV+PDYx9w2v5WnHhonA+O8EtD1kymKZpTyJ0BeUSyHrXEZEAI6W6DtvEyNDhCTvfvuu7y9w8/x03dSUhSmqT0j1uIP9m5b3gyRCM6qHGzX7ce+4ONWHn5mQGUYieiwzxM0FUUZ8KmaiXR6M7CaAzgPzMW+9DS4mnC+ZsSa62bTb6dIv1KBjevPsGnLVBpazOyrycbnNzC71MOuA+NYc6uR4+f1YP4I2hvA2wnF02SGojd6A2RmQft5mdMunNQngNO4CCJhaG2Q2SFLVv9ZukgEXM3cWpzBt/es4uiZLDw+A1NKfLS6Clk0163WaStbqyYwe6qs4x9/qVbGCJEwuFqo/MGNWE0unAfnYl/w8aD9CFfq8bJlCe7OC0CF+v+bA2xLSW5mDx6/CVvFMTnY0GDEVnEcj8sPPd3Mm9xK6XifvAkCFBRaO8zMnuqhrt7KupUXONlwgJlT35QCbcmRAn6mBjrbkkazxi9CzSh5XHDhpHStNC6eoB/On5D3MzO7fyXw+6TL1H6eZ7d08v7hIzQ0Z9DuNvDR8WxaO4wcOJrDupUX8Pj0rFt5gQNHc/D5dWzaMhU6m3Fs9VD5/RsoyfPgCWZgW3xsSEVM/yEWZRNF9bOpxw469pbhrJ6B7aZT2NcEcXwwi28+P4OzTRamju/hO48fx76slRWPPklbp5ub51+HbflN2JfdAuEw+D1gzYOiKWDsJ5j2+wB1WIZZG5YxbHzdMihWdGC2pN4nEgFXi7TCej2YLEy459t0uI0EQzuICOnGmo0Rli3qYPvP4g9iXb1hEeebTUzMbmf7135D5eZPYM0I4fGb2Pz4G9F6buThZyYMVMwx7QDbl55k84Y32VtbzJy/XclLr1g5dT6TnKwQ3h59LJ3W1mnkxDkvje15OHfukT/W62XKrqcbzh6Bblfqk5gtoDPIjJIWRA+daDxw4bjM2PWjBI4/v0blhv+Fw/FHuY9J7leYF8BgiGAwgEEnMBsjFOYG2fjgGfm7nYVUfmsmSycf55ZpJ9l4T7X0FpaciHsQw+Cq6EHaumcemaYg2w6UMntCM3WNhXxh7WlA5pTr6q1ERBEHjuj52ddviv9QUWQrHw7JCssugqJJfYM4o0ma86Z6mfYbV6IF0QMxlHhAtQLOV1/FarHgfP8Q9lV3xr52ezrp8TcwbdIWfL6/5voZ3WxcfwYEVH5rJvVn9cyb0EST18LmDfFFcO0VJ7BXnBh2ka+K2lx3y1G8ASP3LDjFOGuAm8oaWFp0EC6cwLmjAEUBIQqZN+OfpVvUG71Bxg7dHTJ28Hal3iczCzqaZCdQeND469ok0DN4PBCLBRqwrbgVTyCEbflNSbu0dLgxGBRa2rfx3Ndr2P7v+7EvOYXzNRPWkAyKPcEMbEuGL/SpuCoswtMP7ubpB3cDUPn8vWqW6Trsi/9Eie4sQggyzBHZovTCses9nDv3xGOHUEBWUl4J5I9Ptg6KDqw54HXL/oaSaWAaeL2cawpPJzTXx8dy9UbNCMlYwAiWbOwrbsO+4rY+uxaNy6Glw80X1i/EvuQ0jv8y4Xy3jJI8F03uHDbed2DAlt9RPQNn9Sxs5bXY5x8atOhj3iI4qmdQ+fy9OKpnAGCrOIbHb6Ikr5vpX3yC7//xVsBOpj6TvXu/SeV3/wnHrvdiv3fu3IPVYo7HDgaTjB06W+BsbdJqGDEsWSAiUmE87hG4yjQnEoH2C9B0UjYMqRqHHi+cq5VKYM7ss49j13tUfu9fYnXz5b/7Lvfe9CuWTpgJZ4/gfHcyVkuYJncOmx9/A3vFiT51n4izepZsED+YNaRLSH9F6PRAZwphVIldcLW8YHvFCTY//gZNrizOt8uu+0ColgzzJLbu3IM10IXzzR0QCePY9R71F5qpOXUm2TQrSrxFa6iT/m6411ilaIU3npTukrhGl5wMBWRWyNUkG5De8VU4LDvQzh2Vny3ZKftlkhqkgI/DHy5gqqUQ5zuTZRC8tG8Q3LvuE4k2iENNn6a/a9Tuhse/DwvnwPJyWHwdGOPFtlUckyawV5bAVnGMqo+mcL49G6N+PF7/bm6etRyP/wwlOUYqn/o29W4v82bNwOMLpI4djCY57MLVLLNKJaXJJj8WN1yQY5UKJ19bI1h93dB0Wn5ONV7I1w1N9Tjeegfn+wewLb8l6T4nuqW25Tfh3PEOtgUz4fRhpuTu4/3jU3l8dQGQOgjur+4B7IuPYx+3E3buh/krBr2U9O9HmFwsqhfPg7cPQIcbrBa4bYFUiplThjSpZsKjW9HrBGsW/pLNj79B5fOvYDUZqDnXTGlxAba7l2FffffAxwoGIOSXscO48TL9mojPI7eVTO8/X361ICLSdWy/oI7RMiZ/Hw7J71zNYMyg8oc/w2ox4/H52fwPX4rtVvm9f5HbPV42P/YpOSQewGBgxbd/DsBb3x7melTN7bD7oJSXhhbCegOfu2mP9/++Uz7gSMr0b75MBvj0GnjoHvjwOOz8AHZUw/b3uZA5jY6KW5n34FQozEv6maN6BpteL1cP8TsCIX2s5bBVXIez+ggb77sNe/kc6eOeOSw71iw5qRXCaJKtfX/WwWKVPagNdVA0GbLyr86Zb6EAtJyTmbXeqVEhZMDcckamUNWh1bblN8Va/kRst1XgrNqB7YZp4G6VAbR6y9q6LLR2WXBUzxg8Hdrthfc/hrf3w5HTctt10+ETd/Dlmi9iNQYGbe3T3yKk6ln29vDrHwpub/ojMzoOSoGbP0NaiaXzIcNE5fP38l7dRISAds8a5kxsH7B1cez9mE1vvEdzt4/iwkI2PvTXqd0lkNYh2JM6sxQJS+uQky8H+F1NrpLXLVPH0Hfxg2AAWs/JFLTZIoW6P4KB+GQcIXAcOIbzgyPYKuRsSmf1EV7Z006mOTdmxfsQCsH+Wtnyf3AEQmGYVAR3lMPtC6A4HwDH3un8/t2ZV4FFSEVmBuPsM/hh9V+zfuZuVna8Crv2w6aXwfwq3DyfR8b7ONP8aYSiQ6fzDXpI5/5azru6aGh3EwwGcb6+DfvSG1PPTTCapDvQ2SIrvrhUplVBBoKZ2dJq9Hjk4L6x7ipFwjIh0NkEpsxkVygioKtVKoGiSznBJkZIVQCXVAD0BlAUnB8cwWo24qw+AoDVbATa8PoLqW/JjlsFIaDujBT+dz+UliA3C1bfDMsWwfRJfc5tX3wc+/xDbihPUaA4Y1MR6B08rYS1d8PR09J1eu8jbvPtZ1vRv8Id5az40A3IynNU1+Cslq2PvSI+MdxWcR31LR0Y9DqKc7Kw3ThDdq5Zx8neZlMvYY5mlqKrYmTnQ8EkqSTR76KuUsEkyCkYm65SoEdagaBPFfIEV8jnkW5QwAsma9+4KUovBXAcqGPTGzJNunHNrTFXNdEiTMo309oVYt7kNvbsMmE/+aZUgKZ2QgYTHxTchbi3nJv/ZoDzDoOx6RoNhj8A+2pkxuDDY9QKAdMmMGfVzXzx6GmULAsef5DNj//NgIdxVNfg3HsY2+K52FfcKle96K0QIFsqvzrpp3CKdIuiQh+JyOyJNVcO6R5oplw6ISLgbpcLHBiMyXn/UAg6zkNnK+hNqYezg3SBXH0tQOXzr/BeXT1CwK1zSlPWQ+3R09S+1sz8mmOUuT+Ku793LGLjx1/CYDXGB9YNxBAH3Y1ZizAgZhPcvlC+2t3M2X1AKsXmP/BPeh0HCnIJ3bFI3qQUrYmjuoZNr79LzbkmFk2fhPNAHfbFc6GrI+77JwqGokgXKhyWPatdrTLwNmfKYNKaI92kc3Vye2Y/AXm6EPTLvhOvWyYBonn/iJDjh9oapLuUkZXyOhx/eRtn1U5sN5ZiL58bU4AoUesb/RzD2wN7D8M7B5nz4XHmCAGlE8C+RtZlvlyHanXW+X7TphfL1WkRVKLd7CV5v+WmmRew5efAW/vhnYPQpfqXty+U/uW0ibHKirZYbq8fi9nIjz9zX9yNCqljjHIK5Cy2VL2ogR45RTC3ODmYDoek5cjKh4IJAweUo4GIyJintUEKf2Js4/NA61lZfpMldRIg0AMdjVT+42asJiOeQIjNT6wd+JzBEBw4CrsPyaA3GILifJwzJsF107Hde+ulXdM1bRFUoj2Pv/hLFbuPtmP79mNQNhk+ex8cqJVW4o334M+7YVIx3LEQbl8Ya7EmjpM+rL1iXlJsAcRdpjtvUxUiQWhMGSBMMpjuapcuUXa+OgYnW6YYvW7ZAWfNTQ/rEOiRLb23K9kKBAPQdkFaOYNZlr/Pb33qPu2AwFZxvcwCLUlenCt2D8vnYM/MkPn+9z+WliA3C+5eKhumWVP4yXdegPc/unRFGCJXtSJEex4Ls3tljQwGWHK9fHV5OLTlTSx7Pmb277bD77Zjn1OK/faFcMsNOOrOUPn8K9S3dDBvcnE8s2Exx10mdztkq/Oho62oopMBczgkh293tkjBt2RJQQuHZK+sNVfGHqO1jEwkLHP47Y1SUaPZr3BIxgAdFwAldTbI1y07zjwu+b+aTbIvmYe9txLsO8yLv3LyGZ3CHbsPQiAIGSaZ7r59Idww47IEvRfLVa0I0czSDV+B2vP5qTtnsq1sCoWx3nojZlc3z06fKLMTv/gjvOhkWl42d0wupikSoeacXBRg6awpNLm6pXWIphK7OuTLmistRLSzTW+Q6dSAXw46y47GGGa53e+Bs0elC5VdMHL9DkJIq9TWIN24DKuMZyIJ7lEkrE5M0qf43Xk5qQkldg+isRXELSnnW+CdQyz9827s3h4CQNvMKWC7A8rnynguDRgDiqDI1kkI2SLp9MN2JVq7LOh1Amf1rJS9lNH03V13LISKefDJFVB/Ad4+yKwd+7jx0DEeMhqozs/h0JQS9rd3sXlDL9/XoN5KT6dsIS1ZsqWP9lSbzDK16nHLPobcIsgrlgF1JCJb5M4WqSTWvCu3BGU0w9V+Xvr9ZovqyglZrtYGGSybLaBPcPciYfl9m7qogTp/PBFn9RHOd7iZEgwR+v1/w8tvwukLoCjop5Tw68wMilffxJrbFw1YREd1DbXnWynMvgxTY+VgyEFNTforgtEEk+fIrIu3S7agoSDqbBspMIpuQOUozPbR2mXpN8tgr5iX1KcAyGxF6QQsD90DR09jePsg5bsPcnNTOy69npPuLso+dRfMmpp87qhC9Hjg3DHp8hRMhKw8qcQZquB3tshXXolcNiYzW1357SwYG+V4Jmvu5VtBQwhZJleTvI9Gk3SDhIDuTrm6R6BHljczIQ4IB1UXqVGWT6cDg6Fvf0xTG18Oh/lqayezevwAHM+x0nNXBfPXraKoIJeHh1hUZ/UR9DqF1q4B1qFKRSQiBT/aaAohlTXgqxvsp+mfNUp4PkKMUFB24vR4wdclW7ZIWLZSKZTjbKsLgCm9xiMNm2CITT98kdtaOri+uR1TRNBsNLA7x0rRvbdwh21Z35Y8EpZpR71eCndOQbwvIRKJ9z/kFsmXyaxeX49UgrxiqRAXG0OEg1LwO1tkUGuQT6MkEpHWK7psjcGcvIBBj1cqTZe60odeH39OQXUNT/6/17g9P4eb2zpZ5fExLSq0s6bALTfy1PGz+HOz5MDGonF9OjAHwlFdw293H2TljbP4H3ctSb1TKqE3miAjW44INmXKBEYkArXvbebhZwZ8jtrYVITeCCGFx+8Dv/po2Z5uKYBJyqGPDeq6WKIt4VRrBl27DrDK28OyYAgz4MvKxHLHQrjlRpg9tdeAtIhM5SmKjBNyi6VfriiysgI+uU9WXvw7EZHXJIR0VbLy1Ukt5v4tRSQiR8n6vbKl93XJazbKVaYJ+mUs42qGSCi+HaTSejpl69/jkdsSBsI59h2metcBpp0+z+1dXmZ7ewCoz81iT9E4HvjyQ1A0Luk+RZMM0Q7M/nr2B6Rfoc+S1stkkYKfKtgOBq8hRUiFELGlHrdsewsCPtYvmXPZlKPy+VdodLnZf+o82ULwP/Oyubvbx6LObpkLH5cNN82Hm2+AOaXxShKoyxCqccO4EsgaJwVOCDmYLxySLXReEWTmqlYiIFOZCHkMg1HuEz1uJCL3CQXiAmMwSiGPhGXj0NkqF9xFUfsC9GrM4JPp0c4WeX90SkInWkSO79lXQ1PVHkp80u05lW2ltnQ8vwmGaDQZkxIIiQLeW/Arn38Fq9nYf8++KvRb3v0QFIX1N8/v1dIPIPSpGKIipH+McLEoimztjGae2yYzGesf/FvZIgZ8ssXzdcmWMzx85YgG2M9VyqHFzuojzK64jkXXz5AdQ+9/BP+9T/ZTZGXCwtlyUtGCWfJ/UGd31UPTGen+5BbKXmeTRSpDawNwTl5H9jgwZ0m3Rm8ABESCEA5IxVAUWfYMqxT8QI8auHeq002FVLZob3DQH49TgrJlR28AgwI+Pxyuheoj8lo6u0GvR5lawksmI0V3LeGeOyv4/vOvkG82YvYHaXJ1xwbOJSpC7/graVxR1G0ksaXPAIuV53Z9CIqO9X/32RFJq169ipAKRYlPscySJly2wkNRjuSAvHcFRzvdKn/9J0rystjr89M1Pp/7FB0P51iZeuiY7EDS6WBuKZRfB+VzZEeeoshlKD2d8rM1V7pPmdlSeMNB6GgG0aReh051aYzxLJoIS/cw6E+eVmowqSNo1Za/44KcABNQhV+nk8c40wgH6+BQHRytV6dVmmHRHKoLcnn6xDm8ej0b19zKPep1R4W6JC+LvcfOAjJtmpJIGCIR7AtnY180Ww1kVffGorb0ianaaK/7CPUtXFuKkIpBlcMr3YlBlCMaQM6eUMiuI6cw6nU0tLtxj8vhbUVBVzqe4tYO7g0LPtHiIvel1+Cl1yAvG+aVwfwyuH4GlBTI9KvHpcYGGdI9igqL0SwtgAj3egC7Il0agwmMOqk8gR7ZJ+B1y2sQIr7/uRY5WvdoPRw+EZ8XXjoePnE7LJwFc6aB0cDPn3+F424PQpDU4kcbg8rnX4nFAfbF18WEPimbZrJIoc+wxlO2abSmrKYIqeijHJMSYg6fusR8l1QSNZXr3HuY2RPyqbvQwrpbbmTvsbMY9DJYrjnXhE6nY4/Pzwfjctgxp5TNa++WLXDNSSmI76pLjuRlw4zJMGOSHA4ypRhyvWrLTzwlaFBnzEWFSQjpToVUdykqiKEQNLXD2RbZN3KqQfr8qq9Pfo4c1blgtnTbxiXPPXZU1yQPT08cJCcERMLYFs3Cub8WW/lcGdSaM5OF3mhO+wXRNEUYKgkxB1l58e2hAAR82O68HefuvVSuugX7gpmquwKVL/yBRnc3+0+dp7RI/m7fibOsfu73svf1fz8oBepCK3x8QgrpiXOw/2i89TYZYXwBlORDbjbkWiE7U05j1alWyR+EHj94/bJ1b3XFX0F1oKBeB5OL4Y5FMoCfO01OcVUUGdRueTPm5jS7pYVodXtYNH0SS8oms7nyr0jq4DSYIDML+z2rsNv+Slovgyk9xk4Nk6s3a5RAq6sTgMK83MtRpMGJBquBHhw738G5+wNsi2ZhL59L5eZX5chWnzqy9dP3YV9yfd9jeHtk632uWSpJY5ucmO72yJGz/dWbTieVpGicFPKicTC1RHYQTiqOrQCSKpvT6HKz68hpci1munrkCuDZGSZZzkc+iX3ZTWp+PiOedbpCXLY6u6qyRlEzf5EtzYgpQBSdXroFGVbsNjt2m111XYLYVrmpd/8Z16lzzJ6kDuKLhKVrsXiuHL+PAhlGGTtc33fxKjkeyAuBkPwcicgxO5YMMBv7v09C4Nj7Mc4PjlLf0oHZqOeJzX9g0+vvsLRsEruOnGT2xGLOd7gpnVAE6CguGMfG9X+FfXk/87evECNdZ6NiEYb6IEGAirkzRPUv/lEdMyK4mKT/i6/vAOCRT6xUf67Il5LwQol/dwlKN1Sia/qUFOSxteptZk8ejz/gjwn2xjW3YS+fLX3+aL9AIqLPB2L3JrqrmpJ0VB9h07b3ae7y0Or2MqEgl/PtbhRFQafTk51l4dYb5mFbcUvy8pejhRC86NwOQvDI/asAtTNNEJeD2Gf6t44gOw1Pffg8Dz/zxECnHHFF6P0gQSHEgAvXVNxwvah+0yn/iYj4jYioNwd1W0QAkfg+0QyJCLPiM0+AELz1y3+V2ZboPtHshogk/C8uSen6Klev/3ttq/zev9DY1s7+o8fVrgA9IAiHIxSNy4sdtnhcDkvnzmBvjRw2s/GTq7DfEh28JsvqeHc/m159k2aXm+JxuSy9bhZ7j5yg5vRZdHodXV4f2ZkWOru9LCu/Ab9ffdKeIti43j404Y9lnlRhHEhIL0q25LWs+MK3QIG3nvuB7NfR6eKpXkV91+nkkv3RDJ5OJ+9vNG5SdLIcu7as4+FnBmxwR8M1WkLCUzQVRSkXQuwfYP84OoXYQMLhuKfRIQRFk4f+mySFGuw9UbFU5Up6DycrXlT5ANvNC3C+U8310yZjNhnZX3sSBdDrddRfaIwVJxgMsPVCM0a9DrfHyxP//Cu++atXkorc6nLHBD4YjrD1rfcx6nXoFAiHwpQWF6oKMpOmjk5st1Vgvz1hLE+qdV6TUBuHqNBFOx/1CUIa26aPC3B/Qtrfu06JD2MvTRE/DYdIGKBnsN3SIUZw9d6guk6PAkydOOAMuytHotJdKSIC+/QbsX/68zi2V+Hc/hee+/zngAibfvUbmtvb5X4CigvGsfTG69l76DCu4yfRm3TUN7cnHS47M5NwJEzppEnx/T88zEQUNn7mAex3L092sZKsFsSEvLdFiwpxVEivQkZDEQZ9kKAQ4gXkc9aouOH69E5rXQoJyma/717sa+6JWRf7PauTLQ4i9tlRtYNNv/4tze0dSYcrLshn4yMPYV95V7JLpuh6CbhGb0ZDEV4AHlBkhQz4IMExT8xlSnCHIuG47xwdXyN7yohNPIq+DCbpYuiS3Qz7+r/Dvv7TiSeK++WRiBoHqa+w+h7yx88dU4aENibmeyf44deQ0oy4IgghXKitPTC02OASee2lX1z+g6q9qkmv2HcKKKpw6g3qM8TU4c56ozpGSJcg9CMkfKlimXAo3iMdHb0aCiTPCowFxkp8ZOoVLvMVqbMBSIcY4YqTmXkRSy5GIjL1lhjgRitcCHXgmzrJJTqpRW+MD3uIteRpNLRAUVTLMoTYJ1HRw+H4vQj61cF96urg4ZAaQydYNki2bNGAeRhcVJ1dAteEIvzsxZcA2PBIgjuRKOiJozWF+kdvkMMpzFZ1aIVJpur0hriwX82ug6LEr3Wg5ZeSFEa1LuFgfP5E0A89PiACqK5gNDBPtCy97mfKOruCXL2KEJ0RFgnz8qtOQLBhvTpWRiArwJShCnpGfNn3REHXGJwkhRlgOmk41FdZgv64sgR6iLtg8PKrDkBhw0Nrky3LFWLsKkLvligSfXRTNPBUBT1DfZCFooMJMzVBHy0GUxYhpIUOq/VpzJCNWUZWXFFidZwQ8Pd2Qy/SSo8NRfD7EjIeEOvxNZrVMfoZcjpj1EePvqJEO9RSLfGukR4oilp/0Qk5av0VT43vk2RRQnFrEvT3ildIGBggYGDnDhgLiqDXy2HPxoSsS1TQr2YfXaMvURepP6sSjfsSM2FyGmrtYIdOf0UwmOTzBTQ0BkOnA52p79L7Dz9TM9hP034+gqIoXQxBo0eQQqB1tAuRQLqVB9KvTBlCiPkD7ZD+FgFqhRAVo12IKIqiVGvlGZh0K5OiKNWD7ZNGvT0aGqOHpggaGowNRXhh8F1GFK08g5NuZRq0PGkfLGtojARjwSJoaFxxxkLWKIaiKOVA2WAT/kegHHlAWfQlhPjRKJVjyIsgjEBZ8kiDe5KKocjNWLMIK4H80S4E8ADgUm/selUIRhR1EYRqIUQVsGqkz5+CUb8nAzCo3IwZRVAUZSVQNdrlADmVNHGKqTrZaKRZQsJ8b7XVGzXS5J70Yahyk1aukaIoqR7KW4U0a1UjXdn9lSdayWqrvG4kyzQArtEuAKTXPVFXSBmS3KSVIvTnwymKEtXscmCGoih5I9HiDORTquUZTb980EUQRpo0uCd9GKrcpJUi9Ed03SNFUdLBF47e3GeBdqTvuXgUipFWiyCkyT1JYjhyo/UjaGgwhoJlDY0riaYIGhpoiqChAWiKoKEBaIqgoQFoiqChAWiKoKEBjJEONY2+qIPaKpC9pieBJUKIp0a1UGMYrUNtjBIdLqAoylYhRFqM7RnLaK7R2OUB1Sq0Q2yIg8ZFoinC2KYMOKEoylp1XoLGRaK5RhoaaBZBQwPQFEFDA9AUQUMD0BRBQwPQFEFDA9AUQUMD0BRBQwPQFEFDA4D/D/1Eszqq/c5WAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "train_vi_loss = errors.loss(mean_vi_train, sigma_vi_train, y_train)\n", + "test_vi_loss = errors.loss(mean_vi_test, sigma_vi_test, y_test)\n", + "\n", + "ax= plot.plot_prediction_reg(\n", + " x_train,\n", + " y_train,\n", + " x_test,\n", + " y_test,\n", + " x_linspace_test,\n", + " mean_vi,\n", + " sigma_vi,\n", + " f\"Train {train_vi_loss:.2f} Test {test_vi_loss:.2f}\",y_min=-30,y_max=250,\n", + ")\n", + "savefig(\"MLP_VI.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/hetero/Calibration_VI.pdf to figures/hetero/Calibration_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/hetero/Calibration_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAlIElEQVR4nO2deXhURda438pCSMISEAjisAVZVBSMLCrqpxDk55dRHgFFhFFQZBlFRccMOnwaRhwEiWw6SkQBFxQR0Jlhk4AIiEggKCKgE8IiWxKWBLKSdJ/fH/d20kl6ud3pTjpw3+e5T3ffW+fWSbpO36o6p04pEcHE5HInqLYVMDEJBExDMDHBNAQTE8A0BBMTwDQEhyilEpRSY/Rjl1IqTik1XykVU9u6OaK29LXdP1D/Lx4hIuZR6QDi9NcoYL3d+1gDslHAGIP1jAHigCHO9ABi7a/r56b7UF+HddvVNcRRGVtdwHwgxlafXj6htr9DTw/zieCYnZVPiEgOkOFOUERyRCTZXTmlVAKwU0RSgP4OrscBUSKShtbQbPdPsf9cHX31X/IqdevXYvX6vwBedHKLv4rIWBHJAB4EcvTyQ5VSUa7qDjRMQ3CA3ogc0UMptUzviti6BUP0z1H65zil1Hy79+uVUrFKqemV7tUTKKtHb3j2OqQALyqllgEpXurbVO8uDdF1idN1GaPrGwPE6EZX+Z5pQIpSagwwzcX/I04pNUZEknWDcKdTQGIaggfojRMRmSEiGbox5ABlv5p6maaVyqcBZyo39krk2H/Q7/1XtO5HZSMyynS0p0IG8ABa12UokKE/uVL09w4NTW/MKbpMlWt640/R723TO8H+c13BNATPOWv3PgbtS4+hanfFRo6T86lo/WoA7H9NdYaISIrezVpWjQHpThFJE5GxaAY7Dehvfz9H97Y99XS9Yip3deyeKvbn4vQ66hymIThB/5IfxK7roL/aN4oYtF9sgLN6tyPWVsbufQxaV6iHXRXJaF2LWLt7oHdFAL7QuzSxwFmbodjpUOHp4khftCfKi3Z6DdF1SKWiQTt6Un0BRCmlhgBLbV0dO/0+t6trmf46HW3wvMzB/QIapc8AmJhc1phPBBMTTEMwMQFMQzAxAUxDMDEB/GAIuoPF6by3Pu0Wp89GmJgEBD43BCchAID7sAITk9qiprtGLsMK9KfFTqXUzuuuu04A8zAPQ8fZgrPSJLGFqClK1BQleEhtjxFy7D/oLvseItIjPDy8llQyqWv859f/cNWbV5GjsmEbcNHze4T4XCvXuAsrMDExjNVqZdRXo/hwz4coi4JPIfhQAhZWve3pvfwyWKZSCICdW95hWIGJiacczT1Kuznt+HDPh0SFRtH046Z0VNdjsfwdWbf3KU/v55fBsojcpEdc2s4l66+2iMU0PW7dxMRj3k97nw5zO/D7+d+5O+Zu4n6M4/yx8wwb9iEQ5tU9a7prZGLiNaXWUu5dci9rD64lWAXzbvy7NE5vzLBlw3jttdc4erQ7jRp5d2/TEEzqBHsy93DX4rs4W3iWqxpexXePfUdoQShd7+7KzTffTEJCAr17Q48e7u/liNqeNTIxccu0LdPo/m53zhaeZfj1wzn67FHaNG7D6NGjKSoqYvHixZSWhrBnD/Ts6V0dPn8i6APjDMrXu9pfi0JbsJ7m6LqJiT15F/Pou7gvqSdSCQsO49PBn3L/NfcD8N5777FmzRrmzp1Lp06d+OEHKC313hB8mgkASEDPnADMd3B9OuUZF5a5utdNN90kJpcvK/evlLBXw4RE5Jq3rpHs/OyyaxkZGdKgQQPp27evWCwWERGZN08ERI4cEZEAyGLh0nNciahLIh+Oic954l9PcP/S+ym2FKNQbB61mWYRzQDNdzBy5EiCgoJYuHAhQUFaE05NhehoaN3auzr9PUbIqfR5GmBbNti0cmH7EIvs7Gw/q2YSaGTlZdFpXicW7F5Qdk4Q1h1cV/Z59uzZbN68mTlz5tCmTZuy86mpWrdIKe/q9rUhuPQci7buNVk0H0OGg+tlIRbNmzf3sWomgcznv3xO69mt+e/Z/9Lrql4E6U0zJCiEAR0GALBv3z5eeukl7rvvPh599NEy2QsX4MCBaowP8P1gORl4UGlmWWFBuogk612hOKXUWZznyjG5jLBarQxdPpQv9n1BkApiZv+ZPH/r85wuOM26g+sY0GEAzSKaUVJSwiOPPELDhg1JTk5G2f3079oFItUzhFpPtefsMAfLlz7pZ9Kl5cyWQiLSfEZz2Ze1z2nZxMREAeSLL76ocm3GDG2gnF0+nq71wbKJiSHe2vEWnd/qzKm8UwzsPJATz5/gmubXOCy7a9cupk6dyvDhwxk8eHCV66mp0K4dNGvmvT6mZ9mkRrlYepEBHw9g05FNhAaFsmjgIkZ0G+G0fFFREY888gjR0dHMmzfPYRnbQLk6mIZgUmOkHk+l/0f9yS3OJSYqhi2jttCqUSuXMpMnT2bfvn2sXbuWJk2aVLmenQ2HD8Of/1w93cyukUmN8LeNf6P3gt7kFufyROwTHHzmoFsj2Lx5M2+++Sbjxo1jwIABDsvs1POAB9wTwVWIhX49Di3dYIyj6yaXFuln0+n3YT+O5h4lPCSclUNXMuBqx43anszMTB5++GHatWvHG2+84bRcaqrmO7jppurp6dMngrc5/00uTZbuXUrHeR05mnsUgL3j9xoygry8PK6++mqOHz9OcXFxmffYEamp0KULNGxYPV1rNMRC3OT8Nz3LlwZWq5XHvnqMh5Y/VOH898e/dytbWlrKQw89RF5eHgC5ubkcOHDAYVkR3wyUoYZDLNzl/BfTs1znOXb+GDFzY1j440IahTVy6CF2RnFxMUOHDmXVqlU0adKEyMhIoqOj6dKli+O6jkFmpm8MwdAYQSnVSETOK6XaoaUoP++kqJGc/zP0e2KXf9/kEmDxj4sZ/e/RlFpL6de+H6sfXs35i+creIidUVBQwKBBg1i3bh1z5sxh9OjRHDhwgC5duhAREeFQJjVVe/WFIRgNr34C6Au8DgxyUS4Kbb1B5Q3wxuivMWg5+itcd3SYnuW6Q4mlROI/iRcSkeApwfL2jrc9ks/NzZU77rhDlFLy/vvvG5abNEkkJESksLDKJc+XEBgqBDeixQY1Bvp6U5Gnh2kIdYO9mXul2YxmQiJy5cwrJf1MukfyZ86ckZ49e0pISIh89tlnHsn26ycSG+vwkt9CLJqiTXk2Bao5UWVyqTDjuxnc8O4NnC44zbCuwzg28RgdmnYwLJ+Zmcmdd97Jnj17WLFiBUOHVtmqzSlWq+ZD8Em3CINjBBHZAGzQPzqf1DW5LCi4WEC/D/ux/fh26gXXY8mgJQy+tmoMkCt+//13+vXrx/Hjx1m1ahX9+vXzSD49HXJza9gQlFJ9RWSjPliOFZEV3lRm23tM6tjWoyblbD68mfhP48m7mEeXK7qw5bEtLgfBjkhPTycuLo6cnBzWr1/Prbfe6rEePh0o48YQlFKD0RxjMUqpBwCFlnTVqSG48SzHAWP1WPKmwBNilwjMJDCxrQ3YfGQzybu0vdSf7f0ss/7fLI/v9csvvxAXF0dJSQkbN24kNtbVal7npKZCeDhce61X4lVwaQgislwpZfMC23ZhdDpG0D3LKSKSpm+6XdkQ0kSkv142zjSCwOd0wWmi34jGihWAyNBIVj+8mjva3eHxvXbt2sWAAQOoV68emzdv5tpqtOLUVIiNhRAfBQm5HSyLyCG0p0I/tF90V/sauPMs27ZIHSIONrk2PcuBx9TNU8uMAGDOPXO8MoLvvvuOvn370qBBA7Zs2VItIygpgbQ033WLwHjQ3Wci8iOAPk4wSo6T8z1xsDG1aDlSkwF69OjhcY57E99htVoZvmI4n/3yWdm5kKAQBnYe6PG9UlJSGDhwIH/4wx9ISUmhtbepJnR++QWKimrHEGYopQTIBdqjNWRHuE37brcZtkmAcujcIW5beBsnLpzgivArWDF0Bb+f/92td9gRy5YtY/jw4XTu3JmUlBSio6OrrZ9toNyrV7VvVYZRQ5iuT6GilLrRRTmXi/f1j1HAQc9VNakJ3kl9hwlrJmARC/Ed4/nyoS8JCfK8I/7bb7/x8ssvs3TpUoKCgrhw4QINqxsiqpOaCk2aQAfjLgv3eOqBAxp547nz9DA9yzVLcUmx9FvcT0hEQv4eIh+kfeDVfTIyMmTkyJESHBws9evXl9DQUAEkMjJSdu3a5RNdu3cX6d/fZRHfepaVUqP113eVUkuVUp9T7lgzuURIO5lGdFI0Gw5toG3jthx65hCjbhzl0T1+//13xo4dS6dOnfj000+ZMGEC+/bto3Xr1m6jSD2hsBB+/tm34wPA9RMBaKy/9rM7d6M3FufpYT4RaoZXNr4iKlEJicioL0eV5RI1yokTJ2TChAlSr149CQ0NlT//+c9y7Nixsuv5+fmya9cuyc/P94m+27aJgMjKlS6Led7TMVwQugPdvanEm8M0BP+SW5grN757o5CI1J9aX1b9tsoj+aysLHn++eelfv36EhwcLKNHj5bDhw/7Sdty5szRWq2drTnC4/ZmNMTiCaCD/r6HiCxwUdbdmuUhaNOqsaKvTTCpWdYfXM/AzwZSWFrI9S2uZ/OozUTVjzIke/bsWWbOnMncuXMpLCxkxIgRvPzyy3Tw6cjVOampcOWVcNVVPr6xEWuhYteon4ty7tLCx2G3NsFVneYTwfdk52dL38V9hUREJSp5MeVFQ3L5+fmyadMmeemll6RRo0ailJKHHnpI9u/f72eNq9K5s8h997kt5p8nAlqskT57S3sX5So4ypRSsVIxjKI/cNC28ya688zE/+zN2sv171xf9nnN8DWGF9K3a9eOM2fOADBw4ECmTp1K165d/aarM3Jz4ddfYYTzfGBeY3Q9wufADOA9HCy6d0FOpc9RaFmwU4D+tmhUG2aIhX/45OdP6P5u9wrnTheediuXnp7OnXfeWWYE9evX5+WXX64VIwAt2S/4YcYIg4YgIrkiMk5EhorIYRdF3XmWd7mpx1y870OsViv3f3Y/I1aMQBDDC+ktFguzZs3ihhtuID09nWbNmhEREUGrVq18MgXqLTaPsrcbBrrESP8JrTv0ObAUaOeiXBQu1ixL+Tgizv6co8McI1SPfVn7pPmM5kIi0nJmS/nt9G+SnZ8tH+/5uMI2TJXZv3+/3HLLLQJIfHy8HDt2zOdToN4yeLBITIyhov6ZPgVe0I2hPfAXbyry9DANwXuStiVJ0JQgIREZ8vkQQ76BkpISef311yUsLEyaNGkiH330kVit1hrQ1jht2ogMHWqoqN8Moa/d++76q19DLUxD8JzCkkLp834fIRGp9/d68tnPxhbD79mzR/RoXxk0aJCcPHnSz5p6Tmam1lpnzjRU3OP2ZnTWaJLdCrX2SqlD+tPB/bSDSY3w3dHvuOeTe7hw8QIdm3Zk66ittGjQwqVMSUkJ06ZNY+rUqURFRfH5558zZMiQCrvRBAq+XppZBSPWggPfgaNzvjzMJ4JxXvj6hbIwiadWPWVIJi0tTbp16yaADBs2TLKznY8bAoFXXhEJChK5cMFQcf88EUQPwXZ3zgi27HZmlrvqc7bgLLcvvJ19p/cRGRrJv4b9i77t+zotX1BQwJ49e1i5ciVJSUm0aNGCL7/8koEDPV9sU9OkpsI110CDBv65v1NDsKV59PSGrkIsdL/BfKVUBg5yn5oY43TBaf6x5R/8M/WfFFuK6dmqJxsf3UiDes5bSUFBAZ06deLEiROICCNGjGDu3LkON98INERP9hsf7786XPkRXrS9UUp1t3vfzpmAu7TwOn8VkbHm08A7svKyaPFGC2Ztn0WxpZgXb3uRHU/scGkEAG+99RbHjx9HRAgLC2PixIl1wggAjh7Vdsbx2/gA1yvUdiql3tXfxyilzqEPlnG+VNNdiAVAD6VUU7RYowohFvrTZAxQYTNpE40jOUeITY5FKF/OfV2L69zKzZo1i0mTJhEWFkZwcDAtW7asVceYp/h9oAzGc586eu+g3DL0YDpgPi4C64D1ruo0B8sVmb9zvgRPCS4LmLOtJHPlHCstLZVnnnlGABk8eLCcPn06IBxjnpKQIBIaKlJUZFjEb4Pl3bbukYjsdlHUZYiF/ov/uZiZ7gxTai0l/pN4vs74mpCgEBb8cQEDuwx0m2q9sLCQ4cOHs3LlSiZOnMjMmTMJCgriiiuuqOG/oPqkpkK3bhAW5sdKjFgLWlr41/VjtItyUbhOCx+lXzNDLAyw+8RuafJ6EyERaf1mazmSc8SQXFZWltx8882ilJLZs2f7WUv/YrGINGokMn68R2IePxE89iPgZ/+B7bjcDWHKpillXaBHVz5qeAnlf//7X7n66qulfv36snz5cj9r6X/279da6cKFHon5p2uE8fUIJtUk72Iedy66k10nd1E/uD5LH1jKfZ3vMyS7fft27r33XkSEDRs2eJVcN9CokYEyxvMa2dYjNEHbA83ED2zI2MB9n91HQUkBXZt35duR39I0oqkh2S+//JKHH36YVq1asWbNGjp27OhnbWuG1FSIjNR2zvQr3jxGauK43LpG4/8zvmxGKOHrBI9k586dK0op6dWrl2RmZvpJw9rh5ptF7rjDYzG/dY0M427xvl4mFnPDcQBO5Z3itg9u4+C5gzSs15B1I9ZxS+tbDMlarVYSEhJISkpi4MCBLFmyxOnGe3WRkhLYvRueesr/dRneXlYp1d3ew+ykjBHPMmizRsae+ZcopwtO89Tqp2gzqw0Hzx3k9ja3k/VClmEjKCoq4qGHHiIpKYknn3yS5cuXX1JGAFoir+Ji/48PwPiOOUbTubj1LOsL91MAfyy4qxNk5WXRMqllmYf41bteZfIdkw3LHzt2jHvvvZcff/yRmTNn8txzzwVk6HR1qamBMhgfLGeIyHsASilPNrvKsf+gG0ZK5X0T7K5f8iEWv57+lV4LelUIk2jfxPhE3Pr164mPj6ekpIQWLVowfvz4S9IIQDOEK66A9jUwT2m0axSjlGqklGqElobFGUbTwscBN1XOYiGX+OL92dtnc+0/r+V88XkUWuM1siM9aLlFhw0bxt13301paSkA+fn5HDhwwK861yapqdpC/Zqwc0/SuUxHS+fyg4tyyWhBdbFUSgsPICJp+vih7vn5q0FRaRF3LLyDiesmEqyCWTJoCVkvZPHxoI85+fxJl3sOFBUVMXXqVLp06cKXX37JSy+9RNu2bX2aWDcQKSjQNgSpiW4R4DYJ8Gj99V00Y/gcSPVmesrT41KZPt3++3ZpNK2RkIh0mNNBTl4wth7YarXKypUrpX379mVBc4cOHRIR3yfWDUS2btU8yl995ZW4x+3NnSE01l/NbNheMGn9pLIwiXH/HmdYbt++fdK/f38B5LrrrpMNGzb4UcvAZPp0rXWmp3sl7ltDqM2jLhvCucJz0vXtrkIiEvFahHyd/rUxuXPn5Nlnn5WQkBCJioqSuXPnSklJiZ+1DTzy80UiI0WU0vIYefHg848hoKdzAdoBg7ypyNOjrhrCvw/8W+pPrS8kIjfNv0kuFLtfbW6xWGTBggXSvHlzUUrJ2LFjJSsrqwa0DUz+7/+0lgmaQXix0Y7H7c3ohuPtjW44frlitVp57F+PsfinxSgUif+TyCt3vuK0fEFBAQcOHCA3N5eEhAR27txJnz59WLt2rdebcF8KJCfDq69qm4krBdHRNRBnBO6fCGjRpoZDr9H8AHHYrUeodD0OGOLsuu2oK0+E7Pxsmb19tlyVdJWQiDR5vYnsOuH6Jyw/P1/atGkjISEhAsiVV14pn3zyScBllqtp3n5bewrEx4ucPas9CbycD/BP16iCgItdc3C/P0LZgh1gl6t66oIhZOdnS1CillqRROR/Fv6PFJcUu5QpKiqSp59+WtCerBIaGipbtmypIY0Dl7lztdZ4330eLcl0hsft2pAfQXemva6UWgqMdVG0J3be5MoeZNHCLVJ0v8I0B/XUmbTwpdZS4j6Mq7Ar/RM3PUG9kHpOZVatWkXXrl2ZO3cuERERhIeH07p168u6KwQwezY8/TTcfz8sW+bnJZlOMOpQG4O2GD8Zu1giA+RUPiHaeuUUYKiDa3XCs7w3cy8tZ7bkp8yfys658hD/9ttvxMfH88c//pHg4GDWrl1LdnY2W7du5eeff77kguU8ISkJJk6EwYNh6VKo5/x3xL8YeWwAg4F+wDu4mDXCfdcogfIsF7vQQrXrVNfoH5v/UZZp+uEvHpbMC5lOU63n5ubKCy+8IKGhodKwYUNJSkqS4mLXXafLiddf17pDDz4ocvGiT2/tvzEC+qAZeMJFmShcL96PsV0DElzVF2iGkF+cL72SewmJSNirYbJi3wqnZS0WiyxevFhatmwpgIwaNSogM0zXJq+9prW+YcNE/OAq8Y8h4GJzEH8dgWQI3xz6RiJfixQSkS5vdXGZS2jHjh3Su3dvAaR3797yww8/1KCmdYMpU7SWN3y4X4xAxI+G8Be7937dF8F2BIohPL366bJZoYlrJzotd+rUKXnssccEkOjoaFm0aJHHm3df6litIi+/rLW6Rx4RKS31W1V+MwTbtlHvAuu8qcjTo7YNIfNCpnSa10lIRBr8o4FsOeJ4ivPcuXMyceJEadSokYSGhspf/vIXyc3NrWFtAx+rVeRvf9Na3KhRfjUCET8awmWV12jZL8uk3qv1hETklgW3SH6xY6/Ot99+K6GhoQJIeHi47N69u2YVrSNYrSKTJmmtbfRoLWmXn/GPIXh0QxeeZcoz3QXkYNliscjQZUOFRCRoSpDM2DrDYbmCggJ54YUXJCgoSJRSAkhERITs8iIo5lInL0/kT3/SWtq4cTViBCK1bQgGpk/HBOr06Y5jOyTq9SghEWk+o7nszdzrsNzWrVulU6dOZbNB7dq1k8jISImJibmk1wd4w8mTWrpG0F7z8mqsao/bruEsFgZx51lOFrvlmxIgyYCnfzedXgt6kVOUg0KxZ/yeKunW8/PzeeaZZ7j99tspLi5m/fr1fPDBB/zyyy9s3rz5sneM2SgthTVrYNgwaNcOzp8vP//rr7Wqmmu8sR5nBwbTwmPnWKt0fgywE9jZpk0b//xW2FFcUix3LbqrbFbIdny85+MK5TZu3Fi2UuzJJ5+UCwY38rpcsFpF0tJEJk4UiY7WngBNm4qMGSPSqpVIRITX6wq8JbC7Rvr5ODTHmtO9E6QGukY7ju2QxtMaC4lImzfblAXP2e85cP78eRk3bpwA0qFDB/n222/9qlNd49gxkRkzRLp21VpSaKjI/feLrFhRHjiXn1+tKFJvqXVDiMK1ZzkObWywnlqMPp28YXLZEsrRX40Wi8VSZVf6tWvXSps2bUQpJc8995zZ/9fJyxP56COR/v21FWSgpWX85z9FTp+ube3KqF1D8OXhD0PILcyVbu90ExKR8Knhsvq31VXKnDt3rswx1qVLF9m2bZvP9ahrXLgg8s47Ig8/rK0YA5F27bSVZL/+WtvaOcQ0BGes/m21hE8NFxKRbu90k9zCik6v/Px8mTVrllx55ZUSHBwskyZNksLCQp/qUNc4eVIkMVEkJERrKUqJjBwpsnlzjU2DeotpCJWxWCzy+FePl2WanrxhcpUyx48flwYNGggg9erVu6wXylgsIl9/LTJ4cLkBBAVprxERXq0frg08bm8+z4YdSBw7f4zbPriNI7lHaBzWmJRHUujRqmLK1e3btzN48GDy8vIACAkJuSynQbOyYOFCbc1wRoaWavGZZ+BPf4JBgyAzswbXD9cG3lhPTRzVfSIs/nGxhPw9REhE+i7qW2UJpcVikWnTpklwcLC0bdtWWrVqddk5xqxWkQ0btPUAoaHar/4dd4h88omIfa+wlmZ+qkPtd40wtnh/urv7eGMI2fnZsujHRXL3R3cLiUjwlGB5e8fbVcqdPHlS4uLiBJAHHnhAzp07d1lkj7M16CNHRN54Q6RjR60FNGki8uyzIvv21baGPqN2DQEDfgT92jJ39/LUECovpI9+I1rSz1RNk7ZmzRpp3ry5hIeHS3Jy8mWTOSIzU6RlS5HgYCnLGdSnj8iHH4oUFNS2dj7H47br6zGC2/0RXFGdtPAT1kyosJD+jf5v0KFph7LPFy9e5KWXXiIpKYmuXbuyadMmrr32Wo/qqEsUFsL338M338DGjfDDD2CxaNdCQmDJEnjggdrVMZDw92A5x5PCIpKMliCAHj16iJviABRcLCDuozi+P/Z92bmQoBDu6XhP2ef09HSGDRvGzp07GT9+PElJSYSHh3uiWsBz8SLs2KE1+m++0YyguBiCg7XU6hMnwscfa7E/LVtCfHxtaxxY+NoQ3O6P4Eu2Ht3KPZ/cQ97FPDpf0ZmVQ1eSdiqtwo70S5YsYdy4cQQHB7N8+XIGDRrkT5VqhIIC2LsXiopg2zat4W/dqp1XCrp31/Ydu+suuP12aNRIk5syBQ4c0GZ+LsOJMdd4059yduAmxEKkQphFrKt7uRsjTFw7sWw88PTqp6tcv3DhgowcOVIA6dOnjxw+fNiDLmbgYLGIHD2qze68847IhAki4eHl/XwQue46kaee0mJ8zpypbY0DgtodLPvycGYI2fnZ0uWtLkIiEvlapHx7qGog3O7du6Vz586ilJLJkycHZEZp+ylJq1Xk1CmRLVtEPvhAW801eLDI9ddXbfRhYeUxPmFhmvPLpAq1Plj2Kyv3r2TY8mEUW4rpfVVvNj6ykYh65c/4/Px8pkyZwpw5c2jWrBkbNmzgrrvu8qtOZ89qffNWrbTPhYVaF6WwsOJ7+3O5ubBoEeTna334+vXL4/ZBG8x26AAdO0JcHHTqpL3v1AmaNIFu3codXH36+PXPu2yoE4ZgtVoZsXIEn+79lCAVxLR+05h02yRAa/wbNmzgq6++YvHixVgsFiIiIti2bRtt27b1mQ4icPIk/PSTduzZAz/+CPv3e3YfpbSUhkVF5efuuQduvbW8sbdtqxmDM37+2ezr+5qAN4RD5w5x+8LbOX7hOFeEX8GmRzcRnhfOvHnzWLVqFZs2baK4uLhCWIQInDlzxiNDKCgob1zBwVoDt2/0P/0Ep0+Xl2/bVjvS07WNscPCtHTmXbtqKc3Dw7VGav8aHq6VKyyE668v/1X/4APPGnREBFzm6VJ9jhIxNEtp/IaaLyADbT1ylTyp7q7b6NGjhzz+zuNMWDMBi9VC74je3HLqFtasXsOv+pq/zp07Exf3v8TGxhMe3pMRI17Dag0lKKghkydPJDLSWCLNixfhzTe17klwMFit2tJC0Bpv165ad+SGG8pfo6I047Fv0D//bLxB2xue+avuczzeh9OnhqCUSgBSRCRNKTVfRMZ6ct2eeq3CpGTMRVS+ImRFPUoyigkOrkd09J00bBhPaen/kpV1NRcu+Ex9QDOERx+Fu+/WGn3Hjto5Z5gNOiDx2BBq2rNs2PNcwkXYC/KfaEqK7wXisVj6AQ1o3Bj+8IeKR7Nm8PjjcOYMtGih7dHrya9zz56Qna39ss+bZ1zW7KZcGtS2Z7nCdfsQC4KB7wB1JhveO6pt8QwnTmjHjh2ObhcUBCEtDx8uPdW8udXqqIRzNNmMjNJTkZGeytIMOO22lO/kaqPOuqTrXhHp6olATXuWXV4XuxALpdROOSkVFw8YQCm1U8RzuerI1rRcbdRZ13T1VMbXeY2SgR56PqP1tpP6L73T6yYmtY1PnwiiJexK1j+m2Z1PdnXdxKS28fUTwZckuy/iU7naqNPUNUDkfO5HMDGpiwTyE8HEIEqpGPtXP9cVpZSK8nc9NU1AhFh46402IBcH9BeRv3pSp/5Fx1CemnKGh3VGAXiiq14mVq+vsj7udJ2vlMoAphuV068PQZvCjvXgb4wDxiqlAJqi7amXZlDW9v85W/nvNPA3jkEbV1ZuA06/YyP6lOFNyKovD4ztxFnlujs5u3JV1kcbqNNh+noDcg43VPdA1wTs1m4YrDMKB2s7DMjFYbfJowdy9mXjvKjT9v9J8EBuuq0uJ9+nwzXwRv/vIr5PC+8NLlPJu7juTs7rOsV5+npvN1R3q6v+y5biqa46PZRScXbT1Ebk+tvVG2dUzvZ/UUoNEZHK+rqTTQFeVEoto+Lf6sl3GeVBF9DwfQPBECqT4+V1d3Ie31OPjXK1xL2KnLjYUN2ZnB5q4sgI3MqKSI5uuCke6hoFZOhy/V30+yvL2ejpTtHKsnoD/iuaD2m6IwEndU4DYvWG3NRAvUbvW0YgGIK33ujqrI92K6v/UlbuU7qUU0olKKVi9PMxdo3LaH1xwE2VGqW7Osc4acTu6tzlQMYTXb2RHSIiKaL5lZbZ/bK7izjIAZL1J26GB9+18Tbiqt9UEwfuU8k7u+5Szq5PWmV9tIE6HaavNyDncEN1I7ra9YXnY7ellsH/Tyx2fX4P/j8JXsoNqay7B/+fIZ5+l7rcGJtspTqrfMfu2o6jw/QjmJgQGF0jE5NaxzQEExNMQzAxAUxDCEiUUus9DZfQ/Qjz/aXTpY5pCIGJobUaetzPGChzVlVnjv2yJiBijUy8Qyqu7zCpBqYhBAh2gWUp6EkO9O6RLUgtR0RSdG93mn6uKfCAVM0WUkWuxv6QOorZNQocpgNf2Lynducy9OMBmxHoDbuHi+5QBTm/a34JYBpC4LNTRNL0X/2e6EaihykYlTNxg+lZDhD0rtGDwE60X3RbhOZYYKleLAety5OB1u0BLc9NP7QwBNv7pvZy4sGuRZcrpiGYmGB2jUxMANMQTEwA0xBMTADTEExMANMQTEwA0xBMTADTEExMANMQTEwA0xBMTAD4/3U3k+JS5EOnAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1)\n", + "_, df_test = plot.calibration_regression(\n", + " mean_vi_test, sigma_vi_test, y_test, \"Test\", \"blue\", ax\n", + ")\n", + "_, df_train = plot.calibration_regression(\n", + " mean_vi_train, sigma_vi_train, y_train, \"Train\", \"black\", ax\n", + ")\n", + "ax.set_title(f\"Train {errors.ace(df_train):.2f} Test {errors.ace(df_test):.2f}\")\n", + "savefig(\"Calibration_VI.pdf\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1130,7 +1242,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.13" }, "vscode": { "interpreter": { diff --git a/notebooks/regression/motor_data_comparison.ipynb b/notebooks/regression/motor_data_comparison.ipynb index 9dab523..1df0881 100644 --- a/notebooks/regression/motor_data_comparison.ipynb +++ b/notebooks/regression/motor_data_comparison.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -17,29 +17,21 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "from models.gaussian_mlp import gmlp\n", "from models.mlp import mlp\n", "from utilities.fits import fit\n", "from utilities.gmm import gmm_mean_var\n", "from utilities.predict import predict\n", - "from utilities import plot,errors\n" + "from utilities import plot,errors" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -1070,6 +1062,116 @@ "savefig('MCMC.pdf')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VI" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "from utilities.vi_helper import vi_model,vi_predict\n", + "params = [[16,16,1],[nn.relu]*2]\n", + "\n", + "mlp_model_vi, vi_model, results = vi_model(params,x_train, y_train.flatten())\n", + "\n", + "mean_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).mean(axis = 0)\n", + "sigma_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).std(axis = 0)\n", + "mean_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).mean(axis = 0)\n", + "sigma_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).std(axis = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/motorhelmet/MLP_VI.pdf to figures/motorhelmet/MLP_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/motorhelmet/MLP_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAz0klEQVR4nO2deZhcVZnwf+fce2vtfU13OntYE7YkuCKLhAhKZHQUxlFnXInADM7ojI4jKs6oqJ/LpyNgcBZH5psx4DKCgCOBEdxAIMgmW8hGpzvpvbv2uvee8/1xqju9V1V3dac7qd/z1NNdt+5y6tR973nPeTehtaZMmeMdebQbUKbMQqAsCGXKUBaEMmWAsiCUKQOUBWFeEUJ8TAhxZe71mBBisxBiuxBi9dFu2/GOfbQbcJyxS2u9UwhRA7w99/+jQF5ByB1zudb6lkIuJITYAKzWWv9gks82AzUAWusf5M69eviltf5yYV/n2KE8Iswvj47foLUeAPbkO1BrPVCoEOTYDNSN35gTkJqcgHwit/lyYCC37YqcYBxXlAVhHsnd9JOxSQhxe051Wg0ghHhb7n1N7v1mIcT2Uf/fK4TYIIT40viT5Z74O6dowy5gpxDiSuCG3LZbtNZ7Ru0zVTuPWcqCsADQWu/M/f2y1npPThgGgJGndm6funH77wJ6c095wDzxhz+f5noDGEG5YvR2IcTHgLeX5lstLsqCsHDoG/X/aswNOay3T8bAVCfKjQibgY3j1ZzhUSc3AqwePeJgBO+4pCwI88zwpBdzE27ObdvMqJsSc/Pfm/u/L6cCbRjeZ9T/q4GzgU3D59da78qNCPXjrntl7t8fADVCiLcBO7TWA7nrfwnYDtxe8i+9CBBlX6MyZcojQpkyQFkQypQByoJQpgxQFoQyZYCyIJQpA5QFoUwZYAE73W0+c5O+45vbZ3y8TqVBCAInrkQECvuaWmtULImzshUZCubd//zzzwfgF7/4xYzbeaxTTB9p38fd14GwJcJxir6W9n2y+zrQA0PI6ioAVCzO4fdc94o13b98ZLpjF6wg9MaGZnW8CIdQiQTZve0E1i5DWFb+Y4RABh28w304y5cghJh2/wsvvHBWbTweKKaP/K4+0LpkQlAMC1YQSoGMRlFDMdwDnTgrl+a9sQFEMICKJVCxBFZVxbT7fupTnypVU49ZCu0jfyiOPxDDqp6+zydjtkIAx8EcQVZV4vcO4HV0F35MJITf1Yf2vDlsWZlhdCaLd6gXWREu/ljfJ7t3dkIAx4EgQE4YDnbhdvXl3xlG1Ci/d3Da/S655BIuueSSWbfvWCZfH2mlcDu6kY5VkPo65thhIRicnRDAMa4aDSOkhOpKvP0diICDXVOZ/5hwEL9/CFlVgQxPPnFOpVKlbuoxR74+8rv7wXURFZGizqt9n+yeDnQsNmshgONkRAAQlkRWRvF2H0Al8t/AQghkKIB3uJeyY+Lc4McS+H2DiGhxKpH2fLIvtRshqMr/UCuE40YQAIRtIyIhsi/uR6Uy+fcPOOhMBjUYm4fWHV/orIt3qAdZES5oEWPkONcnu/sAOpEsmRDAcSYIACIQQFg27osH0Bk37/4yGsbr6kO75YlzqdBK4R7sQtrFzQt0xiX74j50Oo2sLH51aTqOiznCeEQ4iE4kyb50gMAJKxHO1D+GkBIhJX53H3Zr05jPLr300rlu6qJnsj6aybxAZbK4Lx4A30dWlFYI4DgVBAARjaBicbJ7Xs5rcJOREP5QAlGZwKqMjmz/m7/5m/lo6qJmfB/5Q3H8/iGsqugUR0xEJdNGCIRARIubVBfKcacajUZWVpiRYV8HWqnp942E8A/3lm0Ls0ClMngd3UXZC1QiRfb5vWBJRCQ0Z207rgUBjI1B9Q/hvnxo2tUhYVsgwO8ZGNl2/vnnj/jSlJmc4T7SrofX0YUMB8xydgF4g3Gyz+1FBAOIAny/ZsNxLwgAsroSv6sP72D3tMIgwyH8gRh+LDGPrTs2cDuNZb9QPyKvZwD3hX2IaBgRCMxl04CyIAA5m0FNFX5HF14e67OMhPAO9ZRXkYpAZ13IZKc0TI7HPdyLu7cdWVWBsOdnGjtvk+Vi82vqrItOphCR4v1PZoIQAmoq8fd3ghA4TROyJZr9bAuRFXiHe+alXYsd7Xloz0cWYDTTWuMd7Mbv6EJWVyKs+XtOz+eIUFR+TZ3OkLjvIdJPv4BOzo8rg5ASkXPF8Pqm9jOSkRAqnkR7/ry0a7HixxKQ9Qq6obXv4+7vwDvUhaitmlchgHkcEcYnsJ0sv2YuCdWVAEsrqrFqq/D2deLtPYi9ehnB1W2I8NytHEDOFaOqAveldpBySr8kWRHhjy/cglVXPaftWawMrxC97c2XQZ7JsXZ9snvb0bE4srqqKEtzqZj3OcJ0+TVzyWg3aa031YciICys2pxA7G0nsfO3pJ/ZbaLP5rKNloWsjODtPoA/NPnEWEjJVe99Hx9881vR/sIaGa655hps2+aaa66ZdttcoTNZ3PZDyFCAq973fq56z3un3FdlsmRf3GtcJo6SEMA8Z7rLpRbcAzA6+/JknNHYqu//yD+O3ah9/IEYaI29qs2MEHM4h9Cui06kcE5aOcaQNkwylUQlM1Q01GEtaThqP+J4bNvG930sy8LL2T0m2zYXaNfDPdCJsCQi4JBMJQGIhCcawlQihbv7ACCKdrwrlEJDNedtRChJfk1hYdXW5FSmjpE5RCHepDO6nOMgIiHcF/aj4skJn2999zu5bNv78Afj+AOzCy0tJdu2bcOyLLZt2zZmG4BSas5GBe15uO2HQBiHRTB9tPXd75ywrzcQI/vcXrCtOROCUeS9z+dNELTWO7XWG7XWF2mtN87qZKNVpn2dJO77LeknnkclJt6ss0UEAohwkOzz+6YUOFkRxj/ch0rOrcpWKDfeeCOe53HjjTeO2WZZFlprtm+feVKEqTBC0AW+mjbxgdbaLI++uN/YCIJzYyjTWuP39pN6+EmAN+bbf3HbEXICYdfV4LUfMiPErmdRsXhpLzMsDM/tnXRkEFIiw0Hcg4fRmWxJr11KJhspSsGIEHgecho3iJGVof2dyKrKObERaN/DPXiY9AO/I/nrx/EHElDAotCx4XQnJFZNNWiFd7gbr70Tq7WJwNoVWAVEoxV0iZx1M/v8PpwTVkxwGhOOjdQat/0wzvIWhLPwuvbGG28cM0qUAu15uAe7wZ9eCFQmi7v3oJkU15Z+UqxTadyDh3Ff3I/KeljRCHZ93fBDcXpHMo4VQRhGSKyqKtAK1d1H8mAX9pIGAieuwKqd/TKnCARACNwX9qHXLp/kcwedyhhhWNY8b1bRo4V2PdyDXUYIplnW1p5P9g97zMhZgrDK0fgDMTPKHOhAI7Aro8iq4tPBHJu/lJDIqiqk1qi+QZIPPobdWEfgpJXIuupZPY2E40ClwH1xP++6+FLkuNUkGQ4at+H2LpylTUdtZLjmmmvYvn0727ZtK/koAMby7x48DEpPKQRaa975hjehegcQQadkPkPa9/C7+nFfOoDXN4iwbKzaKhDFBf+PZsEWCpl0+XSmaI1KJNHpDFZ9Nc7Jq7Hqa2YlENr30YNxrLYm7JbGCecaDgW1W5sK9rEpJXO5XKqSRg2RtoUITn5z64xL9kAnanAIWVkad4kR9Wf3AVTWRYbDyDzL5yoWZ+g7P/jcmu5fTptgaXFPlgtFCGRFFKuhDpXKkPr146QfeASvs2vGxjBhWfQKn8PPvoi7v2PCeWQ4iLAl7oFOvJ7+eTe6zdXE2B+M4b58CBlwphQCbyBG5tmX0Mkk/ehZZS3USuH39JN69GniO39L9tm9yFAIu74urxAUw/EhCKOQkQh2Qx3a80n/7mmS//s73Jc70X7xT813X/93vPurn0P1DpCdJAZaOA6yImziHfYexOvuQ8WTqHgSr6sXr3egZBkyxluOh5dQIeddK+WsLM3a9/EO9eB19iCjoUlVPu36ZPd34L2wHxEMIKNR3vWZj/Guz3ys6O+j0xncve0k73+Y5K8fR/UMYNdWY9XVwAxSQubj+FCNpkFnMviJBNK2cU5YgbOsZcon3Xgu/rCpz/ezb9yCTiTRaJzVyya1Qmul0Oks2ldGjbIk2nWxWxqxqme/sjWVKjS8HZixpVkl03id3aD1lCtD/lACb+9BtPIRFdERVXF0H+VDa20eGvs78NoPARpZEUUEZq5allWjAhHBIHZdHTIcJvvsSyTu/S2ZP7xUtLVaRCMIxyH73F7cjokql5ASGQlhVUaQFWFkOIiMhvF7SjMqTKUKDb8XQkywNOdTnbTn4R3uxTtgMlRPJgQ665Hdd9CEUzoWsrKi6LmXzmRxD3SQuv9hUr96DP9QtzGY1tXNSgiK4bgXhBEcB6uuDqs6irun3Virdz1blOuECARMtFtHN9nnJze+jdnfstCeXxIj3GhVaLyKpLVGKTXB0jze+jyM9n28/kGyew+iYglEZXRCZJlWCq+7n8wzu1H9Q8iaqqJWhbTW+H2DpH//HImf/4bME8+BEFj1dciqKhDze2uWBWE80h6xVvuHukk+8CipX+3C6+otaMIrpETWVIGCzLN7yOx+2cQuTPHUF1Kg0vmTjeVDpTPorMv27dvxfX9SN4p88wLt+/gDZj6jugfMqBUJjXnCa62Nn9Af9uAe6EBEgkWNAjqVxt3bTuq+3NO/owurttI8/efI3aIQjk07QikQxvgjAZVMkn7oCUQoROCE5ditjYhgkA9c9rapDw8FsUJBdDJtAtBDQWRDDTIaRgYC4FgmEMix0fEk1Mzc0KRdD3d/B8KyuOqDV3Lzd26ZVOUZLSSjRwLtefixBNdf+xFu27GDt/7pn/C5L3xh7DWUwh9K4HccRifSiGgYqwDj2Acuexva8/E6u3H3HcTr6QfAqohg1U8eBXg0KAtCAchIBCIRtJsl8/SLZJ7ejd3WzFvOfCWydvqbQURCiEgI7br4nd34SoMAEIjaSpy2JZDKoLWesV1DJVMjmSG+dt1n+OZNN056rm3bto0Y2bRS6FQGNRQ3OZuk4N+//194yuffb70VLQSf//zn0VkPb2AIv7MHnc0aQS7ASq99D79viDfVtuEf7CL1yFNm2bO2BhaIu/pojvtVoxmhfVQsQXtvDyIcYuVZ67GbGwuKyx05hdbooRhWYx1WbRXOqrYR1+VicQ8cAu0jHAc/lsCqrxljMNRag+ejXQ+VyaDjKaOOKYVwbJMuRQg++clP8r3vfQ8LQZXl8MjOX6AHY8atOhLO6zKiM1n8vkHcji78zh5QPh2pODIaZmlt44y+22wpdNWoPCLMBGEhq6r4ix3fBqX5UWCbGSXqqrFXtGDV1eYVCiEEVFXid/ebESPrzkgQtFKodBorlz5RVkTwewdQ/UNgWaCUSV6mMTe0EMZBcJzurzJZPvvRj1ETy/DgT+9hy0UXQSqDqJpa/9daoxNJ/N5BvPbD+H1mBUwGAlg1FSAsrvnOPwNwxweLtyXMJ2VBmC1SYNfVgNbodJrM758DbdQpq6UBq64aWRVFhIIT0koKIdBo/EQalc4gi6wRAIDrIbQ2jm17D4KUxscp6IBSIGzkuJhh7flGLcpkTZmswQQ6aybs1175Ia79y2snvfm11uhkCjUYx+vpx+/oRmWzZrUnHDKOjfO82lMqyoJQKoRAhMNYYTMSaDeLu78Dd88BNAIhBFZFGGqqsRtqsRtrRzK46XgSnZzZypF2PRDSWKkHYwjHJvvMbkRdNTIc5P986cvc8eP/5q2XvZkPX3WNWap1vZyero26Ewogw2PnOtpzjbAk0qhYAr9vwFQQyq2cCduMKqXOSn20KAvCHCGcAFb1qHV1rUwM9KEeMi93kLUsgpvWYzXWoQeGUInEjCbMKpVGS/AP9SIqIwjLNiNBLIk/GONXP/kp1VjsvONuPrztKkQoBGHA88ycIesaQcxk8Yfi6KEkO++4g+effIZT153Kueeea1bQQgFjMZcz9/BcyJQFYb4QEhEIjlhKdTZD+uGniLz+FQgp0amMeVIXOU/QqSz/+Jl/YNcP7+CCSy7m/Re9CZ1ImvmBr3j3K87l97t2seHMs0g9+BjKddGez7C4aQQCs2AibAvhOPzuySfwlM+vn3mCVNBm167HqG9opLenmw0bNvLGS940pg1333MXu3Y9NulnfX19xGIx7r7nrgmfLSTKgjALrj5ny/Q7KIVIJxHZLFpKsG10OGrUqEAQZIbsiwcIrF46YhArZsKstUZnMvzo9ttYRoDOex9EhKsYuvcnRHu7qLQE70LzLiHQjxyCR39mdHghRl7accB20LkXtsNHnRS9sRiVtXV0//bnrNbgtfeRFYLMQ/cTqg6hg0devY/8hmYNex99GHHeeehACHIrTCemQdkV7Nr1WFkQjjlcFxkb4NJwFDnUj/zVz5FDA0de8SFEMo5IJRDjlqdVpILMWa8i9fqtWNUVeC8fwlm1FD+eRmWyxU2YPR+05pSVqxH7O3m75VP7XzdRO/L5kV2LUbhacy96DrJysh1++l9j3v7V6DefM++0ZaMDQf5DQJ/rYkdtKm77DipcgY5UoCJRdKQCHYmiIke2EQgeFTtDWRAmQ2tEIo7V14Xs68bq7Ub2dWH1dWP1dSNj05edHY0KR9BOEKEVZLPIZJzwr3fivPgHhrb9nQkrjSVNNF0yBfU1hTcz66KB3v3tvAGLVw4dRAvB020n8LPuIVadfhYXbN4CQnDvfTt58snfc8bpZ3DR5otAaVA+3/rG17CUT0AIrnzv+0wlG88Fz/wVnmtWprwsIpNBZNOj/qZh+H3GvB/53PcQKY9KoBKg/7B55ftOlm2Eo7IaVVWDqqo1r+rh/81fHQqXVGCOb0FQytzo3Z1YXZ1YXR3Y3Z3I7sPIzNTep1paqMpqnooP0eH7RK0og0JSt/oETnvd+Tzw+O955LnnOGnDJrSQI/ozWtP16EO8Xyao6ergxRs+wb5NF7J5aTOB6FJ0rLh0NMp1QWvevOUNXHzPT5BA6pwttJx/KX+WTBn3DScASvPUU89gC4cXnnqOC845P3cGQW19C4f7+nDq6/DaVoKQ3H3PXTz66JF8WJs2nT2pWjPl3EBr8DxEJs1V3/sG1b7PVy56CyIZRybjiGRi7N9UIvdZAuFmEbFB87DpODD1b+AEUNV1+LX1qNoG/LoGVG0jfm0DqrYeHakoSlCOD0HwPazuQ1hdHeaG7x5+HUZMEZCjgmFUfSN+XSOqrhG/vin3fxOquhak5NLrrgEp+eNAjTno5cPsfWEPjz77PCB45LHHRs43cmMJm5t0BR8jxtlegvsff5hvPvU01Re8kquuvw7teYUH/acy4Cs+8Ja3Unn3bWghSG46D+26hF+7ARkNj6hmNU+9ljvvvos3vfFNRC96Ddp1URmXB7/1FSq0INXXgz8whNaavbuepEo4ZLRPBjWlfr9r12MoNfHzm2/5Nt1dh2lsaqbdCdDugHvy6YV9JzeLTMSRsVGq5lC/+TvYP/JeZDNYPYeweg5NehodCOLXNeBX1jJETd7LHnOCINJJrM527M4DWJ0vY3e+jHW4Y8ob3q+uw29qMa/GIy8dzf9EqaysJDY6DFEIdu16bOoDcqimVh5pH+TVfpJX+Al+aUf45b072fHzu3nVe97BN7bfXNB3/dTHP8Ed39/B1085jS3Kx12xFk/YhM44GSGNBVmEHLSCqz98LVd/6CpUxhjSQCNDAc6+ZDM/uesutl68meglr0On0tQ98xiP3v8A9ZZDrQiwft06/P5BE5k2Kj5gw4aNR0a7UXR3HT7yd0mRRkIngKqpQ9VM75An0inkQB+yvxurvwfZ34PV34vs70H29SAzKexDB7EPHYSGc/JedlH7GolkHLt9n3l17MfqfBmrf/K6BX5dI37zUvymFrzGFvymVvzGJRCceXbtN3/ny/T19bHFDYPWbNp0NsAYtQIgGg6j01lsJApNRmhOrwjyZ4dfoFc63BRaxlNejC7f5XGSDHr5jWtaKU4MVRFRcDOa1/hJEhdsJXXexQTWnYCzqhVZGSVnNkNohUaAFIAAz0fFE6iBIbN0K4Wxfo+LKdCuaxzzegbwXj6MShmVUYZDxiYxycPi5u03j4wI90hTo3peXSy0RqQSyP4eREc7Pb/cfQz5GmXS2B0HsNv3mhv/4D6svu4Ju2nbxm9uw2tZhteyDL91Gd6SZbO64aejrq6OT4/7kd94yZtMwuKhBNr3EJbFPQ/9mvse+g0BYXHxGRs57/zzSX3tE9S7GV65po3W+kZ+9NsHuPJP/7ywC7seV1x+Bb+4/cecQQp8SC9bi2ysx6qpIHjSqiM5RZU64nSXTKHiKVA+Vk0l9pJ68BSf++R1/OKHd/BHWy7mgx+60jjZWcauYNXXYtXXEjhpFSqR5N+/+S32/PJhXr3udM4993XIaHRMLMFV264a+f+e70xbD2ZuEAIdqcCPVKCqGuCXu/MesmAFQXguwYf+d+TGt7o7JyxFaidgbvi2VXhLV+C3LsdvWGKMSfPARy+YpM6y8vAH46DBXt6Cs2wJsrqCt7/xvJFc+CqRJPm/j/C4dngNGXjuCS77xOe5/NqrCJx2QkHX1q7Hpz/9adoO9xPd+WOylo3fuIRAOIhz4sqxy7CWBZbFX3zkr0fcsL/19f+Ln0ii+obQnse/3b4DV/n8089+woeu/3v8rl606yOCQcSodDQyGuE/778XrRS///1vuOjq9+PuPYjf28cvfvUrHnr695w1avI8aR8tQAoSBCHE67XW98/2YrlCIHuAmlzlnKkbNtBDxR3/b+S9lhZey1K8pSvx2lbita3Cb2qdt5t+Ms5be+qRNyMp68FZs4zAqqVTFjWR0QjB9Wth5cmw+1FOiwRQiSRK+ehYYa4WKpNF+z59D/wvAAe0xa5f/4pbv/VVVr7nrfzTtyfOM8YH5tiBanRNFTqR4p3v+FN+9P3v85Z3/AlOSyN2cz0qlsI71GXmB8HgSHnXrVsv5c47f8pFW9+Is6oNe+VSVP8Q9930TVqlw55dT8JFW8B2xvbRAqbQEWFL7od5VGs9oyQ1uQIhO7XWu4QQ24FpBUFbNpnTzz5y47csn5M0HrPhqY4DoDXrKmtRroezuo3gmuUFVfWx25o44eKt8K1HaYz1csPNN/HkPyU4480X88n//FfIkyFPp9KgNW9oa4G9A/RHa/jV73fR4SW555+/M6kgjA7MGUYIwV98/G/Z/p/f5S/e+37+4e+uQ8VTiEgQqzqKVb3K1DHo6Eb1DyIiIa699sNce+2Hx5zDqqtm1Rsv4Od33s0V573eqGBejD/E+hBOgNNaJ6bIXEgUNVkWQtwG/A74gdZ6X1EXEuJ24ONa6z05Qdiutd41bp/RpaM2PvH3Xy3mEvPOm7d/Ee0r7rz2OoLr1iArJq8mr32FTqXAN7lohW0hohHSTz5P1Wc/it1ziH8NtXCfUrxMll/0d0ybUBcgu/sAKpVGXnE5gWef4PZgM/dowc+9Pi770AeKSvM4OrWLm83iD8ZMnIQUY5Jo+bEE3suHUMmUmYhbUwur9j28jm7e+Mm/RvuKO9730TmvlTwZJU3nIoR4VAixA7hBa/0VoF8IceYs2zgwfsOE0lELFeXj9+Vib6sqCL/y9EmFQGuNisXR8QRWQx3OiSsInLQSwiFUPI7d2kS2bTUAS/0UldLmsi2XoPPkGdKuxz9c/1nO2XA2avezAERPXk9awuXvfU/RuU5Hp3YRUmLXVhNY1YaMRvAH4yNFEz/9xS9w0iUX8pXv34pOZsx3myopgWXjLGtB1lTl8jxp/J4+dAkSFcwFhUZR3KC1vkJr/Xju/Wag2MjrR+CIZSNf6aiFikok8PoHCZy0CllTiQhM/lTUrosaGELWVRM87USctiasyiiyIkJgVRtogVUZwWtZAUCrynDeWZu48s//3LhaTIN2PW67bQcrNITcLCoSZeOFW/j2v/8bX7n5W0V/p8lSuwjHxm5pxG5rNgE8qQz/8R//D1/5/PPtOwiuW4usrUL3D00ruEKACDqEL3gFobPXg9Z4PX3ozMISiIIEQWv9w/HvZzB5vgXYJITYANxb5LFHH+Xj9fYhgwGi551N4MSVU4cwpjLoVAbnxJUEVrROEBbhWFgtDSYG4NR1AJwaDXHha88xo8gUBQxHmpLNcvnll7Mpd3lvyTJkZRRZW4Us8TzKqowSWLkUHJv3/sk7kNLiXe96JyJgE1i5FPvEFeb75qlWJKTEbm0ifMHZhM9eh/YVXm8fuO60x80X81ledgAjDAC7ptl1waFTKVQiRfCU1Thr2qbXjRNJEJLAKavHZMHWmayJJpMCEQ5h1VThtx9Grl+HCoSwhvqRsSFUOouKJadfOUqm+dR116GTafjxf+E2tOSSHNdOvv8sEY6N09bM9V/5Mp/65HXIiiPzBrumEnnqatx9B1EDQ4jqymlXvIRlYbc2YzXX47V3kf3DbnQsbtJeTtOvc83iDDCdL7TG7x9AA+HXbTSjwKgf6/oPXsP1HzySLEslEuDYpg5DTgi0a+wKIuCYFPFVFcbbNGTKUYmqCrwlbQC8cP89XLvtQ3zti180LtZToJJpEyS/fy8AXkMzMhrCqsufZ2imZWaFlNjN9VjNdSaibVSyMxkMEDhhBVZzg1GVRrmzjO+jkfNZNs6KViKbX03g5NUmZ1L/IOi8xW3mhLIgTIXn4vX0YbU0ETl346SFxV+1/gxetf4MAFQ8jggGCZyw0gTOY0YBlXVxlrdgL21GVkSwm+pNhux0Bllfg7At9uS0A3ffi4QV3HvnXSMT1PFo1zOrUFkP2b4fAL9hCSIcmnLVajTTZcIrBLu2GmtpMyqRHtNGISXOsmbstcvQseTIHGB0H02GcBwCJ6wgeuErsdua8XoG5qQoZD7KgjAJKpnCH4oT2nAKoQ2nTJnT86Gnn+Chp59AJRImC96a5QjHGPhU2iTtCqxonZDaxaqrQbs+VkUEIQS/7DUrUC1+hkrb4bLNb4ApJqA664IAPTCA1X0YLQRufRMyHCoor1Ip6iZYlVGcZUtQqYxR90Zh11UTOGU1eAqdSI70UT5EOETozJOJnLcRGQ6YFSZ3/gozLlgXi6OC1vgDg8hQiPC5m5B50rVf/50b0b7iZ1+7aYwQ6KxrLMxTFBUU4SDCtsCSoDWP51Yg21SGj131l4ResQ4/kZxQlgrMRFlIAc8+j9CKoVCUr9xyC80vvYaPvj1P6CilKygoo2GcZUtM0RAhzPcZ9Vng5FVk97bzmZu/ibCtgtLCA1i11YReuwGvo4vsUy+g4ilTEHKO08SUR4RhfC+nCjUQPndDXiEAYyhDCAJrVxwRAt9HZT3saeqnCSGML5CvEBURTrlwM0kkETRyaADteVOuHOlEyqxJPvsHANpdTcxzue2eu0bSPs4XMhLCaWs2I8P4NPhBh8Da5WDnMn4XY7iVEqdtCZHXvwp7RYtJWDbH6lJZEDDVWfy+QYKnn0Bow7qC0psPG4ZkODiyPKq1RiXS2C0N0xbdBvPU1J6PrK3iz6/8IPayVQDY3R2otDtlSnmdzvLZz3yWn9+Ysxc0LyUjYcvlby3065YUGQ1jtzaiEymTUW8UwraQoSDCsY27typuIiyCAUKnnUj4vE3IoIPX0wfe3Cy3HveqkRoaAiEIv24jVoHxwtp10VnXVJMc9RTWiTRWXdWkFXPGI4IBEOapKm0Lv3U5vPwSdncHbipj8hWNi1bTnpko//i223ibNrr50o2v5Ot/9BYiF7+uuC9eQqzKKLq5Hv9Qr8nqN3r5VAhEMIDd2oR/sAtqKoseuayaKkLnbMA70EnmmZdACqyqypLGLB+/I4JW+H39yMoKIuedXbgQ+D4qnsJZu9zo+MPbM1kIOAWv5QvHzvn7B0BrHhk0ASx9zz5pKoDGUxNWjrTrg4B3XPZWluUEwa1rQkbCEwqgzzd2bTVWfbWJdRiPEDhLm7CWL0EPxIxKWSTCsnBWtRF9/SuwGmqMdbqEk+njc0TwXLz+GM7qpQTXrZnWQDYarTV6MI6zailWVZSvXm8i6LRSKNcnsKKpqKedjEZQySQiHOTuPXs5Fwj2daFiKeNd6nowSsVS6TRCCD7ynvcQ/cH30JaFampFVEWRM8ykXUqsxrpc8E96xGlwuI8AnCUNICXe/g5kVcWEXLCFICJhQmefht/ZTeaJ51CJlKnTMMvR4bgbEXQqjTcYI7ThZEKnn1SwEADooThWSyN2o3nqn7l+PWeuX4+KJ7GbagsuQjiMjIbA85HVVSw/7wIU0Kg99NAQOmuKd4xGDSXBlvz3DV8EYDAYRVRVYEUjed225wMhBPaSBrCskSpAw300jNNUh7OqDTUUn3lpXyGwW5uIXPBKY6Humb2rxnElCGpoCO25RM7ZiLO8tahjdSKJqIhgLz2S5/++Bx/g3vt2IqORglaZJpC7eWVlhHdf9SFUfTMSsHoPmSo2vYMmd5HWprhhKgVa0PfwbwF4OePz6Rv+kS9tn7wwyNFAWBbO0iZj9HM97nvwAe578IEx+9gNNbMWBjBViUIb1xHatB4/kcQfmnk95+NDELQ284Fo2MwHJrEST3t4NgsIAqvaxqg+X/jG17nhW9/Ebq6f0Y0oAo7JOhp0ELbEbzGuFnZ3pylF6+YqVr64H6/9ECIcQmdcXrfUCHG3DJDyff7l9u8Xfe25RAQcnNYmVCrDF77xdb7wzf87YZ+SCYMQOG3NRC94BVZtNV53H/jFjw7HviAo39gHWpsJvebMgqLHRqN9hU6mcNYum+BFqv1cxZkZ6udCSmQohJCC7f/8L9z13IsAON2dqHQaRK7WWEXEBMI4NiqZZF3O6a1bBsgoxdvf82czuv5cIqNhrKa6aSfGpRIGyM0dXnEaoTNPxB+Mo+LTe/BOaO+srr7QcV283gGCp64hdNbJCLv4G1YNxrCWT3ST0JmsKQZYaDKuKRAVYfB8fvizu9id+zns3kOmvNMkme9ULInsaAegVzokLc3Xbiw+BmE+sGqrEJYcicybDLuhBmfl0pwwzM7hTkiJs7KNyPlnI0MB/N4+CrXjHbOCoFNpvKEY4bPXG6/RGVhd1VAM2VAzMjkeObfWqIw7ZVBOMchQEO0rtrz1MvbmJu52VydqMIZOjC1Lq30f3d2D1d+DLwTdwuG1b9iyICbKkyGEGBkt9TSTWbuxFntFK3owVrTRbTJkZQWhc87CWbsCv28AIO8qxsLswVmi4glQPpFzJvcaLQSdySCcAIFlLRP0f500hjNK4NJgPFUFf//Z60n/8Z+i3vVHyFQC0ddj5gkZFxEyv6OKpxH795k2NCzh6qs+QvTS8xbMRHlShEAEHVQqi7SsKR9ITlOdmcvt75yR0W3CZS2b4KlrIBwk+aOd+/Ltf8wJgj84hAwGCL3qLGR0ZnHPJtg+Q+CU1SM+RCOfeX6uQnzNjF2ZRyMsyzw1lUYGA3iNLQTa92D3HEKlM3idXVh1tYDG29+JPLAPAK+uCVERwZpFfeb5YLiPrMY6/O6+aa3uTnM9eD5eRxeypqokAu401AA8mm+/Y0cQckE0Vm01obPXzaqKuxqKYS9vmdStWaXS2C2NCMvipJNOmk2LR5AVYdRgHGwb1bYc2vew/+HfcP1P7+B1F13ElR/8AAAiEDgSjFPfZCLdCnDnOJoM99FIIcJUZkzk3njs1kazdNzdP69CfmzMEbSJf7VaGgm96rTZCUEigaytxm6amJtAZ11EKDTiHn3nnXdy5513zvhaw8hwyKxAVUTwV5qsFrrzABVa8sP/uRtZXYmsrjTu23tNzgOvfomJciuitvPRYLiPhBDYzQ2m0Pk00XdCCJMdsK4aNRSbt3Yu/hFB+Xg9/TgnrCB4yuoZme2HMfYCCCyfOC8A43Fqr1g68tlXv2ryLm3dunXG1wSz7i4EyMow3so1ACwPWjQSZN3FF43Z1zpoagZ4tQ040UheL9ejzeg+Eo6N09KI234YWRmZUvURUhJY0ULW9VDxOLJi7it3Lu4Rwcstj552AsF1a2cnBEqhEimcVRPtBWDihGVVxbTD+oxx7NykMohevRaA2kySv9t2DX951dVH2tDbh+zrxgcefOo5ZG0V2IuryqWsiGDVVpnEA9MgLIvA6mWIQMDEYMx1u+b8CnOEzmbw+4cIbVpnAmNmObHSsTh2S9OkXpxaa7Sv5i5LhBCIcBApJf9yx08YFDbCc5H93bj7O1CxOP5ADO/B3yCAHmHz2AvPY5VoQjnfWA01IOW0S6pg0t4E1phUkXOdGGxRqkY6lUalUoRecyZ2U31JzifCIeyWyc+lUxms2soZW5ALQUTDqFSan937P1wsHKq1RzDRz/9s/zeeeeZZ1q07lS31ZhWsWwY49czTsWoWZ7FvYVnYLY24+zuQVfb06V+CDs6JK8g+t8ekwikgaGomLLoRQSWSaDdL+JwNpREC30dnXZxVbZOqVloptFIztkcUigyYWmcXvPES9kojcHbny/z6mSfpUxl+/cwTBHo6AVhxzgVc8pY/mjSmebEgIyGshhp0YnoVCUwUoLN2OTqZHpMqpqTtmZOzzhFjoslqS3NjqsG4WSqdQvfXyTRWQ+2krhS33nort956a0naIRzzZPzwJz7O6640hTbs/bvZsGEjUko2nLUBZ7eJU3ZbV0IkwMc/8+kZ5SiaT6brI6u+BhzLJDvIg1UZxV7dhhpKlMT6PJ5FIwj+wCAiHCRyzlkF5e8pBJVIIOuqjc46Cdr3QUqs6slVkGXLlrFs2bKStGVkwhwJ4Z96GtqysF/ew9bT1/Pq9euxf/egca2oqsVdshwrGOLb3/3XWeUomg+m6yMhJfaSRlQ6W1Bwv11Xjb28BTUQKyoZQCEsCkHw+gawqisIv7p479Gp0K5JuTLVUimASmawGmunXI3asWMHO3bsKEl7RG7VSEjJrXf9lGcJIbSm5sZ/5PKH7+Yd2T4AHq9fyi233MK//WAH773yg7POUTTX5OsjGQ4aFSnPKtIwdlOdybZXYhvDwp4s56zFdlO9SbRVogS3WmtUPIlz4sqps1l7vsnCMI0efvPNphjHFVdcUZJ2iUgQ0inuvn8nqUAddsZnbdAi67oM+YqDy09kx+EhAkh23vdzHsj0ceNNN5Xk2nNFIX1k1VWjYgm06+b9jYcNbtl0tqQ2hnkbEYQQNUKIDUKIt+Wq5+TF683FEWw6tWRCAKDjcaymOuwpVB4wrhRWY+285gqSoSACyesvfgOHpeAnS9fxUVnP7a9+E9bnv83ybR9h4xlnkUFx4WWzM+ItJExe1QaTVbsAlUdIaWwMlmMqgpaA+VSNLgcGcrXTrhBC1Ey7txA4q9pMHEEJsyTrbBZhB3CWNk+9j+cjbHtsQb55QDg2aMXVf/tRPvmlG0j09Y8U9B7mgteewyc+/zk+/rnPzmvb5hoZCSFrqwtWkYRj4axdhs5m8xZWKYT5TAt/y7j3A+P3GV06qq2ugeBpJ5bUYDRsPQ6esnpMisLxjDjWzXPmOBwbjQnq33HHTxj+5sFgkM99/h+ob2hE9wzQ+JqzeH8B6R0XG1ZDDSqeHFFL8yHDQew1y3Bf3G98sWbxe5X8l86pPuNfNaM+/xiMVFodw+jSUQ01tSW3muqYKdc03ZNeux7CcY7KGr2Q0vgd2Q4/vf8+FBobQSqVQilFsruXIe3yswcfHFPy9VhBWJap5lngqACmPoO9bMmsHfRKPiJMVzZWCLGZPNU05wqdziDCQVNgexpUKo29tLkgIfzBD0r/VWQ4jEpn8SQ87yY5xYlSV1tLf/8AXtDmsVgPa9uWmyx7i4Bi+0hWRLCqomNyI+XDbq5HJVKooTiycmaT5/mcLG8GvgRsB26fr+tCzjqcTuOsWDqtY552PUQwUPDcoKGhgYaGhlI1EwARDiCkYMsbLqadLC+1VPLjrn0Et7ya++JdJLXPno72OXX3KCUz6SOrsW7Eol8IQohciS7HJEabAfMmCFrrnVrrjVrri7TWG+frupBLzLW0OW/9AJXOGCtygSrZd7/7Xb773e+WoIVHEI4DQvChv76W//npT3npwAHSyuOOu+5i69ZLCUiL8y55w6w8beeTmfSRcGzspjqTy6nQY2wLZ/UyU0hlBpPnRWFQmw06nUFEgpMG2ozZz/UQAaeolaK5EYRc0q+aSnTWZevWSxFSsnXrpVx77Yf57+9/n7/6zN+X9JpzyUz7SFZXGrf0AtwvRo4JB40bRqx4N4yFbVCbJcMqUeCUNXmfoMXMDeYS4dgm31E0jO+rMVXutVKgVMn8rBYyJqKt3nioOtN7qI7GrqlEL20ysd7VhffTsS0IsThWa1NelciMBoXPDeYaEQqA5yEqo6iBoSMJbrVGLmlAHOXM1/OFDAex6qpQg/GiQlLtJQ2oRAYVjxd+zEwauBjQmQwiGMRuzu+qrdIZ7Namoz4aDCPCIXT/IIG1y1DJtFkfzyUTU+lMyWspL2Ss+hrjcVqgbQFGhXo+txeVKmyecUzOEbRSqGQaZ9X0q0Rw9KzI02GSfmmEZWFVRpHRsKnM41hmcFhk4ZmzQVgWVpG2BQARsHHWtKGMC0beB/4xOSLoeAK7Jb9KBKOsyDMYDe6+++6ZNC8vJjZh4natNRoWzdIplKaPZEUEGc2V5C0iWYGMhnFWtADkrShyzI0IxpfIyWs4g9mPBpFIhEik9COJCDjgOCOlW7Xvm4J8yTRWVcX8u37MglL0kRACu6kOnfWKjkPIxZk/nW+/xdOjBTBcR8Be2VqQPqlSmVl5mN50003cNEdu0FZjLSqTxY8n0a6PVtqENxZY4mqhUKo+EsFAUXELo45kTfcv87qoHlOqkY4nsJrqCqonNhJvMIu5wW233QbA1VdfnWfP4rEqIsi1y03U2gKZxM+EUvaRVWtWkIqZOBfKMTMiaM8zrtstTQXtr1IZrIaaBa1mCCkXtRCUGmFZWE11RU+cC2Hh3gVFoocS2MtbCkrVrn0fYclFnQXieEVWRBCRkKliWsrzlvRsRwmdSCJqq7BqC0saq1NZZH31gh4NykzO8MRZZdySBvAv+jtB+yaprLNsSUFqhFYKLcCqWpzJscoYO4tVUzVjT9PJEKVOi1EqLCGSQeSeAvazM1plXPTs6osWRgPQMw/XKZSF1h5YmG0Kaa3XT7fDgl01UvCHpPY3He12jEYI8ajWesG0aaG1BxZum/Lts+hVozJlSkFZEMqUYWELwi35d5l3FlqbFlp7YJG2acFOlsuUmU8W8ohQpsy8sWBXjUYjhNgArJ4uVcw8tqUGWD380lp/+Si140pgD1BT7pfpKeT+WSwjwmZg+uj7+aO41JVzQC5J2qNa653ARfn2nyeOer9MQ977Z8ELQi4f0s6j3Y5hctn49ox6P3AUmnE2MHLd3BPvqLJA+mUChd4/C0I1EkK8bZLNOzHD2c6j8UNP1abhH3i61JVHgYGj3YBhFlK/CCE2FHr/LAhBmEp3E8YXfzOwAVgjhKiZryfNQk1dmeMRoGb4zegn8dFkAfTLBAq9fxaEIEyF1noXgBBioejBo1NX9mH0znnN2pfjFuDynJPhvUfh+hNYIP0yhmLun7IdoUwZFsFkuUyZ+aAsCGXKUBaEMmWAsiCUKQOUBWHRkatM+thwZVIhxJdyr5qj3LRFTXnVaBEihFgNbNdaXySEeNtC8DVa7JRHhEVIzoC2SwixnQXkfrKYKY8Ii5ScKvSY1nrN0W7LsUBZEBYpo3yhztZaf/yoNuYYoKwaLUJyc4Q9ubnBhtz7MrOgLAiLjJxPz+2YoBwwvka3l4VhdpRVozJlKI8IZcoAZUEoUwYoC0KZMkBZEMqUAcqCUKYMUBaEMmWAsiCUKQOUBaFMGQD+P/u2JwqFPxYNAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "train_loss = errors.loss(mean_vi_train, sigma_vi_train, y_train)\n", + "\n", + "ax =plot.plot_prediction_regression_without_test(x_train,y_train,x_linspace_test,mean_vi,sigma_vi,y_min=-3,y_max=3,\n", + "title=f\"Train {errors.loss(mean_vi_train, sigma_vi_train,y_train):.2f}\")\n", + "ax.set_xlim(-4,4)\n", + "savefig(\"MLP_VI.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/motorhelmet/Calibration_VI.pdf to figures/motorhelmet/Calibration_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/motorhelmet/Calibration_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAf8klEQVR4nO2deXiU1dm475ONJJgQIspSZYkVE0WtIanlIxSEgMW2IhCgQF2KCMhPLfYrQam9DP1KEeqCC2gAUVCURTQuLEoQkGCphLiBgELYtxBCwGQSkpl5fn/MmziEycw7w2Synfu63mtm3rNm8j5zlmc5SkTQaJo7QfXdAY2mIaAFQaNBC4JGA2hB0GgALQgBQSmVrpQaZ1zblVKpSqlMpVRcA+hbnPNrc0ULQmDIE5F5wHKgSESygSlAjKeCSqkYpdQ4M40YgpaqlEqrpZ5053SlVAyQqZTK9OJvaZJoQQgMuTVviEgxkO+poIgUG0LkFqVUOpBrCFl/F1kexyGQ2cAIp/tTRGS8iHjsS1NGC0IAMB56VyQppVYYv9RVU5Q043OM8Tm16hfbeL9OKZWolJpZo65koLodpVSimy7FOE2Fkox6TY06TRUtCPWI8euMiMwSkXzj4SwG3sHxC16VJ7ZG/jzgtIeHvbjG5xlAolGmqr5iEZln1DvMX39XY0QLQv1T5PQ+DscDGWdcriiu5f42nNYcNac6xqg0zxCifEPwxlWNPM2dkPruQHPBeOCGA3FKqVQRyVZKpRqfY4wHNQ5YZxQpcvrFjzPKxxnv4/hpKpRn5JkHDFdK4VQHSqlxIjLPKJOqlCrCMTqAY/Eep5SKBVbUwZ/daFDa1kij0VMjjQbQgqDRAFoQNBpAC4JGA9SBIBjKmZrKHuf0Ws0ANJr6wu+CYChnXO6BmzAD0GjqhUBPjdyaARijRa5SKveGG24QQF/6MnUVWYrk5y/8XNQ0JWqaErykvtcIxc4fDHV/kogkRURE1FOXNI2Nj/Z8xM+e/Rn7zuwjLDgMoNTbOgItCG7NADQab7Db7dz73r38funvOW87z7Q+0zj68FH4HK8NCP1uYuFkNpBo2LVUq/mpxQxAo/GWQ2cPkbIwhcPnDtM6vDWf3vMpXVt15frrr4eDzAPe8qY+vwuCsRDuXuPePOO1GIcwwE82MhqNV7ya9yoTVk3AarcyIG4AH478kLCQMPLy8jh69ChAS2/r1EZ3mkaD1W7l92/9nrX71hKsgnnlt68wPml8dXpZWRlWqxV8WCNoQdA0Cr45+Q23LbqNorIifhb1M7aM2UKnmE7V6TabjUmTJtG2bVtOnjz5a2/rr+9dI43GIzM2z+AXr/yCorIiRt84mkOTDl0gBACZmZnk5uby3HPPVTkueYXfzbANl798IEZE3qmRFgOMw7E+uCjdmaSkJMnNvcjVV9OMKKkooe+ivmw7to0WwS14e+jbDE4YfFG+EydOEB8fT1JSEuvWrUMZuzHe4NcR4RIcyDWaC1i7dy1X/vtKth3bRkKbBI785YhLIQCYPHkyZWVlzJ07Fx9kAPD/1MhXB3KNppoHP3qQgUsGUmYtIzosmtxxubSJbOMy74YNG3jzzTdJT0+na9euPrdZ12uE4hqfL3Igd8bZxOLUqVN13DVNQ6OgpICuL3blle2vVN+ziY3dhbtd5q+oqGDixInExcUxderUS2rb37tGHh3IlVLzjNd8F+lVSjeSkpL8u3jRNGiW71zO3e/dTYWtgh5X9eD4ueOcKjtF25ZtiW8T77LMM888w+7du1m9ejWXbJIjIn67cAjBOCARSHO6P854jTPS04BEd3V1795dNE0fm80macvThAwkaFqQPL3laRERKa0ole3HtktpRanLcvn5+RIRESFDhgxxlez1s9tgnff1rlHTZ1/RPlJeS+FEyQmuiLyCTfdtIuGKBI/lRIQ777yTDRs2sGvXLq6++uqaWep310ijMctLX7zEdS9dx4mSEwy6bhDH/veYKSEAeP/99/noo4+YNm2aKyHwCT0iaAJKhbWC29+8nY0HNxIaFMrCOxfyx5v/aLp8aWkpCQkJtGrViry8PEJDQ11l83pE0CYWmoCx7eg2+r/Rn7PnzxIXE8fmP22mQ3QHr+r4xz/+weHDh3n77bdrEwKf0FMjTUD426d/49YFt3L2/FkeSHyAfX/e57UQ7Nixg2effZYxY8bQs2dPv/avLvwRajWxMNJTccT7jHOVrmlaHPvxGH1e78MPRT8QERLBeyPe4/af3+51PSLCxIkTiY6OZubMWmND+IxfBcEwscgWkTwjlHlNW6NUHAJSFfdT04TJ2p3FkGVDEISw4DDyH8mnXVQ7n+pavHgxmzdvZv78+bRp41rLfCkE1MRCHDZGjyulVgDZNQtrzXLTwG63M+b9MQxeNhjBsRkTGhTKsZJjPtVXVFTEX//6V3r06MGYMWP82dVq6nqxXOz8wbAtmoJDsTaTGoZ5ojXLjZ4j546QsjCFg2cP0qpFK1qGtuTs+bNuNcSeePzxxzlz5gwvv/wyQUF1s6w1JQhKqWgROaeU6ozjDLBztWT15JyfJiKzjDpRSsW5yKNppCz6ahFjPxyL1W6lX5d+rB61GqtY2V24m/g28USGRnpd59atW5k/fz6TJk3i5ptvroNeOzClR1BKPQDsAwYAX4jIu7Xki8FxBkAuTovhGjH6E3Espt0ulrUeofFgtVu5a+ldrPphFcEqmBcGvsDE5ImXXq/VSnJyMgUFBezevZuoqCizRetMj5CL4wGfQQ3HfGekFud8+cl5P5+fDtDTzvtNgJ0FO+mzqA+FlkLaX9aezX/azDWx1/il7jlz5vDVV1+xfPlyb4TAJ8xOuGJxbHnG4kYQNM2LWVtmcdMrN1FoKWRkt5EcefSI34Rg3759/O1vfyM1NZW0tLoPk2tqRBCR9cB64+O/6647msaApcJCv8X92Hp0K2HBYbw15C2GXj/Ub/UfOnSIbt26UV5ezp49eygrKyMy0vv1hTeYGhGUUn2N185KqSG+NmYceh3ja3lN/fPZgc9o+0xbth7dSvzl8Rz9y1G/CsGaNWtITEykvLwccGyd7t7t2jHHn7gdEZRSQ3FsccYppYbhWIQI4HKxbJRxp1lOBcYbfqWxwAPiQ8QBTWCxVFrYXbibhV8uZM62OQBMunUSz/3mOf+1YbEwefJk5s6dS0JCAuHh4RQXF9O2bVvi433bdvUKTw4LQBegH3CLcY11kzcdw+EGyHSRHuf0PtVdu9oxp2FQWlEqHZ/tKCpDCRlIy+ktZdP+TX5tIzc3V6677joBZNKkSVJWVialpaWyfft2KS117ZjjAa8dczxOjURkP45RoR+OX3R35xp40iznG/fTxDg82xmtWW54zNk2h0PnDiEIQQTx8d0f8+vOXsfPconNZmP69On86le/oqSkhHXr1vHcc88RHh5OZGQkiYmJdb42qMLs9ulSEfkKHOsEL+ovruV+MjXskEBrlhsSdrud0e+OZunOpYDDROLq6Ku5pd0tfqk/Pz+fu+++m88//5wRI0Ywd+5cYmMviucQMMwKwiyllABncUyVkmvJ5zHsuza2a/jsP7OflNdSOPbjMS6PuJzVo1cTEhTis3bYGRHh9ddf55FHHiEoKIg333yTUaNG+RyPyG+YmT8B/Zze3+ImXwxunPeN92nOn2u79Bqhfpj7xVwJnhYsZCC/XfJbqbRV+q3uU6dOyeDBgwWQ3r17y4EDB/xWdw28DzzhdQGI9qUhby8tCIHlfOV56beon5CBhPwjRBbmLfRr/atXr5Z27dpJaGiozJo1S6xWq1/rr4HXz5un7dOxIrJAKfUK0BrH9qm7qZGmEZJ3PI9+i/tRXF5Mp1adyBmTw1XRV/ml7sLCQiZOnMiKFSu44YYbWLt2bZ0az/mMOykBWokXUyN/XnpECAxPfvpk9dbon7L+JDabzS/12u12WblypYSGhgog0dHRcvr0ab/UbQL/jggictZ4Xa+U+oXx/su6FExNYDhXfo4+i/rw5YkvCQ8JZ+Xwldxx7R1+qfuLL74gPT2dTZs2VS+CbTYbBw4cqNedIXeYNbF4APgD8Ael1FgPed0eKK6USjPS073vrsYfrNu3jnbPtOPLE19y45U3cvx/j/tFCPbt28eIESO49dZb+e6773j22Wfp3LkzLVu2DJyG2FfMDBtcODXq5yafJ81yKk7hH921qadG/qe0olSGLB0iZCAqQ8nj2Y/7pd6CggJ5+OGHJSQkRCIjI+Xvf/+7nD171tHmpWmIfcW/UyMn4pRS24z3Xdzku0BR5nyypkF/YF/VyZv85LugqWP2Fe0j/qV4rGJFodhw7wZ6d+59SXWWlpYye/ZsZs6cicVi4f777ycjI4P27dtX56nSEDd0zPojLAdmAfNx4XTvhuIan2OAfDEOEqlpiapNLOqGJd8uIX6OQwgAwoPDiWrhu6OL1WplwYIFXHvttTzxxBP069ePHTt2kJmZeYEQNCbM+iOcBSaYyOpJs7zdQzvaxMKP2O12hi4fStaeLIJUELERsZy3nvfZkV5E+PDDD3nsscfYtWsXPXr0YMWKFX4PtlUvmJk/4ZgOLQeWAZ3d5IvBs2Y5Hae1Qm2XXiNcGt8VfCdXzLpCyEDaPd1Ovi/83mOo9dooLS2V119/XXr27CmAdO3aVd59912x2+111PtLxus1gllBmGwIQxfgr7405O2lBcF3nvn8GQmaFiRkIGnL03zSDdjtdtm5c6dkZGRIixYtBJDg4GB5/vnnpaKiog567VfqbLG8XRzm2Cilso3XaKk9rIumHii3lpO6OJUth7cQFhTG4iGLGdHN/JmNNpuNrVu3kpWVxfvvv88PP/wAUB1LqEWLFqSkpPg1+G5DwawgPObkodZFKbUfx+jgfRBLTZ2w5dAWBi4ZyI8VP3Jt7LXk/CmHKy+70mO5srIy1q9fT1ZWFh988AGnTp0iNDSUvn378pe//IX+/fszYMAATp482fB1AZeCmWEDF7oDV/f8eempkXkmfzK52kzioVUP1Zqvak//8OHDsnjxYhkyZIhERkZWm0D84Q9/kKVLl0pxcbHLcgHWBVwKXj9vAT8opCq6nacodzrAl2eKLEX0eq0X3xV+R8vQlnww8gP6dunrMq/FYqFLly4UFhZit9sB6NChA4MGDeKuu+6iT58+hIWFBbL7dYn/Anz5ugZw57xv6A0ylVL5OGKfanzAUmnhlW2vMPXTqZy3nSe5QzKf3vspl4Vd5jK/3W7noYceoqCgAIDQ0FBeffVVRo8eXWexRBsdtQ0VwAyn979weu9u+9STiUUMHk7TrLr01Mg1P5b/KJdNv0zIQMhAntzwpNv8FotFhg0bJoBERUVJZGSkxMXFNaZpji/4ddco1/BDAIeJxRk8+yN4MrEASFJKxeKwNbrAxMIYTcYBdOzY0U3XmicHiw/yywW/pKSyBHBoiO+87s5a8xcUFDBo0CD++9//8vTTTzNhwgT27NlDfHx8wJziGw1mpAUnHwTcu2quwDCmAzJxY1gHrHPXph4RLiQzN7PahTLinxES8X8REjc7rlbl2M6dO6Vz584SEREh7777boB7W+/UjR5BRL406Y/g1sTC+MVfLo5gwRoTWO1Wfrvkt3yS/wkhQSEs+N0CRt440m2o9ezsbIYOHUpkZCSbNm0iOVk7FHrC7PkIDwDXGO+TRGRBLVnnAcMNZ4x1TuXHiWMatBzHNCsWx+ihccNXx7+i7+K+nCk/w9XRV5MzJoeOrRxTxsT2ri06FyxYwIMPPkh8fDyrVq3SU0yzmBk2MOmP4M+ruU+Npm2cVq0buPe9ez2aSdhsNklPTxdAbr/99mp/gGZK3UyNMO+PoLlESipK6PN6H7Yf3054cDjLhi1zuyAGh47gnnvuYeXKlTz44IO88MILhIToI7S9wey3VeWP0BrHGWiaOmB9/nruXHonlkoL3a7oxqb7NhEb6d7H98SJEwwaNIht27bx7LPPMmnSpPoPltUI8bc/gsZHJq6ayMu5L6NQpP9POjP7e9Y37tixg9/97necOnWK9957j0GDBgWgp02TgB84buRJRB84DsCJkhOkLExh35l9RIVF8fEfP6bH1T3clrFYLLz22mtMnTqVli1b8tlnn9G9uz7I6FIwLQhO26dfucnj9sBxJ1KpPUBws8BSaeH5rc/z5MYnqbRX0qtjLz65+xPCQ8LdlistLaVz584UFhYSFhbGtm3b6Nq1a4B63XTx9/apR82y4bifDST51OMmQMn5Eto+0xZLpQWAmakzSe/pObrNgQMHuPvuuyksLAQgJCSEkpKSOu1rc8GsxVW+iDwmIo8B+72ov9j5gyEYtTr/Nwfn/T2Fe+j0fKdqIYgIjiA1zn2A8MrKSmbNmsUNN9zAl19+SWxsLJGRkbRr167p+gcEGLOCEKeUilZKReMIw1IbZsPCpwLda0axEJF5IpIkIklXXHGFya41HmZvnc31c6+nqKyIyNBIIoIjaB/V3q0j/X/+8x+6d+/OlClTGDBgALt27eLw4cNs3ryZb7/9VtsM+QszygagFfAyDuf9m9zki8GD877xeSYOW6SY2upqSgq1ssoy6bWwl5CBhP4jVN765i2PjvRFRUUyfvx4UUrJVVddJVlZWQHudaPGa4WaJwEYa7y+gkOXsBzY5ktD3l5NRRC2Ht4q0TOihQzkmuevkeM/Hneb3263y1tvvSVXXnmlBAUFyaOPPirnzp0LUG+bDH4XhFbGq46G7QOPrXus2kxiwocTPObfu3ev9O/fXwBJTk6WvLy8APSySeJfQajPqzELwpmyM9JtTjchA4mcHimf7P3Ebf7z58/LP//5TwkPD5eoqCh58cUX6/ogjaaO18+b2e3TviLyqXGQYKKI1HrOcnPnoz0fMeydYZRby+nevjsb79vo0oXSYrGwe/duCgsLmTRpErt27WLYsGHMnj2bDh061EPPmzdmDxzvYvbA8eaK3W5nzAdjWPT1IhSKjN4ZPNnnSZd5LRYLCQkJHDt2DKvVSqdOnVi1ahV33OGf8wk03uPpoJCVSqk8HOYQ681U6MnEwtg+jTHqb/QmFpZKCxsObGDCRxM4cu4IrcNbk31Pdq3+Anv37uWxxx7j0KFDgMORfsmSJU0jfmgjxuPUSBwR7qqVaEqpX0gtZhaeTCwMG6MYEXlHKbW9Znpjw1JpodNznSgsc2h6+3buy5rRawgLuTgsSl5eHjNnzuSdd94hJCSEqKgorFYr7du355Zb/HN2scZ3zJ6YE62UekoptQwY7yZrMk7aZOPBr0Yc5hbZxqgxw0U7jUazbLVbGfDGgGohCAsK498D/n2BEIgI69evZ8CAAXTv3p21a9cyefJkDhw4wIkTJ8jJydFKsQaCWaO7cRjO+F7WX1zzhogUG/FTZ1JjRJBGEhZ+x8kd9FnUh9NlpwlWwYQFhV2gIbbZbGRlZfHUU0+Rm5tLu3bteOqpp5gwYQKtWrWqrqcxHKDRXDBrYrEfhxCk4dAy14Yn5/10pwh3cTVNLBoDMzbP4ObMmzlddppR3UZRPKWYnPtz+HbitwTbg5k/fz4JCQmkpaVRXFxMZmYm+/fvZ8qUKRcIgaZhYdYxZ6VSqguOX3B3o4In5/13gBjjoMFl0oiiWVgqLNy26Da+OPYFLYJb8PbQtxmcMBiLxULJDyU8t/A55syZw/Hjx+nevTvLly9nyJAhBAcH13fXNWYwo2zATXS7uroakkJtw/4N0nJ6SyEDiX8pXk6VnhIRkZKSEomNjRUcW8py2223ybp16xryARrNBa+fN7NTo+qjYg0L1GbDn9f8mdsW3UZpZSmP/upRdv2/XbSJbIPdbuf++++nqKgIgPDwcJ5++mlSU1O1z3AjxOxi+ZfGjtEZmsm5CAUlBfR6vRffn/6ey8IuY83oNaR0TAEc/gFjxoxh2bJltGrVisrKSu0b0NgxM2zQzOIardi5QsL+L0zIQHos6CGl538ylS4rK5M777xTAPnXv/4lJSUlje3sgOaA18+b3x9gHFutqTj5IzilxWD4KgDp7uqpD0Gw2WwyYsUIIQMJmhYks3JmXZB+7tw56dOnjyilZO7cuQHvn8Y09SsIeA4LP46fggRvpwE55nxz8htpM6uNkIFcMesK2XFyxwXpp06dkqSkJAkJCZElS5YEtG8ar6mzxbJZPGmW54mTbkEayPbp7K2zuenlmyi0FBIREsHeR/Zyw5U3VKcfPXqU3r17s2PHDrKyshg1alQ99lZTF9R1XMBiVzcNm6RhLu4H9HyECmsFv1nyGzYc2FB9L0gFsbdob7XR3N69e+nfvz+nT59m7dq19O7du877pQk8/hYEs877Lo3tJIAmFtuObqP/G/05e/4snVp1QkQ4XXb6glPpv/nmGwYMGIDVamXDhg06iFZTxpf5VG0XHpz3cSyit+PQOm93V1ddrhGeWP9EtQvl2PfHis1mu8iZ/vPPP5eYmBi56qqr5LvvvquzvmjqBK+f3YCfqmmWujhV81z5OX79+q/5+uTXRIREsHL4SgZeO/CifOvWreOuu+6iQ4cOZGdn06lTJ7/2Q1Pn+O9UzabGmh/WMHT5UMqsZdzc9mY+u+8zosMvVpKvXLmSkSNHcv311/Pxxx/Ttm3beuitJtA0+bNF7XY7Yz8Yyx1v3UG5tZwnej3BVxO+ukgILBYLGRkZDB8+nOTkZDZu3KiFoBnRpEeEI+eOkLIwhYNnD9KqRSuy78kmqcPFIVdLS0vp2LEjRUVFREREkJWVRUxMTOA7rKk3muyIsPjrxXR5vgsHzx6kb+e+FPy1wKUQFBUVMXjw4GrjOaUUhw8fDnR3NfWM3wXBcLdMNXwOXKWnKqU8n4LhA5ZKC18c/YI7ltzBvVn3IiLMuWMO6+9d79KPODs7m5tuuokNGzbQunVrHVi3GePXqZGZ8xFEJFsp5c7v2ScslRaue+E6jpQcAaBdy3bkjMnhmthrLspbXl7O448/zuzZs4mPj+f9998nISGB3bt368O4mykBNbHwxKU4709dP7VaCIJVMB+O/NClEHz99dckJSUxe/ZsHnroIbZv30737t2JjIwkMTFRC0Ezpa7XCMXeZBYfwsJbKiz8z6v/w/P/fR6AFsEt6NSqE9dfef0F+Ww2G7NmzSI5ObnaXOLFF1/UD74GqAcTC3+ScyiHgUsGUlJRwnWXX8fa0WspKi+66ET6gwcPcs899/DZZ58xZMgQMjMzadOmTV12TdPY8EUdXduFifMR+MnMItFdXZ5MLB5d+6iQgZCBPLL6EZd57Ha7vPHGGxIdHS1RUVHy2muvaX/i5kH9+iP486pNEE6VnpL4l+KFDKTl9Jayaf8ml/lOnz4tw4cPF0B69uwp+fn55r9GTWPH6+etUSnU3tv1HiNXjuS87Ty3/uxWPr3nUyLDLp7jZ2dnc99993Hy5EmmT5/OlClTdFgVjVsahULNbrczauUohiwfQqW9khn9ZrB17NaLhODgwYOMHj2a/v37ExUVxdatW5k6daoWAo1HGvyIsP/Mfnq91oujPx7l8ojL2XjvRrq17YaIsHfvXrZs2UJOTg6bN2/m+++/ByA6OprNmzfrBbHGNH4XBBNh4d2mO/Pytpd5eM3D2MTGwGsG8vdr/866t9bxZM6T5OTkUFBQAEDr1q3p1q0b+/fvp7KyEpvNxqFDh7QgaEzjV3+EmpplERnvTbozEVdHSPnYcoIkiK57unLog0NYLI6zibt06UJKSgopKSn07NmThIQEysvLufHGGzl58iRt27bVUaabN/Xuj5CMk1mFccB4nhfp1ZTbyqEY7K/ZibwmkrFjx1Y/+K6OVoqMjOTbb7/VZhIan6gX5/3a0p2d9wkGlgIW9ufl5RXl5eXxwgsvmGmzDVDobUcvsWygy9VHm42prztEpJs3BQKtWXabLk7O+0qpXDkuF9tNe0AplSvifblLKRvocvXRZmPrq7dl/L19Og9IMoztLggL7y5do6lv/DoiiCNg1zzjY57T/Xnu0jWa+qYhK9Tmec7i13L10abuawMp12DDuWg0gaQhjwgakyil4pxf67itmMZ49p0nGoSJha/aaJOHm/cXkSnetGn8o+OqLhGZ5WWbMXDhgepmNOrGJkKci/546mumUiofx0mlpsoZ6Wk4trATvfgbU4HxxqlAscADzrogk99PUc2/08TfOA7HurLmM1Dr/9hMf6rxxWTVnxeeQ8m7TPdUzinfCh/adBm+3kS5aj8MnEJaetHXdJx8N0y2GYML3w4T5VL5KRRnnBflnPOm+tBm1feT7kW5mVVt1fL/vOieN9+7iP/DwvuCJz/n2tIvxT/a1/D1vh6o7rGvxi9btrd9NUgyooOMc7rnqVx/p3ZTzZar+l6UUmkiUrO/nspmA48rpVZw4d/qzf8yxospoOl6G4Ig1KTYx3RP5byus7bw9e7KGUKTDYwwW84wNXElBB7LikixIbjZXvY1Bsg3yvV3M++vWa6KZE8drVnWeICn4NAhuQvpU7PNGUCi8SDHmmjXbL3VNARB8FUbfSn+0b6Gr/f1QHWz7aUC3Ws8lJ7aHFfLQ+ypze0uynjTV1/KpolItjj0Siucftk9WRwUA/OMETffi/+1+WfE3bwpEBeeQ8nXlu62nNOc9CL/aBNtugxfb6JcHC7OiDPTV6e5cCZOR2qZ/H4ScZrze/H9pPtYLq1m3734ftK8/V8a5cZVla3R5kX/Y0/PjqtL6xE0GhrG1EijqXe0IGg0aEHQaAAtCA0SpdQ6b80lDD1CZl31qamjBaFhYspXw7D7GQfVyqpL2WNv1jQIWyONb8iF/h2aS0ALQgPBybAsGyPIgTE9qjJSKxbH2RLpOIzPinCMAMPk4mghF5UL2B/SSNFTo4bDTOCdKu2p07184xpWJQTGg53kZjp0Qbk673kTQAtCwydXRPKMX/1kDCExzBTMltN4QGuWGwjG1Gg4kIvjF73KQnM8sMzIVoxjypOPY9oDMB/oh8MMoep9rHM5qSV2lOYntCBoNOipkUYDaEHQaAAtCBoNoAVBowG0IGg0gBYEjQbQgqDRAFoQNBpAC4JGA8D/B1XhgNOqKLp9AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1)\n", + "_, df_train = plot.calibration_regression(\n", + " mean_vi_train, sigma_vi_train, y_train, \"Train\", \"black\", ax\n", + ")\n", + "ax.set_title(f\"Train {errors.ace(df_train):.2f}\")\n", + "k = jnp.arange(0, 1.1, 0.1)\n", + "ax.plot(k,k,label='Ideal',color='Green')\n", + "savefig('Calibration_VI.pdf')" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1094,7 +1196,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.13" }, "vscode": { "interpreter": { diff --git a/notebooks/regression/olympic_data_comparison.ipynb b/notebooks/regression/olympic_data_comparison.ipynb index aa60be0..b46c943 100644 --- a/notebooks/regression/olympic_data_comparison.ipynb +++ b/notebooks/regression/olympic_data_comparison.ipynb @@ -71,7 +71,15 @@ "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/regdata/dataloaders/base.py:202: UserWarning:data not found. Downloading...\n" + ] + } + ], "source": [ "n_points = 300\n", "x_train, y_train = rd.Olympic(return_test=False).get_data()\n", @@ -1069,6 +1077,116 @@ "savefig('MCMC.pdf')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VI" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from utilities.vi_helper import vi_model,vi_predict\n", + "params = [[32,16,1],[nn.relu]*2]\n", + "\n", + "mlp_model_vi, vi_model, results = vi_model(params,x_train, y_train.flatten(),variable_noise=False)\n", + "\n", + "mean_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).mean(axis = 0)\n", + "sigma_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).std(axis = 0)\n", + "mean_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).mean(axis = 0)\n", + "sigma_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).std(axis = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/olympic/MLP_VI.pdf to figures/olympic/MLP_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/olympic/MLP_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAApmElEQVR4nO2deXRkR33vP3W33tRSax/tGs3mZbDNeOyxzTZ4AUwMgWDsJITEIcFmSQgJeSRsL+8BSZ45Dw45iU3skGfnERK8EBJ4IRCbJV5IvI2N99lHs2gZra2Werl9b9X747Y00mjrllrdLc39nKMjdfe9data93vr96v61a+EUgofn3MdrdwV8PGpBHwh+PjgC8HHB/CF4OMD+EIoKUKITwohbs39PCOEuFYIcZcQoqfcdTvX8YVQWvYppe4G7gdGlVIPA38ExJY7UQgRE0Lcms9FcgK7fZnPdwkhbpz13o259z+ZzzU2Gr4QSsvTZ7+hlBoHjix3olJqPCeiZckJbMFeRghxLRBTSu2bPib3Xl3uvAfzucZGwxdCCcnd9AuxWwjxQM50mr45b8y9juVeXyuEuGvW3w/lnuqLPvkXqcPDwKeEEA8AD+fevm66XODaghu2AfCFUAHkbk6UUl9SSh3JiWEc7+n8qVnH1J11/D5gRAixK99r5cr+I+AhYFpEMeBIrtzrpsV3LmGUuwI+M4zO+rsHeC/wAIuYOHhCWQk3KqW+BCCEmBbGMyssa8Pg9wglJve0vQnoyZki0yZJz6wncQ/eExtgNGcC7Zo+ZtbfPcBlwO6zrjFd3q5Z70072g/mzK5deA77kZzvsSt33kNLmHAbFuHHGvn4+D2Cjw/gC8HHB/CF4OMD+ELw8QF8Ifj4AL4QfHyACp5Qe3PHVvXgRz6FSqVwJ5MYLQ0ELtiCFq0qSvlyIoGojmJtbkXo+pzPlJ1FKYXZ2TLvs9ns3bsXgJ/+9Ke48QTuB26Fxx4j/sE/pOELH0MIUZS6Vgqz27vOWPYfUbFCGE0nARChEEYwiBydYOonT2JubiewtRMRCq6qfK06ioxPYPcKrO5WhHamcxSWiZpK4Q6NYWxqWLSMa665ZuZv5bhY//QtAAI/fxKkhCVEtB6Z3d6NRsVOqF3c2Kp+/AdfmPumkrjjEwBY523G7GpFmOaqriPHJ9Dqa7yytLmWojsxhdHSgF4TXbYc5+QARkcLAFN7rib8yA8Q1urq5lM0lu0R1pePIDT02hh6dRX2K4dJ/ug/yZ7oQ7nOiovUYtW4I3Gyx/tRUs79rCqE0z+MTGWWLUcNDp45b3QI5coljvapNNaXEKbRDfS6OkQwSHrfqyR/8hRO/+l5N3K+aDVR3OHxeWIQmoYWsnD6TqOy88V2/fXXc/3113sv+vvPnJdKoZyVi7NSmdPeDcb6FEIOYVoYDXUIoZF64gXSjzyNMzRKoeaeEAKtJoociWMf60O57qxrmKAU2f6heUJLpVKkUinvxSwhaHYKlbFX3rAKZU57NxjrWgjTiGAAo7Ee5bikfvYc6cf34Y6MFSSIaTGosTj2kT6Uc0YMWjgIGRv39MIiU0rBwMCZ49MpZGrjCWEjsyGEMI0IhTAa6pCpDKnHniX9+HO4I+MFlaHVVKMmE9iHjqPsM+aNFgnhjk8sXJ7rwtDpM/WwM2D7QlhPbCghTKOFw+gNdchkkuRj+0g9/izuaDz/86NRVCaDfeAYMm3Pej+COzyOc3ZZUiGGhs4cZ6dRyY1pQmxUKnYeoRhokQhaJIKcnCL56DMYTXVYOzaj19Xkda5KpbFfPYK5pQM9GvHMp6oQ7uAICLjhhhsAvBGi07N6BOki4xNr1q5yMd3ejcj6mkdYDUohJ6eQaRujpR5rezd67fKCULaNSqYwutowGmu996RETqbQG2Lo9TFkYgpe/3r0F34+c17y298n/Esbc4RlHbJ+Z5aLjhBo0Sq0KoUcS5D8j2c8QWzrXrKHEJYFuo5z9BQymcJsb0boOlo0jDsS90QgBMbIyJzz1MjYWrfIJw+c8QmMWPWyx21IH2FJcoIwGmo9QTzytOdDLDHKJHQdUVuNHB7D3n8MmcoghOAtv/k+rvuVm8BxEWOeENxNbd5JoyMLlrWe2bt370y80XpAZR3kUH4PpHNPCNPMCKIOOZUk+diz3jzEwPCceYQzhwu0mmpwJfZLh8ieHgGlQBNoyvUm0XQDt86LTVIjG08I6w13aBQ5ld+gxbljGi2GEDNOtUqlSD/xPCIUwNrWhdHaiAgE5h4eCoJl4pzoRybTiICFOtUHgBuMQNSLSxKjvmlUTuRUCjc+mXdwZsmEkEtV0jP9M51bp5IQoRB6KITK2mRePEDmxUMY7c2YXa1osehMUJ7QdfQaz69QqTTZ/3waE5DhKojlHOqR0cUu47PGKNfFGRhGCwfnDH8vRSl7hJuAh5VSD+YyQd9dqflzhGmh19WBcnEHhnBO9CPCIczuVoyWRrRI2DtOE6Dp6Bmv+3XDVej1dV4h8fEy1d7HHY2DlAgzAJUmhLMT2C4kglwSqlsB2qqWH9pcc4SOVu2NOKisjf3qUeyXD6NHI+idm3jXnjcggibi1HEA3OpaT0CAmJhAKbWhFufcdNNN5a7CsshUBnckjhYNF3ReyX2EXNrx9y70WU4sd4M3j1DKei2HMC30WgvAm3V++Si/Fm0HwH3iAe93bSOiwROClkpuuMU5H/nIR8pdhSVRUnomUdAq+AFUUiHkUgqu+7TjIhBADwRI2hlQitoJzx+Qjc1Q740aiXQS3I0lhGTSWzUYDhf2tC0V7mgcsg6iKlTwuaV0lq/Fy748ipfV+dJSXTsvpERk0ohMCpHO/c6kEelk7ndq7ufpFK/uf54ax6EDhQP8x9ET3HDjWwAQ6RTKlctPaa4j3v72twOVuWZZpjK4w+MFm0TTlNJHeJgCbn4tk8L6+ZOgaaBpqNxvhAZCeGP40vVuYNcFJb2/pQuOg7AzCNtGZDOQySCymTPv2WmEbYOdRpu+ye3lV6Gdzetm/f2IGeMnr77CDQ25HiGTQmZttFBg4ZN9isZqTKJpKnYeQU+ME70vrw1iioISAmUFUcEgKhBCBUOowPTfud+B4Jn3g2F+/1/+gdOpFNe17iA+nuKCK3ehNdQDoGXSqMk0VC+/3tlndazGJJqmYoUgAyEy23d6T3klZ57+SJmb0dVA0+f2Fpru/a3r3k1tWWAFUGYAFcj9tiyUFYTcbxUMIYMhMC2vjAJ45b9+DMCvfvDj2AMDRN54OdR68wiancbxQ7HXnNWaRNNUrBDcaIzJX7mt3NVYGqXAlbijYxhVVejNdRCLASCyGdzx/NdA+BSOcl2cvtNooZWbRNNUrBAqlulwbtvmly/cjVYfI3TlJSjXRa+p9nqaYAiRTuGe6ofdryl3jYvGLbfcUu4qzMEdHj8zcbZKfCHki3Rx45MgXfRNDQR6OvjAL7xpJhOeMzCEHouCAFVTg0inkIePlbfORaaShCAnk7hjE6s2iabxhbAc0vEEoMDY3Ia1uW0mxGJ4fBzlZKk3g6DraNURhKYhm5phcACOHEO57pJpI9cTw8PDADQ0LJ79rxQox5mJJSrWzL0vhMWQLm48AQrMrR1Y3W1zIhmVlLzvM58AofHQN+9DaAItGABNoFpa4YWfQ98p5GQyr0x564Ebb/T2Jy/nPIJSCmfQC3EXRvEeML4QzkZJ3PEEKIm5pROrpx0RnGuDKttGTqUQlomwTPRoGGlnIfePUa3e4hxtcABnfGLDCKESkPGE93CJRopari+EWcjEJCpjY3S1Ym3rQovMH5dWUymUklg7NnvJgl2JUqA31iJ0HaUUP3zxRW4ARl99ifr+YQJdbaVvzAZEpjM4g6ML/l9Wy7m7Qm0WKpPBGR5Fi0YIvekygpect+CXLRMJsAys83vQQgEvhCJgYm1uw8glAhBC8I1nnwMgPD6Ec1bmPJ+VoVwXp38ILWDOS9ZcDM5tIUgHd3QM5TiELt9J8KpLvJGfs1BK4Y5PIKJRrG1dnkOcyXpT+gs4wue/9704QL20cY+fynu5oM/iuENj3trwNcowfm6aRkohEwmk4xLY3o25pR1hLPwFKymR8QRGcz1Ge7N3bjKD2dnChz/60QXP+cztX4L/fBzRewTt5RdxxybQq4uzwUk5+fCHP1yW67oTk7jjCfTq4voFsznnhKDSGdzEJEZLA6Gd22aGQhc8VkrUeAKjrQmjtREAmUhitDejhYPcfPPNC54nLAN3+/kYvUfQDx/AOT2C1dW6Ju0pJYu1dy2R6QxO/zBaZGUbw6iMzeHGN3RsGXr0xFLHnTum0bQZpCShqy4hePlFS4vA9USgd27CbGtCCIFMTKE31c2MWJw4cYITJxb4fnUNteu1AJj9vTi9/RvCT1i0vWuEchycU6c9v2AFczEyY2Mf6gVoX+7Yjd8jKIU7kQDXxTqvB7OnDaEv3WzlSlQ8gd7VgtnsRZPKZBq9umpOMrD3v//9wPxxdWGZcMVVAJh9vaQGvc1G9KrKXNCSL4u1dy04M1+gVuQXyIxN9sBxkApg2c0qNrQQVCqNOzmF0d5M4PwteQ27KTlfBMrOgq6hN9fnNZMpdB2xZQsyWoOeiMPBQ6jEFKxzIZQSd2R8xfMFMm2TPdALkPd6kI1pGrlZnFw6ldDrdxHavTM/ESjlmUMdzWdEICUyk8Voacy7exahACIUxNlxIQCPfO2v+PwnPrnCxpx7uBOTXmj1Ch4cMpkmu/8oCBDh/P2KjSUEJXHH4rgTSQI7txJ6826Mhtr8TlUKOT6B3to0ZydNOZnC2FTvhU/kidB1tKoI6lJvQV63TPPYt79bWFvOUWQqg9M3hBYpPI5ITqWw9x8FXS9419UNIwQ5OYUzPIbRsYnINXuwejqX9QVmo+IJ9KY6jLbGmX+ATKXRqyNoKxj61KpCyJyf0OGmufmN1yBXuMfbuYKys2RPDqCFAgU7x058EvvVo4iANS8kJh/WvY8w4wc01BHc8xr0PDIfn41MJBCxasyOTTMiUFkHFOhNdYs+mT7xiU8sWqawTMTOnUgrQJ2d4cY9e1CTSciJ6qMf/Sh33XUXt912G3fccUfBdS4HS7V3tSjHIXtqEM3QEWZht6UzPE726Cm0aBhhrOyWXrf7I6isjRufRI+EsV6zFb0pP0d2XjlTSQhYWFs7Z6IZlVLIiSnMzpYVx7UopbD3H0W8+11Yr75A/J3vJ3zn/8ZsawLAMAxc10XXdZwNuANnISjXJXvqNNhZb7+6fM9TCuf0KM7xfrTqqgV7EZmYZPCWz16+ZejRp5Yqa/2ZRq6DOzqKSqUJXnKe5wc0N6xMBKkMaDpWT8eckF6VTKPX1Swrgv3797N///4FPxNCoEXCyJ0XA2AcPUi278yuOrfddhu6rnPbbRW+HHUWS7V3pUxnoCBjFyYCKXFODOAe70eria56zcf6MY1mwqMV1o4ezO5WbxOPlRZn2yjHwTpvM8I68zWorAOaht4QW7aM6Zt4sXF1EQmhLr8cHvx7zMETJF8+TOjSCxCaxh133JGXSaSUQmVshGmUfYHPcu0tFKUU7uAIaiqNVkAGCuW42Ef7kPEJtFh1URbnVL4QcnFBKutgdLcR2NZV8IjAvCJdB5VMY523ed44s0ymMTtbinLTaQELLr8cpWmYQ324Bw6TPdbnDcUaOmgCNG3Bf6RSCjWVwhkaBTsLhuH5MGsUdFZqpkXgTkyhF7DcUmZssodPoDL2ivzBxahoIXh7nmUwWhsJnLcZLbr6wDXlSs/+39o5b5xaJtPotcubRPkiLBMRi5HdfiHWqy9gPvojMldeipxMeg6hriPIpabRveRlQtO8jc2zWaTjInQNFPzp5/6Ev3vgPm744G9wx513FqV+5UIphXt6FDeeKGjCTE6lsA/2IoRWlHthNpUrhJyNHbxsZ167YOaDUgoVT2B0tWDUzn2aTG8wrjcULwu3MA1EKIB93Q1Yr75A9LF/wzl+AKJVXsY+TXCyr5+hkWHqGpvo6tkMmu4t8LEsaNyEfcFFsP18nvqn79GgdP7xb/52XQthRgRj8YIicp3hcZxjpxDh4KpM4sWoWCFokRDB111S1LTqMj6B0dI4M2s857NkGqOtacXDb4uhR0Jk3/pWph59hMhzj2P2Hprz+ebcD4MnvZ+zCAKZ7h28f9flfGvfU7zzXb+EyjoFDzFWAjPmUDyRtwiU6+KcGsIdGEbULDwyVAwq99tcxHZeKTKRQKutwWhrnP9ZOoNWFSp4Sv+zn/3s8geFg4jaajK/cRvJ3W/k1Ud+TF/vcTo62jj/vPPY//IrnDp1CoVCKIUQcM3eN0M6iTkySPDgCwSO7eem8RHe9gefwrjsIpzTI5htzQXVtRjk1d5FmN7FRk6m8jaHZMYme/QUaiqJqC2OU7wYeQlBCHG1UurHq71YbiOQI0BMKVWy9PBqKokIhbC6Wuct81NKobIuZltzwV/0tddeu+wxWiCAFg5hdrehNdZz59/fw6CTYqr3ID+9714uDli81rL4ky9+gW/+wz/Qs3UrH//Rd/nAe97L73/oQ6SffYnoPV/FHOojes9Xmaj/PHptNXqsek3W7i5FPu1dCGVnvXkCx8nbMXbikzhHTnrZQWqK5xQvRr7zCG8RQlwthFhxjXIbhDydy4p93UrLKZTF5gpmPp9KodfXIAKF253PPfcczz333JLHCMtACwURkSBa2KLlfe/kESbZ/cH3Edp1IaELtxHY1sX/+ruvcyKb5PGDL3PCSfHFB79J+JorCFz3OhK/8yncaAxz4AShb34d+9gp7CMnS77GIZ/2no2cSmH39oGS+QU+Oi7Z4/1kDxxDBC3EEmtGiklePYJS6o8BhBD3CyGeBB5USh0r8FqXMWuTECHELqXUvtkHzN46qqOhqcDi5+PNFWSxzuuZM1cw83nWAV1Hr12Zvj/+8Y8DS4+rC8NAWAZGSyN6SyNfefe1fOXrdy16/G233TYTemE01ML2zSjHJfFrH6HmrtsJP/s42e9uJVsdRW+MYW6ab+qtFfm0dxqllBdKPTyGFg7llYNITqU8U8i2izY/kC/5mkZPA4eBP1dKPSuEqBFCXKKUem4V1x4/+43ZW0e9dsv2VcV+LDVXMI1MZTDbm9d8okqvjyE0bd5I1UKcPdFmNNfjjk8gr7uaqSMHqPrhg0T/9VuMb92OVleNXhVZUbjyWqKyDs7gMHIqhRaNLHtDK9fFGRzBOTWEFg6glSGVfr6m0Z8rpW5WSj2be30t3q43hfAUEJt+oZQ6UuD5eTM9V2Bs6Vj0Jpl2kEUJ7Gy9ugpzAf8kH0TAwmhtwmhvxv31W0hvvwjNzqD91e387tveyRc++BFv4VCFICeTZHv7UGkbPQ8RuBNT2C8fxR0YRqupmrevdanI6z+jlPr22a9X4DzfDewWQuwCHirw3LyZmSvonD9XMOcY28FoXDyytJLQa6sx6mowO1v5ZlsPo8IglprgQ26Sfd/5V+zj/d4kXBlRWQenf4jsiQGEZS4bN6QyWc/X2X8UdIFWHV2TfEX5Usqto8bJmT3AviUOXRVyYvG5gpm6JNMrdpDLgdB1jNYmpJ3lnsd/ygkjxmezw1zuJPjklh3Yh46jVYUxZy0oKhXK9bKEu8NjCE2g1yw9P6BcF2dkHPfEoCeANfYFlBfZu6ytVbnzCCtATiTQamMY7Ys72spxQYiizFb/2Z/92arLyBctFMBsb+bN73k3//Htf+ZtWy7misM/58ojzxN/5HHsSAhhGnz8s5/mr//m7jVZ5zC7vSrr4CamkKNxUMrLTL3EE125Lu74JM7JAS9RV1V4zXwzpRQyPonTd5rMS4cArgGWtGAqdj3Ca7dsV4/f8495Hy8nJ725gq0dS37BbmIKo7Wp6ElkS4WbmMLp7cM+cgLzU39I8OVnyTZsYuqLXyGw+2IuvnIPGTeLrQlODp9GmKY3C23oq37yqqyDytjIeAJ3Ko0QufXZSwnAdnDG4rgDQ6isi1a18sUzyyEnp3AGh3GO9uMmkwhNByFI/J9/+uKWoUc/t9S5G6JHULnM1FbP0iJQGRsRChZtlOVnP/sZAFdddVVRyssHPRqBzhaUnSX9B5/G+MMPYw4PELjjq9h//N95/7vew99/50F+7eabvRBnpRBCoABhmmgBEyzTi4wVwgv8y0XBIgAFKOX5HI6LdBxIZZDJND978kkAXnfFniXXFCulUMk0zvAYcnjcW0gfCaFFinu7KaVQk1M4p0dwjw/gJqZAaGiREEa9N5YjE5N5lbXuhaBSGRACa0sXwlxCBEohM1nM7sai2aSf/vSngdLvF6BXV6G62zBth6mP/Deqb/8MoReeRPzp5/jY+36L33vvr6AFLLLH+9GiVd4SRsv0dvOxs6hkGim9LN4IEEpBTiwwrQdvC1+hCYShIwImf/KXXwHgR2/6zoL1khlv1aAcHkUlM17QYTRSVCdYSYmcmMTNrUxzkykvGjUSQq8vdCDzDOtaCHMW1wSWjtNXqQx6bXVB2SgqGb22mkDnJv7yleep12P8phol+PKzBD/zO2QbNuHseA3Z3XvIXvxatKrITA8gQgFvljsSQQTMGdNpJTersh1kOoOcSiLH4jMPJREMoBVxrYCybdzxBM7AMM6pQcj5eVpVeObJv1rWrRBUNotKZZacMJs51nVRSqHXFy/EutwIIdCb6vj6977DZgKMBzbxO5s3ETz0EubwAObwAKHHH0KaAbKdW8huu5Ds5VegbdvmhTKb8RlLSIAX2hwwEQEDYVq59RHTPae3e6hK2yilyOw/hkqlwc0N2Wq5m79IMUHKdZDxKdzxOG7fEO7YBEopNNNAj4RAL/7ipHUpBOU6qKkU5vbuvOJXZDKNsalhzZy0ciEMg1/4wK/zk3u+yf6WJn736CDXX3QNV29uI3BkP/bLzxJLJQgcfpnA4ZfhBw/gVtdid2/H3nI+7iW70bvauffB+/nuv/+Qts52Th0/ydvf9jZuu/WDzLGVhJgeigTXRYRC3qKhVaKUglQadzKJHJvAHRr1bnxAINBCAfTa2CxRrg3r7s6YvcIsnzThys4iLGtFuYnWA39x19dwb7+di5vaqZMa/S88ybUf+1vS43v58qc+S0Oolh43xdu2dcL+5wlMjBF6/glCzz+B+ue/w2np5OKhCar0IEdPjRPRgux76EfI972fr3/j//K9f/8hb73hej72ex+fufFXsjBGKQW2jUrbyGTa2xVzeAx3fAIc17vxhYYWDKDX1ngLl0rIuhLCdF5So7s1r7gd8NLAG52tazJp89WvfrXoZa4EPVbN1R94H9/623v40HvejbmlEz2eoH7vHh770Y/pb+ngZ8dGiFjtbDIydLpJrmlrQu89iNnXy15gbxbSQueQFmAw2sA9t9xKQjN4k1lN+oc/I3nRlXx+7zvRggHsw71oAW/jRO9JnTOychuwS9tGZRxkMgXpDDJlozIZbwRrutJCQwta6FUR0Mq/6+i6mUdQSiHHJjDamzBb84tMlckUWlVkTgrHjY6SEpXKIMcncEbiuIlJfvGat6BLCAhBlWZwcWMb9ug4hnJpcSbZItNcVRNBHxmcU9agMDigB3G27eSid9+E1ExwsihXgXS9+5/ZDxjlmVNabkjW8DJvCEMDXS/5Ux684dOJv3lw48wjyHgCY1M9Rkt+YcdKSpRU6PWxNavTww8/DKx8wcpaIDQtN2YfQm+uR06luOymX+T/fesBfvkd7+JjH/gtbvqlG9GkIiJ0Rswqqi68guE3vgnTTmL1HmLksR/RNNJPs3Jodibhlf9CHXiKI4EI/2W7xHZexhU3/qo397BBWBc9gpzwUjJa3flHcLqJJHpT7cwmf2vB3r17gfLuO5wPSinPNh+fwIlP8qX/8T/5wXf+hRvf+nZu/eBv53rbBNmjp3AnkwhNQ48EMPpPYh58icmnHqchMTonQlOGq8huvYBszw6yPTuQ9c1r7tCuhA3TI8jEJCJahdXVkrcIVNZBWIa/v3EOkZs/0EKN6A21fPrOv+CPv/h5VMbGTaVRIxOIWoHeVI/KZnH6hsgeO4VTVYd91Vv48yd+TiAYZDRzmjcqybsCQaqSkwSef5LA895ss4zWkO3eTnbzdpzNO3CbWipSGItR0ULw1hoHsTa3FRSgJVMZLxnWBuq6i4UwDYy6GlQs6uVXGhlHRsLgurijEzCexezchLW9E3c4TvbQcV53wcW88PJLPBqN8VJ9Hdf/9n8jO3Ia8+BLmEcPYB7dj5aIE3jhKQIveClGZSQ6I4ps9zbc5raKNqUqWAgK9MXXGi+Gt8VTpOQL29cbQtPQq6vQopFcRr0xb56luQ45GscdGkOrChF83Wt522u2cXVvHw/f/jlveat0kA3NZBqayVx5NSiFPtSPcfTAXGG8+AyBF58BQAZCOJ09OF1byXZtxenoAatyZvkrVghC0zC3dS641ngxZhzkPDcH8cmZTVVhzEjIE8TgCHp9LVpjHe7pEU8QwQDBi3ag1VajMlnkVArlTKJHq8A0QQjcplbcplYye/Z6YdkjgzOiMI4dQh8fwTr4EtbBlwBQmobb0kG2cytOl/cja8r3f6tYZ/nS11yknvjhwwWdIydTaA0xjCJlxluO6czQO3bsKMn1SoGSXrJld2gMzTK8rBInB5DxKQ7FhxG6zraWNpyTg9j7j6Cy7hlBLIEWH8PoPYTRewjz+CH0/hOIs1bVubF6nO6tZLu24XRtxW1qXbU5tWGc5XzxMlJo6LHSOcgbSQDTCE3DqKvxMvQNjqBsB3NbF3I8wbbjhjdpZhiYm9sxOppxTp7GfvUwaiKBXh1dVBCyphb7osuwL7rMeyOTxjhxFLP3EMbxQxjHD6OPj6A/N0LguSe8c4JhnK4tOXNqG057N5hrs6pwwwihHA7y9773PQDe8Y53lOyapUIELMyOTbijcdzTo+jRCP/WfxR3NM5bt3ubMwrLwuxuw2hvwjk+gL3/KCox6T2MtGVurUAQZ+v5OFvP915LiT54yusxjh3E6D2EHh/F2v8C1v4XAFC6jtPahdO9zfMzOreiqorz4NsQQpCpTFkc5C9/+cvAxhQCeP6DUR9DCwVx+k7z1bv/GnSNd9zzDpyjfcjMJKIqgjBMzJ4OjPZmssf6sPcfA014w9f5zibnfAa3pYPMFW/23hofyQnD6zX0gZOYJ45gnjhC6NEfAuA2bCLbvRWnyxOHrG9a0bDtuheCkhLlSt9BXkO0cBCzq9WLLXJd9GgE7fwe7BP9qOFxyCXnFZaFtb0bs72ZzIFessf60MJBtBVmq5OxeuxYPfbFewAQ6STG8SNneo2TR9GHB9CHB+Dpx7xzIlGy3dtmHHA7ungSh9msfyEkM+gNsQ2zgUalIkwDEbBQdhZ3YgotGsba3IYbjeD09kEwMLObpQiHCF5yHmZXC/aLB3GHRtFiVd46h1WggmGy23eS3b6TFIDjYPQf95zwYwcxew+hTSUIvLSPwEteohRlmEzE9ixb9roWgspmwdRL6iCf6wjLRG+qwx0cQYuGMRpr0cJBskdOIScn0arOhLvrtTUEX7cLp28I+8UDyMSUt8tNsaJNDQOno8ebk3j9W3LDtqcxew9iHDuE2XvI6y3yKao4NSoPMplZcQY5n5Vj1NWAAHfAE4MWCWGdtxn7WB/u+IS3uV/OTheahtnejNFch33wOPahXjTLKvqON97FxJmJvktfD4AaHIBv/mD5NhW/NqVBJlPeGuQCdmIsNt/4xjfKdu1yMLu9Rm0NQtNw+4egKowwdawtbTinLO+9muicFWzCNAlcsAWzvYnM8wdxhkYxYosPtxYLFc5PcOtSCMp11zzEOh86OjrKev1Sc3Z7vVEhgdt32hODpmF2NEPAxOntQ6uOIPS5t5hWHSV41SU4JwfIPH8QBOg11WUP0FuXNoVMpjGa68u+fdJ9993HfffdV9Y6lJKF2qtXV6G3NCInUzP5V82mOsytnahE0vPjzkJoGmZnK5Fr9qA31eEMj6KydknasBjrrkeopDXIX/va1wC4+eaby1yT0rBYe/WaKMqVuKdHvRxKQmDUViN2dJM90AtBuWCWaxEKEty9E7dtiMxzryBVqmy9Q8l6BCFETAixSwhxY273nILxFphkvIwU6yjW/VzAqKtBr69BTiZn3tOjEazzNnupIlOZBc8TQmC0NhF+85negQV6kbWmlKbRTcB4bu+0m4UQsUILUKkMWm3NsnmMfMqD3lCLXl2FnErNvOeNKPV4aSRnvX82071DcPeFXmqXiYlSVHmGkglBKXX37M1Bcmni5yCEuFUI8bQQ4unh0dG552/AJF0bjemkY5gmclYPoAUtrB3doGuoqeSS55vtmwi/+TK0aJXXO0inBDVfAyHkTJ+zf2KzPv8k8N6Fzs2JZbdSandD3dxUfnIq5TnIGyxJ10ZD6DpmW5PXA2TP3MQiYGJt7wZTX1IMgLfR/JUXE7hgC+7oBCq1eE9SLIp+Vy21bawQ4lpmbSiYd5kZGxEKoVVYKvcHHyzZDrkVQb7tFaaB0dpE9sQAmq7NTHgKy8Da2o196Ji3DHeJGCSh61jbutDrY6SfeRk5FvdmpdfINyyls3wtcDtwF/BAvudNZ7E2mitvm6eGhgYaGs6dnEmFtFcLBzE21SMn5z7NhWVgbekCXV/SZ5hGr6sh/MZL0VvqPVPJXRtHupRbRz0MXFrweck0el1lZrG+9957AbjlllvKWo9SUWh79Zqol0ZmYnJOiLwImJjbOsnu70Wl0ojQ0tEBImAR3HUhTn2M9PMH0MMhRKi4IfcVPaE2s81TmWeQF+Pee++duTnOBVbSXr2xFixrjvMMoAUszO1d3iBIZuGh1dkIITC72wm/4VIvs+RYvKB6LEdFC0Em0+jN9Wu+D7LP2iE0DbO1cZ7zDLnRpG1dqEwWZec3s6zX1hB+4y70xphnKim3KPWsXCFIhRYJVdxm2j6FI0wDo6URlUpzdrIILRLC3NaJSqbOpJ1frrxAgODunQR2dOMMjxclPKNyhSDAaKo8B9lnZWiREFpD7TznGbwZaGNrFzIxhXLz2y9a6DrWeT2E9rwGmUgik0sPyS5bv1WdvZbo2rrZB9knP/T6GCIUQKbn+wRGLIrR1YqKJ+b1GkthtDQRftOlXhTs+Mr9hoqdnRJrHKdeDL7//e+XuwolZbXtFUJgbmrAPnYK5brzfD+zqQ5sB2fwNHpN/hEEWnWU8Bt2kd73irfOoS5WcAr6yu0R1gHhcJhw+NzxYYrRXmGZGJsakFPpBT832hrRamPIRKKwcgMBgpfvxNzcjjM0VvB8gy+EVXDnnXdy5513lrsaJaNY7dWiEfTqCDI5XwxCCKzOFkQotGwoxrxzdYPAzm0Ed52PO5bIa1h2pk4FXclnDvfffz/3339/uatRMorVXiEEemOdF1bvzB/+FIaOtbkdoKCbebpss6uV0Otfi0ylcRfpec7GF4JPWRCm4ZlIC/QKkJt93tqJSmVQbuERqHp9LeE37UZ4bsiyDqcvBJ+yoS9hIoE35Gr0tCPjUzPLQAtBq4oQesMugCeWPbbg0n18isiMieQuPENs1NVgtDUhJwpznqfRvK1wTy573IpK9/EpEsI0MJrrFx1FAjBaGtBqqpGTk2tXj0rdH0EIkQD2l7seZ9EADJe7ErOotPrAGtVJgAgLPaSUQkEhdpJQEEwpd8tSB1XshBqwXym1u9yVmI0Q4ulKqlOl1Qcqt07LHeObRj4++ELw8QEqWwh3l7sCC1Bpdaq0+sA6rVPFOss+PqWkknsEH5+SUcmjRjMIIXYBPUuliilhXWJAz/SPUupLZarHrcARIOZ/L0uTz/2zXnqEa4G6ZY8qDatOXblacknSns5lBrmu1NdfhLJ/L0uw7P1T8ULI5UMqbOfxNSSf1JUl4DJg5rq5J15ZqZDvZR753j8VYRoJIW5c4O2H8bqzh8vxj16sTtP/4KVSV5aB8XJXYJpK+l6EELvyvX8qQgiL2W5CiGlF7wK2CCFipXrSrEXqyiLyFBCbfjH7SVxOKuB7mUe+909FCGExlFL7AIQQlWIHz05dOYpndxacva8I3A3clMvw8VAZrj+PCvle5lDI/ePPI/j4sA6cZR+fUuALwccHXwg+PoAvBB8fwBfCuiO3M+kz0zuTCiFuz/3Eyly1dY0/arQOEUL0AHcppa4TQtxYCbFG6x2/R1iH5CbQ9gkh7qKCwk/WM36PsE7JmULPKKWWXJTukx++ENYps2KhLlNK/VFZK7MB8E2jdUjORziS8w125V77rAJfCOuMXEzPA3iLcsCLNXrAF8Pq8E0jHx/8HsHHB/CF4OMD+ELw8QF8Ifj4AL4QfHwAXwg+PoAvBB8fwBeCjw8A/x+p2Ig1dhs20wAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "train_loss = errors.loss(mean_vi_train, sigma_vi_train, y_train)\n", + "\n", + "ax =plot.plot_prediction_regression_without_test(x_train,y_train,x_linspace_test,mean_vi,sigma_vi,y_min=-3,y_max=3,\n", + "title=f\"Train {errors.loss(mean_vi_train, sigma_vi_train,y_train):.2f}\")\n", + "ax.set_xlim(-4,4)\n", + "savefig(\"MLP_VI.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/olympic/Calibration_VI.pdf to figures/olympic/Calibration_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/olympic/Calibration_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAfyElEQVR4nO2deXhU1fn4PyebCQkhRhAEZAlSQVxjgNovIEsSFcUNBIG2KEJYatX++BaL7aOx/fpQsKXWCoFIY3FBBFIRkS1hh0IhIAVUkE2QsicEyEKSyby/P+ZOHEJm5s5klizn8zz3ycy9Z3lnct+555z3Pe+rRASNprETEmwBNJq6gFYEjQatCBoNoBVBowG0IgQEpdRkpVSacexUSiUrpeYopRICLEdCtb9xSqm4QMpQV9GKEBh2iUgmsBAoEJFc4GUgzl1F42ZNM9OJoWDTnLUDzFFKzXE4nQwsUkrlGAqaaKafhojSy6f+RykVJyKFxs24SERSHM/7uK9FIvJUTTIACSKyy+FcgogcMV4nGwraKNFPhADg4mZPUkotMoZO9uHKEON9nPE+2f4rbrzOUUolOvvld0OS0UaaIZddCYY0ZiUArQhBxX7zich0ETliKEMhsBiY4lAmvlr5XUC+J0MZESkUkUyjjepPjO61/Sz1Ha0IwafA4XUCtps0wThqotCbToyJelwN55O9aa+hoRUhQBg34VAgwX7zGX8THG7QBCDHeF1gDIES7WUcXidg+xVPqtaHvb1Eh3P2ifZCh74XOVSLAw777IPWU/RkWaNBPxE0GkArgkYDaEXQaACtCBoN4AdFcGXmN66nGWWG+LpvjcZbfK4IhsGmxjVwpdRkIM8ok+LrvjUabwn00Kg7Dgah6pZR42mRp5TK69atmwD60Iepo6CkQG55+xZRrytRryvBQ4I9Ryh0fGO4ACSJSFJUVFSQRNLUN5YdWEabGW04fOEwEaERAMWethFoRdiBg+ux3elLo/EGq9XKqE9HMWjBIMoqy3i97+t8/4vv4V+Yclt3JMzXwjma+e0uv0qpNMMfPxMYqpSCH1wJNBqPOX7xOL2yevH9pe+5PvJ61v58LT9q9iM6duwIZ8kE5nvSns8VwZgI31vtXKbxtxCbMgDsQqPxgr/v+jvjvxiPxWohNSGVz4d/TkRYBLt27eL8+fMA0Z626XNF0Gj8hcVqYdD8Qaw8vJJQFcrsh2czLmncVWWsVit4MUfQiqCpF+w5s4d+8/pRUFpAm6Zt2DJ6C+3j2l9VZt68eYSHh1NRUdHH0/aDvWqk0bhl6qap3D37bgpKCxh5x0iOv3T8GiUoLi5m3rx5DB48GMftqGbxx2Q5DTgCxInI4mrX4oA0bPODa65rNI4UlRfRf15/dpzcwXWh1/Hx4I95ousTNZZdsGABFy9eZOLEiV715dMnggnL8RRsER1ygWG+7FvTsFh5aCU3vnkjO07uoGvzrpz4fyecKgFARkYG3bp1o1evXl715+uhkUvLcTXiAh3XR1M/mLBsAg999BClllJiI2LJS8ujeZPmTsvv2LGDnTt3MmHCBIyleY/x9xyhsNr7qYB9+2F89cKOLhbnzp3zs2iausbZorP86G8/YvbO2VXnKqWS/ef3u6yXkZFBdHQ0P/vZz7zu29eK4NJybLcjGJOZIzVcr3KxaNGihY9F09RlFn61kJvfupmDBQe5r+19dIjtQHR4NC2jW9KleRen9S5cuMCCBQsYOXIksbGxXvfv68lyjZZju2XZGAolK6UKsD0dNI0cq9XKsOxhLP56MSEqhD+l/IlJP5lESUUJ+8/vp0vzLjQJb+K0/rx58ygtLWXChAm1kqPObt5PSkqSvLy8YIuh8SOHCw7T671enC46TYsmLdjwzAa6tuhqur6I0LVrV66//nq2bt3qeMnjiYK2I2iCwjvb3+HWd27ldNFpHrv1MU5OOumREgCsW7eOAwcO1PppANqyrAkw5ZZyHvjwAdYfW094SDj/eOwf/PSun3rVVkZGBvHx8QwdOrTWcmlF0ASMHf/dQcoHKVwsu0hCXAKbnt1E69jWXrV16tQplixZwosvvkhkZGStZdNDI01A+O3a39Jzbk8ull1kbOJYDr942GslAJg7dy4Wi4Vx48a5L2yCgLpYGNeTscX7TNAuFg2fk5dP0vcffTlYcJCosCg+HfYpD9zyQK3atFgsZGZmkpKSQufOnX0iZ0BdLAwliDPsCNqq3MBZsn8JbWe05WDBQSJCIzjywpFaKwHAF198wYkTJ3wySbYTUBcLQ0GmKKUWAdfE49eW5YaB1Wpl9GejeeKTJxBsy/PhIeGcLDrpk/YzMjJo06YNgwYN8kl7EGAXC8Og9jI2Y9s1sY+0Zbn+c+LSCRLeTuC93e/R7LpmtI5pbcpCbJbDhw+zatUqxo4dS1iY70b2plpSSsWKyCWlVAdsOcAuOSnqbnP+EBGZbrR5VeoiTf1n3u55jPl8DBarhQEdB7B8xHIsYjFlITbLnDlzCA0NZcyYMT6Q+AfMqtQwpdRhIBXYDvzTSTmXLhbAYiPC3RFsCqWVoAFgsVp4fMHjfHHwC0JVKDMHzmRid9u+gAgiSLzJNzkKr1y5QlZWFo899hht2rTxSZt2zCpCHrYkF1OptjHfEWeb8x027x/BpgRXXdfUX746+xV95/XlfMl5boq5iU3PbqJTfCe/9LV48WLy8/N9Okm2Y3aOEI9tyTMeF4qgaVxM3zKdO2ffyfmS8wy/fTgnfnXCb0oAMGvWLDp37kz//v193rapJ4KIrAHWGG/f9LkUmnpFSXkJA94fwLb/biMiNIL5T85n8G2D/drnf/7zH7Zu3cqMGTMICfH9Go/ZyXJ/EVlrTJYTRcTZHMFdO3GAq3SrmjrOxu828vDHD1NUXkSXG7qwafQml7vHfEVGRgaRkZGMGjXKL+27VASl1GBshrEEpdRT2NxbBeeTZXeW5WRgnDGZjgfGehNxQBNY7HsDsr7MYuaOmQC81PMl/vLgXwLS/6VLl/jwww95+umniY+/ZmOjT3CpCCKSrZSyW4HtaVCdzhEMy3KuiOwykmRXV4RdDlnnk7US1H1KKkro+k5Xvr/0PYIQHR7N8hHL6dPB49BBXvPhhx9SXFzsl0myHbeDLRE5iu2pMADbL7qrvAbuLMsuM71ry3LdY+aOmRy/dBxBCCGEVT9bFVAlEBEyMjJITEyke3f/5UU3u3y6QER2AxjzBLMUOjnfnWufFvZl1kyw7VDzoB+Nj7FarYz850gWfLUAsLlI3Bx7M/e0uiegcmzZsoV9+/bx7rvveh2hwgxmFWG6UkqAi0BHbDdyTbgN+64zvdd9jl44Sq/3enHy8kluiLqB5SOXExYS5jPrsCdkZGTQrFkzhg8f7td+zCrCNGMJFaWUq58Ed5Zl0Jne6zQZOzL45YpfUimVPNz5YZY8vYSwkODs3zp37hyLFy9m3LhxREd7HODaIzyxI9hxehO7sywbr/UehDpIuaWcgfMHsuboGsJCwsgalMWz9zwbVJmysrIoLy9n/Pjxfu/L3fLpGBGZq5SaDVyPbfnU1dBIUw/ZdWoXA94fQOGVQto3a8/m0ZtpG9s2qDJZrVbmzJnD/fffz2233eb3/tytGi2y/xWRYSIyFDxPy6Opu6SvSycpM4nCK4U8e/ezHHnhSNCVAOCzzz7j6NGjjB49OiD9mY5rpJS6G8C+euRvdFwj/3LpyiX6zuvLl6e/JDIskuyh2QzsPDDYYgGwceNGUlJSqKiooEOHDuzbt48mTTyapPsnrpFSaizwNPC0UsqlI7i7hOJKqSHG9cmeCqvxDTmHc2j151Z8efpL7rjxDk5NOlUnlODQoUMMHTqU+++/n/LyckSEs2fPsn+/69inPkFE3B7AgJpe11BuMjZfJIA5NVxPBtKM1wmu+rz33ntF41uKy4vlyQVPCumISlcyJXdKsEUSEZHTp0/LxIkTJSwsTKKjo2XKlCnSoUMHiY6OloSEBCkuLva0SVP3teNhVhHGArHGMcZFuUX2GxyYY1cKh+vTsM0xqhTC2aEVwbccyj8kYa+HVSnB+qPrgy2SXL58WdLT0yUmJkZCQ0NlwoQJcurUKRERKS4ulp07d3qjBCJeKIJZf9aFwHTgXWrYdO+Cwmrv47BFwc4FUuzeqHa0i4V/+GjvR3SZ2QWLWACIDI2k6XVNgyZPRUUFs2bNolOnTqSnp/Pggw/y9ddfM2vWLFq1agVAkyZNSExM9HRu4DVm7QgXATOLue4syzvd9KNdLHyI1Wpl8MLBLDmwhBAVQnxUPGWWMp9tpPcUESE7O5tXXnmFgwcP0qdPH5YuXUrPnj0DLkuNwrk7sNkOFgKfAB1clIvDNvRJxLZR334+zeH1ZPTQyO98ffZraTG9hZCOtPpTK/n2/LdSXF4sO0/ulOJyr4YbtWL9+vXSo0cPAaRbt27y+eefi9Vq9Vd3fpsj/NpQho7A/3rTkaeHVgTv+fO//iwhr4cI6ciQhUOksrIyaLJs375devXqJYC0bdtWsrKyxGKx+Ltbj+83s04kO8Xmjo1SKtf4GyvOw7pogsAVyxWS309my/dbiAiJ4P0n32fY7cHL2Zidnc2QIbZV9Pj4eHbv3s0NN9wQNHlcYVYRfuOwQ62jUuootqdD7eP3aXzCluNbeOijh7hcfpnO8Z3Z/Oxmboy5MWjybNq0iZEjRxISEoLVaqWsrIxjx47Ve0Wo8j61o5Qa4Ad5NF4wOWcyf/rXnxCE57s/z98G/i2o8uzevZtHHnmE9u3bU1ZWxvnz52nZsiVdugR+gm4Wb7xPnZ4zgz26nY5yV3sKSgro/V5vvj7/NdHh0SwdvpT+HX0f6sQTDh06xIMPPkhsbCw5OTk0b96c/fv306VLl4AthXqDU0Xwdg7gavO+YTeYo5Q6Qg2xTzXmKKkoYfaO2byy9hXKKsvo3ro7a0etJSYiJqhynTx5kpSUFCwWC+vXr6ddu3YAJCb6JtKdX3E2iwamOry+2+G1q+VTdy4WcVSzNjs79KpRzVy+clli3ogR0hHSkdfWvRZskUREJD8/X26//XaJiYmR7du3B1scn64a5Rn7EMAWzuUC7vcjXLUXWSmVKNdGqkhSSsVjc8XIdLxgPE3SgKpfE80PHCs8Ro+5PSiqKAJsFuJHb300yFJBcXExjzzyCN9++y3Lly/36yZ7f+FUEUQkG8gG2/ZMEfnS/tqD9gurtVmIYTlWSuXww242+3VtWXZC5s5MJn4xkUqpJCosCgRuanpTUCzEjpSXlzNkyBD+/e9/s2jRIgYMqJ9rKGYny1867Ef40kVRly4Wxi/+QtGR7kxjsVp4+KOHWX1kNWEhYcx9ZC7D7xju01Dr3lJZWcmoUaNYuXIlc+fO5cknnwyaLLXFbMjHsUAn43WSiMx1UtTd5v2F2IZZ8fyw+03jhN2ndtP//f5cuHKBm2NvZvPozbRrZkxAfRRq3VtEhBdeeIEFCxYwbdo0nnvuuaDKU2vMTCQwuR/Bl0djnyy/vv51UelKSEdGfToqqG4SNfHqq68KIJMnTw62KDXhNxeLBKXUDuN1Rz/oo8agqLyIvv/oy85TO4kMjeSTpz6pExNiR95++21+//vf89xzz/HHP/4x2OL4BjPaAjQDZuPG+9SXR2N8IuQezpUmbzQR0pHbZ94u+cX5wRbpGj744AMB5IknnpCKiopgi+MMj+83v9/Q3h6NTREmLJtQtXts8uo6OdyQZcuWSWhoqPTv319KS0uDLY4r/DY0Mo27hONGmUR0wnEAThedpldWLw5fOEzTiKas+ukq7rv5vmCLdRUlJSV8/PHHPP/889xzzz0sWbKEyMjIYIvlU0wrgplwLibCwttJxnmA4EZBSUUJf932V15b/xoV1gp6t+vN6p+tJjKs7txgIsK+ffvo168f+fn5hIeHs3jxYpo2Dd42T3/h6+VTt5ZlIwhwLpDklcQNgKKyIlr+uSUlFSUATEuexuT/qRvRbQoKCli7di05OTmsXr2a7777rupaWFgY+fn5tG/fPngC+gmzT4QjIvIueOx+Xej4xlCM3Op5ExyuN3gXiwPnD/CTrJ9UKUFUaBTJCcELEF5eXs7WrVurbvy8vDxEhNjYWPr168eLL77IjBkzyM/Pp1WrVnXalbo2eLN8msAPiQWrYzYsfCLQSSkVJw5WZmngLhZvbXuLSasnYRUrTcKbIFYJqJtESUkJ33zzDQCbN29m9erVbNiwgeLiYkJDQ+nZsyevvvoqqamp9OjRoyqzfVpaWr1wpa4VZmbU2JZPM7Atn97polwcbjbvG++nYYt7FOesrYa0alRaUSq9s3oL6Uj478Nl/p75Ad1If+bMGcnKypKYmBgx8lwIIJ07d5aJEyfKkiVLpLCw0O9yBBCPV41cxj6tFg3bnsWto4j43b2wocQ+/feJf5P6YSqXyi7R6fpObB69mVYxrfzaZ2lpKZs3byYnJ4ecnBx279591fWIiAiys7N55JFH/CpHEPE8tY4rLQGaGX8dXSzu8UbjPD0awhPhNzm/qXKTGP/5eL/1U1lZKbt375Y333xTUlJSJDIyUgAJDw+Xvn37yhtvvCEbN26Ujh071iaMYn3Ct0+EYFKfnwiFVwrpndWbfef20SS8CUuGLSGlk6scjJ5RUlLCpk2bOHbsGBs2bCA3N5ezZ88CcNttt5GamkpKSgp9+vQhJibmqnoNfqxvw+MnQkATjjcGlh1YxlOLn+KK5Qr33nQv659Z75MtlMXFxWzYsIHly5eTmZlJRUUFAC1atCAlJYXU1FSSk5Np06aN0zbsYRQ112I24XhHswnHGytWq5XRS0cz7z/zUCjS70/ntb6ved1eZWUlu3btqhrnb9myhYqKCiIiIqisrAQgMjKS5cuXk5TUaE0yPsN0wnExGbXCnYuFsXwaZ7Rf710sSipKWPfdOsYvG8+JSye4PvJ6cn+e69V+gWPHjrF69WpycnJYs2YNBQW2HO933XUXL730EikpKSQmJtKjRw/OnDlDy5YtA5JWqTHgdmgktgh3R+3vlVJ3ixM3C3cuFoYhLU5EFiuldla/Xt8oqSih/V/ac770PAD9O/RnxcgVRIRFuK9bUkJeXh6nT59m48aNrF69moMHDwLQunVrBg0aRGpqKgMGDKBly5ZX1d27d29jGesHDLNzhFjgFWx7EQqACU6KunSxMBTkiPHUmFpDP/XGsmyxWkj9ILVKCSJCIngz9U1TSrB//366d+9OUZFtE36TJk3o27cvv/jFL0hJSaFr164uk2vrsb7vMWtZTsNmAEvwsP3C6idEpNCInzqNak8EqSeW5X1n9tF3Xl/yS/MJVaFEhESYshDn5+czdepU3n777arJbmRkJGvWrOHHP/5xACTXOMNsopCj2JRgCDYrszPcbd6f7BDhLqF6opD6wNRNU7lrzl3kl+Yz4vYRFL5cyObnNrN34l6nG+lLS0uZNm0anTp1YsaMGQwbNox27doRHR1N69atufPOOwP8KTTXYNbggG1YNAAY66JMHC5cLLApUyI2hZrsqr+6ZlArLiuWHpk9hHTkuj9cJ//8+p9u61gsFvn73/8ubdq0EUAGDhwoe/bssbVXu9RIGtd4bFAzqwQdvGm8NkddUoR1R9dJ9BvRQjrS5Z0ucq74nMvyVqtVli5dKt26dRNAevToIevWrQuMsBoRL+43s0OjqlSxxsS50fDiihfpN68fxRXF/OrHv+KbX3xD8ybNnZbftm0b999/P48++ihlZWUsXLiQbdu20bdv38AJrfEYs5PlHkqpT4ALNJK8CGeLztL7H735Nv9bYiJiWDFyBb3a9XJa/sCBA/z2t78lOzubG2+8kZkzZzJ27FjCw8MDKLXGa8w8NmhkcY0WfbVIIv4QIaQj9829T4rLnI/jDx8+LIMHD5bQ0FCJjo6W9PR0uXTpUgCl1dSAx/ebx/kRxI2F2URY+AT7ISLTPVFaf2O1WhnxzxF88tUnhKgQpidP59f/82un5VetWsXAgQOxWq3Exsaye/duOnbUYZ/qI2bnCKYwLMt5YuRRrqHIUKDQUJBhdWn5dO/ZvbT8c0s++eoTWjRpwZ7xe5wqgYiQkZHBww8/bH9KUllZyYULFwIpssaH+FQRsFmWC+1vqu9NFpFMcbAtSB0JBvzWtre4M+NOzpecJyosikMvHKLbjd1qLFteXs748eOZOHEiKSkptG/fnujo6DqfGknjGp/HNapGYU0njSfHUzWcD6iLRbmlnAc/epB1362rOheiQjhUcKhGp7kzZ84wePBgtmzZwpQpU/jDH/5AWVmZ9vtpAPhaEcxu3q/R2U4C6GKx4787SPkghYtlF2nfrD0iQn5pvtOs9Dt37uTxxx8nPz+fBQsWMGyYLW2r9vtpIHgzw3Z24N6ynAzsxBYyfqertvy5avS7Nb+r2kI55rMxUllZ6XIz/UcffSSRkZHSrl072bVrl9/k0vgMz+9dbyoF4vCHIlwsvSh3ZdwlpCNR/xcly79d7rK8xWKRyZMnCyC9e/eWM2fO+FwmjV/wz/JpQ2DFwRUMXjiYUkspd7W8i43PbCQ20rmR/MKFC4wYMYKVK1cyYcIE3nrrLSIi3LtYa+onvl41qnNYrVbGLB3DwPkDuWK5wu96/47d43e7VIJvvvmGnj17smbNGubMmcOsWbO0EjRwGvQT4cSlE/TK6sWxi8dodl0zcn+eS1Jr1/t7ly1bxogRI4iKimLt2rX06uXcrULTcGiwT4T3//M+Hf/akWMXj9G/Q3/O/u9Zl0pQXFzM888/z6OPPkrnzp3Jy8vTStCICHh+BGP5NEVEXvZ13yUVJew7u4/09emsOLSCUBXKzIEzmdh9ost6R48e5Y477qC4uJiYmBhWrVpF8+bOPUw1DQ+fKoKZ/Ahii4Y9zpf9gk0Jbn37Vk4UnQCgVXQrNo/eTKf4Tk7rXL58mRkzZjB9+nRKSmzRqa1WK8ePH9eK0MgIqIuFO5RSaUqpPKVU3rlz5zzq+JU1r1QpQagK5fPhnztVgoqKCmbNmsUtt9xCeno6DzzwAG3btiU6OrpBhz7XOCcoLhbOEC8syyXlJSR/kMzWE1sBuC70Oto0bcNtN14b70dEWLx4Ma+88gqHDh2iT58+LF26lJ49ezamcIiaGgi4i4Uv2Xx8Mw999BBF5UXcesOtrBy5koIrBTVmpF+/fj0vv/wy27dvp1u3bixbtoyBAwdWhU3RrhKNHG+scM4OTORH4Ac3i0RXbbmzLP9q5a+EdIR05IXlLzgtt2fPHhk4cKAA0rZtW8nKyhKLxeLGMKmp5zR8F4tzxeekyztdhHQk+o1o2XB0Q43ljh07Js8884wopaRZs2Yybdo0KSkpMf9VauozDdvF4tNvPmV49nDKKsvo2aYna3++liYRPwyBSkpK2L59O5999hkZGRkATJo0iSlTphAfH++sWY2mfjwRKisrZfji4UI6EvJ6iEzdNPUq9b9w4YLMnz9fmjZtWpUWacSIEfLdd9/V7ndFU19peE+EoxeO0vu93vz38n+5IeoG1o9az63xt7Jp06aqkOnbt2/HarVW1YmKimLSpEkNMg2qxj/4PGOOCcuyy+t2kpKS5LmM5/jlil9SKZX0uakPj5c+zro161i3bh1FRUWEhITQo0cPUlNT6d27N2lpaZw9e5aWLVuyd+9evQzaePE4Y45PFaG6ZVlExnly3ZGom6PkypgrKFHEbYzjwjrbxvhOnTpVZYjp168fcXFxVXW0LUBj4J/UUR7gMiy8ietVXKm8AoUQsyiG/j/uT8rsFFJSUkhIcB6QW9sCNN4SbMvyVdcdN+8TCiyAy/mXj2ZnZxdkZ2eb7bM5cN4jKWtfN9D1gtFnfZJ1n4jc7kmFQFuWXV4XBxcLpVSenBKPk4MppfJEPK9Xm7qBrheMPuubrJ7W8bXTXSaQZDjb5dhPGr/0Tq9rNMHGp08EsQXsyjTeOqaMynR1XaMJNnV5h1qm+yI+rReMPrWsdaSez+0IGk19pC4/ETQmUUolOP71c19xdSl4s6+oEy4W3lqja7M/2tvw9d4mVDdjUTcWERJqkMedrHOUUkewZSo1Vc+4PgTbEnaiB58xGRhn7OOIx5ZTb5fJuvbvp6D65zTxGdOwzSur3wMu98Cb9WQIunMdMNn4RwDMMXvdXT2Hcou86DPN+EeBbe9EnMl6VfswcAhp6YGsk3HYu2Gyzzhq2Nthol4yDkkePajnWDbZiz7t389kD+pNs/fl5P95zTlPvncR8znU/Im7fc7Ortdmf7S34evd1dsF5NaQUN2trMYvW66nshokKaWSHZapzdRLceg32Ww9+/eilBoitjwYpmU1yk9RSi3i6s/qyf8yzoMhoOl264IiVKfQy+vu6nncprPw9a7qGUqTCwwzW89wNalJCdzWFZFCQ3FzPZQ1Djhi1EtxMe6vXs9Od3eCVq9r3MAvY7MhTaupgpM+pwKJxo1cm40l1dutoi4ogrfW6Nrsj/Y2fL23CdXN9pcM3FvtpnTXZ5qTm9hdnztrqOOJrN7UHSIiuWKzKy1y+GV353FQCGQaT9wjHvyvzd8jrsZNgThwH0re2XWX9RzGpNfsjzbRZ43h603UqzGhuhlZHcbCczDmJB58P4k4jPk9+H4me1lvSHXZPfh+hnj6vzTqpdnrVuvzmv+xu3unpkPbETQa6sbQSKMJOloRNBq0Img0gFaEOolSKsdTdwnDjjDHXzI1dLQi1E1M7dUw/H7SoMpYpYM3eUmd8DXSeIdcvb9DUwu0ItQRHBzLcjGCHBjDI7uTWqHYcktMxuZ8VoDtCfCUXBst5Jp6Afsg9RQ9NKo7TAMW262nDueOGMdTdiUwbuwkF8Ohq+r5XfIGgFaEuk+eiOwyfvW7YyiJ4aZgtp7GDdqyXEcwhkZDgTxsv+h2D81xwCdGsUJsQ54j2IY9AO8CA7C5IdhfxzvWEyexozQ/oBVBo0EPjTQaQCuCRgNoRdBoAK0IGg2gFUGjAbQiaDSAVgSNBtCKoNEAWhE0GgD+P3SsiRiznWaPAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1)\n", + "_, df_train = plot.calibration_regression(\n", + " mean_vi_train, sigma_vi_train, y_train, \"Train\", \"black\", ax\n", + ")\n", + "ax.set_title(f\"Train {errors.ace(df_train):.2f}\")\n", + "k = jnp.arange(0, 1.1, 0.1)\n", + "ax.plot(k,k,label='Ideal',color='Green')\n", + "savefig('Calibration_VI.pdf')" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1093,7 +1211,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.13" }, "vscode": { "interpreter": { diff --git a/notebooks/regression/sin_data_comparison.ipynb b/notebooks/regression/sin_data_comparison.ipynb index cb7b455..7bc1ada 100644 --- a/notebooks/regression/sin_data_comparison.ipynb +++ b/notebooks/regression/sin_data_comparison.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -33,12 +33,12 @@ "from utilities.fits import fit\n", "from utilities.gmm import gmm_mean_var\n", "from utilities.predict import predict\n", - "from utilities import plot,errors\n" + "from utilities import plot,errors" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -68,12 +68,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAL8AAACWCAYAAACPSVn4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAS50lEQVR4nO2dTWwbR5bH/yWGhE3ZASHKB19ELX3zaZfRLUdRQDDHXdlAwCADGwk18kXAAisjo7MmsG5CgJkRB5ARoBuDWNprgIUYYC9zWITi3nILR/QiCLCWZC7GkQ177LeHZkskVdVd1exu9kf9gIKa/VFdIv9V/arq1WtGRNBo0sjUpAug0UwKLX5NatHi16QWLX5NatHi16QWLX5Naoml+D/66CMCoJNOsolLLMV/fHw86SJoEkAsxa/RAIBpAvPzwNSU9dc01a5/L4hCaTRBY5pAvQ6cnVmfu13rMwDUanJ5hN7yM8YKYd9Tkzw2Ni6Eb3N2Zu2XJVTxM8aqAPYcjtcZYxXG2HKIxdLEkKdP1fbzCFX8RNQE0OMdY4ytA2gRURvAUpjl0sSPuTm1/Tyi1OG9hYuK0WOMVSZYFk3E2dwE8vnhffm8tV+WKIl/kMKkC6CJNrUa0GgApRLAmPW30ZDv7ALREv8hBkTfN3/O6fcHWoyx1rNnz8IumyaivHgBEFmjPWtrasOdk+jwVmyThjFWYIztAAARNQBU++ccjF5LRA0iWiCihRs3boRZbE0EMU3g3j3g5ORi38kJcP++fAVgcVzJtbCwQK1Wa9LF0EyQ+XmrtedRKgFHR0O7GO+8KJk9Go00IuED8sOdWvya2OFm1sgOd2rxa2KFbes78atfyeWlxa+JFRsbwJs3zud8+61cXlr8mlghY89rm1+TSGTseW3zawJnXH96L2xuAtms+LiSiwMRxS598MEHpJkshkGUzxNZ86tWyuet/WHce3p6+N4AUakkvD9XR7rljwGmaWJ2dhaMsfM0OzsLM4ymllse4NNP+f70n3wSzlNgdG7WbvFVfHsm3op7SWlq+Q3DoGw2y12Ync1myQijqR0qD1E2e7nVHU2MEa2uBnP/qSn+PUsl4WVcHU1cyF5SmsRfKpUcIxOUHH7xYMrjLvzBVCz6Zwq5VTzGhJdydaTNnghjmia6TvP4gOtxf8px0bFVvd3JibW21g8zyG2MX2UhCwDd8kcVwzAon8+7xqTJZDJD15RKJWKMUalU8sUk4nVsvSQ/HlBu93D4d7k6mriQvaQ0iN/N3BmtAIuLi5cqSz6fH7sCqJo5bhWAMcdRGUcyGXHexaLjpVr8cUJW+G5pnD6BYfgn/NHkZVjUY6tPJNCRtvkjiJ9DmE9VwhkMleEiDk4QqIYZMU1ruSKPYlFxiNNGVCuinJLc8huG4VurjzFafj/NHVFyGJ1RKo/EE0S3/FHHNE3cc/PXVSCXy2FTJZzBQDm6XfKtHCJURmcCGdQS1Yoop6S1/PYoDXxs8QFQ0aUXyGN1dZWArwh4J92CF4tWCsrmd+t7SDzcuDqauJC9pCSJX3ZI02tSLYt13Rsl08Uw5Gd+7cqi0tl1M8EkzCct/igSRItvJ8aY0lDnRVnkW/3B7A1D7gmQy6mJnzHn/GLR8gOoA6gAWBYcP4QVy7PqlE+SxM8YkxLy7du3zyewisWidAVQ6fBelEWu5RdlLTMxpmKRObX8khVpsuIHsA6g0t/eEZxTkckrSeKXbfm9XKNq+lxUKjmb30l0huE8KWVXHjfhityXAaIrV6SfIBMX/w6Acn/7EU/o/QpSBVB3yitJ4jcMg3K5nKN4B10YiOSfFnaSNX2uXbs2cN1XBPxdWAlkWm43c2UwL14R3Z4gCg+1SIl/x6mVB3DA2VcH0ALQmpubk/6v48D09HSgLb+s6TN83ccEvOCKTnakRmUEiJen2/UK8wQTF3/dyezpH7crx6FTXklq+YlGRecuXsMwKJPJ+N76D1/zV67gMhn5zuq1a/LiH23JZVwrYtPyW9/tuVmz3P9csCtCf7sKYBmCDrGdkiR+txldkXOaYRhSTwy3fOy8Lj9N3o7V2nr1C7KLKPPUUBgxmrz4/UpJEr+TCSMzVKkyQcYzf8TzDH/lCk52lMare0Q+T3T7tty5CmjxR5FxTBUvE2SjiCvOxwS8vCQ42TF62c6u1zQ9rfY189LEhewlJUX8lisBX6QyrgleJsiKxSKtrq5KzhX8L1d4MrZ20I5xip4bWvxRY2pqynOrT6Q+5KmevNv9hhFs66/iEUoCHWmvzglhmibevXsnPF6TcFCfU160qgp/LYDMbWs14De/Efvgj4sv/7qoVkQ5JaHldzJZRie1RATtFMcb61ddgWUYw8sX/Wj1PawC4+po4kL2kpIgfieTZVUh4E1Q7tCDFSCT+R8C3nleezuIquvzaPIYCkWLP0qIOpy5XM5TfkFVAL+DYqm4Po+mxUXPt+XqSNv8E+KXX37h7r9+/bqn/DY3N5EffTHtmBSLRam+hwq1GvDZZ+rXra4CzaavRdHinwQPHjzAq1evuMdOT0895Vmr1dBoNFAqlcYp2jn5fB7b29u+5DWK7MsjbAwD+P3vAyiI6JEQ5RR3s8fJL8eP8IMynqJOKZPJBBoDVGUI1MNKTB7a7IkKb9++FR7zsuB8lFqtht3dXU9PAcYYvv76a9/NnUFkhylzOSCghw8AbfZMhEwmE/g9arUajo6OQEQoFovS1xFRoMIHrFDibuP/xSKwu+sxHo8kUuJnjH3OGPvH4IqRLuoO0aA2VCI5SbK9vS3dGfarz+CEaAIsn7fseyLg+DhY4QOQs/kB/AOAzwE8AfAfAP5Z5rqgUpxtfsMwXP1qgrrv4BpgnmtFLpcLNd7/6ARYgLfm61p0YOgkYB7A+/3tfwKwCOAzmWuDSHEVv8yMrOzsrh9lGayExWIx9BddhAhXR4wsQTvCGPs3AGUAMwA6RPQFY+xfiOjfx3vueGNhYYFardYkbj0W8/PzUvH0ZX4TjRLcHsZ7khc3YYn+/wb29cYtUdrodj8E8J8A5mA5jf0WwJ+HzgnD5tZYSImfiP6bs+87/4uTXKwow38Ckd3xnAfwp/62VQG8xtbUeEMPdYbExgYGhG8zDeB3ACxXgt3d3cCHGTUXaPGHhDhM/hxKpRK2t7e18ENGiz8kxLOaT9HtdnH//v2JvVc3rYQqfsZYnTFWYYwtezkeZzY3rUmcYX6B1ekFXr9+jbW1tbCLlWpCEz9jbB1Ai4jaAJZUj8edWg1oNABrMOcdgCNY84YXoz0nJycTKVtaCbPlv4WL4dEeY6yicrz/VGgxxlrPnj0LtKBBYJpWp9ey/fnDnJpwmZTNX1A9TkQNIlogooUbN24EUqigsF/u1u1afivWMKcB4Kuh81Qc0DTjE6b4DzEg6r55o3I8tmxsWG8fHGYKwAMAHwMAstlsYItHNHxCEz8RNQBUGWNVAAcAwBgrMMZ2RMeTgniYcwrA71AqlfD48WM91BkyUr49USNuvj3XrwMvXoiOvoNh/FkLP1i4vj16nD9gTNNJ+BZra/8VTmE0Q2jxB4z70P0UTk7+NYyiaEbQ4g8YuaH7oMMOanho8UeAYvHSUJAmBLT4A8Zt6D6X+zu2t6+FUxjNEKkSv2kC8/PA1JT1Nww/su1tKwQHj1IJ2N19L/iF2ho+ovWNUU4qa3jtRdK8gEgeov16YvTN5B6DrWq8w9VRolv+QbcCHmdnwCefAA8eBF+Wly8vtk9OrHJpD+YJI6oVUU6yLb9KPHiFt3lLM9rijyYfIhNq5ODqSHYBeywRuxVc5tUr6ynwl7/4ExTVNIFf/xpwiEzoWr5er4fj42O8efNm/AKlgGw2i9nZWRQKBanzEyl+233Yi+fGH/8IfPjh+NHC1tachQ+4x6z8+eefMT8/jytXroAF9X6fhEBEePXqFY6OjqTFnzib383Od4PIarEfPBhvZMhtciuft1Z3uXH16lUtfAkYY7h69arSNYkTP999WI23b4E//OHC/77b9beDmslYq7qiOMS5tbWF/f197O/v49atW2g2m1hZWUGn03G9ttPpYGtry/W8druNW7duYX9/H81mE41GA+22swe7zP2VEXUGopxEHV6vr7yXTSod1OlpcT6yHesffvhB+n6DsThLpZLn0IN7e3tERPT8+XOqVqvn24eHh57yE2HnbbO8vEzPnz/nnvv8+XNaX1+XylfwnXF1lJiW3zSBe/eCvYdKB1pELud/i2+aJur1OrrdLogI3W4X9XrdUzSI5eXLsQMKhQI6nQ7u3LmDra0ttNtt9Ho9NBoN7O/vn7fa7XYbKysrAIBms4mlpSW02208fPjQ9b5LS0t48uQJN99Wq4V2u41m/71EvHO8kAjxmybw6adA0IMisi9VME1A8MotvH7t//j+xsYGzkZsvbOzM1/DnduVYn19HZXKxfLqcrmML7/8EgCG9lerVZyenqJSqaBYLCqLdDDfarWKcrmMarUqPMcLsRe/3eI7vM/ZF2Q7qIC7G7PfIfifCh5Jov1emZmZOd9uNBool8vngufZ5OVyWTrvg4MDVKtVx3ztbZl7yxB78W9sqLX4i4t2+BB5ikW1DqrbSI/PmhS+id3rG9p7vR6ePHmCTqdzbmo0m010Op1zoZXLZXQ6Hezv72NmZga9Xg/tdhutVut82z7/+++/x+DKO/vYYId3ZWUF5XKZmy9wYXqJ7u0JUWcgymmww6vycjP7FbeGYfn1yF6n6gPkV8dZtsPLi/ufz+eTHG9fiEqHd+JC9pIGxa/iwsDYxbfh5PAW5EiPSkWaxGhP3Ims+AHUAVQALAuOHwLYA1B1ymdQ/Cpv9BYJ2DCIpqbcr5fR0+qq+HpVb04V8WssIil+AOsAKv3tHcE5FZm8Rsf5DcO5tZVpcWXMJ7c8DEOcj5f3yWrxqxPVcX63cIVAP24PY0z8ukIBVt3hIzOjKtM3PDuzXB9EQ5Vra+Jy6DCc0cN3xzaBcJ+MfC7wriWirX4eBwAanHzrwOVRDCeXhnxebqRmc9OaK3AbMn371nJ1AIbzNE1ngYfw6l2NKqJHgt8JfXufBGZP/3i5v33olNeo2eNksqjY2NeuuZs+dspkLvKW7TOoMgmz59GjR7S3t0d7e3tULpfp4OCA6vU6/fjjj67X8s45PDykcrlMe3t7dHBwQDs7O66uEjL3EhFJm5/o3O6vot/hhfUE2BnYrgJYhqBDbKdR8YtGbVTsbC9+Qfk80eKiXH/By8KVSYjfq2+Pk/9NUH48PKJq84OItoioSUT7/c89IloZ2G4S0b59XJbNTSCbvbz/b3+TdyXw8l6IszPgu++c+xs2Qb9nzq/F+SLfnkKhcO5PY09OtdttNBoN9Hq9S/43TkzCj4dH7Gd4Acv2fv/9y/tfv5Z3JQiyQzo9Haz78mgIdL9dsAHg4cOHWFhYQKVSwenpKdrtNr755hssLCygUCgI/W9kCMOPh0cixA8Ap6f8/X67EnjhypVg8+d1+M/O/PUhmpmZQaFQQLlcxt27d7G8vIwvvvgCOzs7XP8bJybhx8MjMeIXDVXKurcE+V4IUcX0C1EF91rxeb49jx49Ojd3bL8a283Zdngb9L+xiYwfDw9RZyDKibeYheevo+JKYBiW78/g9aOfvSavURpkO7yiDn8ao0NEtsMbJIMvfGPM+qviiVmrAbu7w9fv7o7/RFBxhfYK702PYdw39ohqRZSTSsS2cRlnaeS4kdnUHNuslp4x629K/drS2fIHRa2m3vpfuwYYBnB8HN4i9VoNODqyZqiPjqK5OD5qaPFLcPeu/Lmrq9b8ghZf9NHil+Dbb+XP9SPa2yAvX74EkcQsWsohIrwcDIgqQSIjtvmN7JCh38OlN2/exE8//aTDFUqSzWZx8+ZN6fO1+CWYm3OPAJfNWrH4/cR2K9AEgzZ7JBD5DtmUSsDjx9rOjxta/BLUapa4B82aYtEa0SHSoytxRb+EWpMGuJF+Yyl+xtgzADwrfBbAccjF8UpcypqEch4T0UejO2MpfhGMsRYRLUy6HDLEpaxJLqe2+TWpRYtfk1qSJv6G+ymRIS5lTWw5E2Xza9IDY6xARL1x8khay69JAYyxKqywlqLjdcZYhTF2eTX+AIkRv9s/zBg7ZIzt9b+4qJVN6scKmih/h4MQURMX0f+GYIytA2gRURvAklM+iRC/5D/8ORHd6X9xoeFWNpUfK0ii/B0qIhMWE0BCxI+A44COiVvZpH+sgInyd+iVgtPBWHl1BhUHNEQKYx4PiwJvZ0S+QzcOMVD+/pOMS6zET0TcL5wx5vgP9ytNk4g6AGYQLm4/hvSPFTBR/g6H6Pc5KoyxChG1GWMFAI+IaIWIGoyxdcbYDIADx3ySMtTZt1nbAApEtD/4hfS3F9D/cUkxHGKQZeMdD7NssuWc9HcYBIkRv0ajSlI6vBqNMlr8mtSixa9JLVr8mtSixa9JLVr8mtSixa9JLbGa4dVcZmDyqQJrgmqJiB5OtFAxQU9yxRx7UQdjbI+I7ky6PHFCmz0xZ3Q1U/9JoJFAt/wxp++P0wFQhmX2tMZd3pcWtPg1qUWbPZrUosWvSS1a/JrUosWvSS1a/JrUosWvSS1a/JrUosWvSS3/D5d8IB+Ded7MAAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAL8AAACWCAYAAACPSVn4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASn0lEQVR4nO2dT2wbR5bGvyJDwqbkLGHKB19EDXPzaVemTrmtZSCY69IBAmYysJFlRr4I2IONrM4ar3QTAgxGHECGgW4MbOmcixhgL3NYmOHecltG9CIIsJZsLsaRDXnsN4dmS02quruq2f+7fkBBZFd3dZn+uvr1q9evGBFBocgiuag7oFBEhRK/IrMo8SsyixK/IrMo8SsyixK/IrMkUvyffPIJAVBFFdHCJZHiPzw8jLoLihSQSPErFACg68DCApDLGX91Xe74D4LolEIRNLoOtFrA8bHxfTAwvgNAsynWhhr5FYlkbe1M+CbHx8Z2UZT4FYnk2TO57TxCFT9jbJkxtuFQ3xrt0wizX4rkMT8vt51HqOInog6AGq+OMXYPQHe0z80w+6VIHuvrQKk0vq1UMraLEiezZwnA0PzCGFuMriuKuNNsAu02UK0CjBl/223xh10gXuKfZGj9MjKJuoyx7vPnzyPqkiJuvHoFEBnentVVOXdnnMT/FEDZ/EJEfWslEbWJqE5E9StXroTdN0XM0HXg9m3g6Ohs29ERcOeO+AUQ+gMvgJrVpGGMjbyzaAOoj+r2w+yXInmsrQFv357ffnIi7u5kSXyNsV6vU7fbjbobighhzLnu/fvxTbz94mT2KBRCuJk1ou5OJX5FojBtfSd+/WuxtpT4FYnCzta38u23Ym0p8SsShUj4gmiIgxK/IlGI2PPK5lcEzrTx9F5YXwcKBft6qRAHIkpcuX79OimiRdOISiUiY37VKKWSsT2Mc8/MjJ8bIKpWbc/P1ZEa+ROAruuYm5sDY+y0zM3NQQ9jqOX2B/jiC348/eefh3MXmJyeMkd8mdieyEdxLyVLI7+maVQoFLgvZhcKBdLCGGrH+kNUKJwfdScLY0QrK8GcP5fjn7NatT2Mq6PIheylZEn81WrVMTNB1eF/PJj+uAvfWioV/0whtwuPMdtDuTpSZk+M0XUdg8HAcR+3en/6cfZgK3u6oyPj3Vo/zCA3H7/MiywA1MgfVzRNo1Kp5JqTJp/Pjx1TrVaJMUbVatUXk4j3YOul+HGDcjuHwz+Xq6PIheylZEH8bubO5AVw48aNcxdLqVSa+gKQNXPcLgDGHL0yjuTz9m1XKo6HKvEnCVHhu5Vpngk0zT/hTxYvblGPoz6RjY6UzR9D/HRhPpNJZzDWh7M8OEEgm2ZE1+3DmCsVSRenid1VEeeS5pFf0zTfRn1MMfL7ae7YFQfvjFR/BO4gauSPO7qu47ZbvK4ExWIR6zLpDCz9GAzIt37YIeOdCcSpZXdVxLmkbeQ3vTTwccQHQBWXp0AeKysrBHxDwHvhEbxSMUpQNr/bs4fAzY2ro8iF7KWkSfyiLk2vRbYvxnFvpUwXTROf+TUvFpmHXTcTTMB8ir/4AdSsf+1KmsQfxIhvFsaYlKvzrC/io761eU0TuwMUi3LiZ8y5vUSM/ABaAJYBNDh1ZRhZG7azJH7GmJCQr127djqBValUhC8AmQfes76Ijfx2TYtMjMlYZE4jv+CFFK34AdwDsDj6vM2pL5v1biVN4hcd+b0cI2v6nF1UYja/k+g0zXlSyrx43IRrF74MEF24IHwHiVz8uxazZntS6CPxm3eGllNbaRK/pmlULBYdxWsNYSASv1uYRdT0mZ2dtRz3DQF/s70IREZuN3PF2havi253EImbWuzEb2vaANjnbGsB6ALozs/PC/+rk8DMzEygI7+o6TN+3GcEvOKKTtRTI+MB4rXpdrzEPEHk4ncze1oAymQjfmtJ08hPNCk6d/Fqmkb5fN730X/8mB+5gsvnxR9WZ2fFxT85kouEViRp5DfNmkVYHnhNE2dUv5hFs8dJtHbBaZqmCd0x3Nox2zp/N3k31WjrNS7I7KLIXUPCYxSt+P0saRK/kwkj4qqUmSDjmT/28ww/cgUn6qXxGh5RKhFduya2rwRK/HFkGlPFywTZJPYXzmcEvD4nOFEfvejDrtcyMyP3M/NK5EL2UtIifiOUgC9SkdAELxNklUqFVlZWBOcK/o8rPBFbO+jAOMnIDSX+uJHL5TyP+kTyLk/54t3u17RgR3+ZiFCy0ZGK6owIXdfxfiKPtpWmQID6vPRLq7Lw3wUQOW2zCfzud86pxKfBl3+63VUR55KGkd/JZJmc1LIj6KA4nq9f9g0sTRt/fdGPUd/DW2BcHUUuZC8lDeJ3MllWJBLeBBUObb0A8vn/JeC953dvrciGPk8Wj6lQlPjjhN0DZ7FY9NReUBeA30mxZEKfJ8uNG55Py9WRsvkj4pdffuFuv3Tpkqf21tfXUZpcmHZKKpWK0LOHDM0m8OWX8setrACdjq9dUeKPgrt37+LNmzfcuhcvXnhqs9lsot1uo1qtTtO1U0qlEra2tnxpaxLRxSNMNA34wx8C6IjdLSHOJelmj1Ncjh/pB0UiRZ1KPp8PNAeojAvUw5uYPJTZExfevXtnW+flhfNJms0mdnZ2PN0FGGN49OiR7+aOFVE3ZbEIBHTzAaDMnkjI5/OBn6PZbOLg4ABEhEqlInwcEQUqfMBIJe7m/69UgJ0dj/l4BBESP2PsXxhj/xhcN7JFyyEb1JpMJidBtra2hB+G/XpmcMJuAqxUMux7IuDwMFjhAxCz+QH8CsC/AngC4DGAfxY5LqiSZJtf0zTXuJqgzmt9B5gXWlEsFkPN9z85ARbgqfm6tqsY2wn40PL5nwDcAPClyLFBlKSKX2RGVnR214++WC/CSqUS+kIXIcLVESND0I4wxv4DwD8AuAygD+ABgCUi+m66+4436vU6dbvdKE49FQsLC0L59EX+TxRScJ8wPhA8+DGAPhH9PwAwxn7lV6+yxGDwMYD/BDAPI2js3wH8eWyfMGxuhYGQ+Inovye+/wjgx0B6lFKMLMN/ApH54LkA4E+jz8YF4DW3psIbytUZEmtrsAjfZAbA7wEYoQQ7OzuBuxkVZyjxh4R9mvx5VKtVbG1tKeGHjBJ/SNjPaj7DYDDAnTt3IltXN6uEKn7GWIsxtswYa3ipTzLr68Ykzji/wHjoBU5OTrC6uhp2tzJNaOJnjN0D0CWiDoCbsvVJp9kE2m3AcOa8B3AAY97wzNtzdHQUSd+ySpgj/xKAofmFMbYoUz+6K3QZY93nz58H2M1g0HXjodew/fluTkW4RGnzD2XqiahNRHUiql+5ciWwTgWBubjbYGDErRhuTg3AN2P7yQSgKaYnTPE/hZGSEABARH3J+sSytmasPjhODsBdAJ8BAAqFQmAvjyj4hCn+NoD6yJzZNzcyxlpO9WnA3s2ZA/B7VKtVPHz4ULk6Q0YotiduJC2259Il4NUru9r30LQ/K+EHCze2R/n5A0bXnYRvsLr6X+F0RjGGEn/AuLvuczg6+rcwuqKYQIk/YMRc90GnHVTwUOKPAZXKOVeQIgSU+APGzXVfLP4NW1uz4XRGMUamxK/rwMICkMsZf8OII9vaMlJw8KhWgZ2dD4J/UVvBx+79xjgXmXd4zZekeQmRPGT79cTkyuQek60qvMPVUapHfmtYAY/jY+Dzz4G7d4Pvy+vXZ5+Pjox+qQjmiLG7KuJcREd+mXzwEqt5CzM54k8WHzITKsTg6kj0BfZEYh9WcJ43b4y7wF/+4k9SVF0HfvtbwCEzoWv/hsMhDg8P8fbt2+k7lAEKhQLm5uZQLpeF9k+l+M3wYS+RG3/8I/Dxx9NnC1tddRY+4J6z8ueff8bCwgIuXLgAFtT6PimBiPDmzRscHBwIiz91Nr+bne8GkTFi3707nWfIbXKrVDLe7nLj4sWLoQp/c3MT7XYb7XYb169fR6fTwVdffYV+378g216vh48++gh7e3vY29vD5ubm1O0zxnDx4kW5g+zsoTgXJ5s/qCUwZT1DTm3l82Jt/fDDD+In9In9/X0iInr58iUtLy+ffv7+++9dj3358iVtb28Lncds26TRaDjuv7u7K9SuzW/G1VGqRn5d9z7iu3F8bJhSoszM2Nc9euR/ElZd17GwsIBcLoeFhQXPL8PX6/Vz28rlMmq1muux5XLZMQmvE0tLS9jb2+PW9ft97O/7H+WeGvHrOnD7drDnkHmAtqNYDEb4rVYLg8EARITBYIBWq+XpArCzl7vdLm7dujVmopgmy3A4BIBTE8n8fPPmTfR6Pdy/f1/ovHbt9vt99Pt9dCzrEk3u44VUiF/XgS++AIJ2ioguqqDrgM2SWzg58d+/v7a2huOJV8WOj499TXe+vLwMALh37x5qtRr6/T7K5TIajQYePHhwuo+5rJK5/+LiIiqVCnq9nmP7w+HQsd1arXbaJm8fLyRe/OaI77Cesy+IPqAC7mHMfqfgf2ZzS7Lb7pXLly+ffu73+9jd3T0dlXmIel0AYH9/H41Gw7Fd87vIuUVIvPjX1uRG/Bs3zPQh4lQqRtoRUXPFzdPjsyZtV2L3ukL7cDjEkydPxkyNTqeDfr8/ZorcvGlkmLl8+TJ6vR56vd7pPubnfr+Pp0+fwvrmnVm3t7eHTqeDzc1NbGxs2LZrPc5tHynsnoTjXKzeHpnFzcwlbjXN8N5E4emRmdkV9fbw8v6XSqU059u3RcbbE7mQxzoD1Kx/7YpV/DKuTcbOfg2ngLdpQxFmZvy5kGRcndaVV6rVaiaFTxRj8QNoAVgG0ODUlWFkbdiWEb/Mit52AtY0olzO/XgRPa2s2B8vG80ZhZ8/6cRS/ADuAVgcfd7m1JfNercyOcmlac6jrciIK2I+ubWhafbteFlPVolfnrhOcrmlKwSMvD3Lllw+whjXD5983v2BVeTZ8PjYCH2wc1Wurtr3Q6XhjB++B7bZZFjucLYNrV+IaAgjcRUYY/vmZ0u7LRhm0zkvBj8jmkGpJOapWV835grcXKbv3hmxQ8B4m7ruLPAQlt71hc3NzVMX5fb2NjY2NrC7u4v79+8LzfJO0uv1cOvWrTFvTqPR8NSW79jdEvwucDd7WgDKo8/7Tm1Nmj1OJouMjT0762768OJzRJ8ZZElabI9d/E1QcTw84mr2uKUrfAKgxhhbBrAr07CdyVKpiPvmRZJLWTHvAMvLwG9+437HSMo6c15je2Tib6KI4+ERmviJaEhGpuUeEe1Ztrct9T0i6pjbRFlfBwqF89v/+lfxUAIv60IcHwPffef8vGES9Dpzfr2cbzcr++LFC7Tb7dOJqU6ng16vh3a7jeFwyI2/cTpH2HE8PBI/wwsYo/uHH57ffnIiHkoQ5APpzIz/wWxWJlOgDwb+vyNs2vy1Wg27u7vo9Xp4/PgxarUayuXyufgbJ6KI4+GRCvEDwCie6hx+hxJ44cKFYNvnPfDLhmCLUK/Xsbi4iO3tbTQaDXz99dfY398fi68RibWJIo6HR2rEb2f3i4a3BLkuhN2F6Rd2F7jXC58X27OxsYEHDx6cxvDs7e2h2+1iaWlpLOBtMs4mNnE8POyehONceG9y8eJ1ZEIJNM2I/bEeP/nda/GapUHU22MXppHF7BBx9fYEinXBN8aMvzKRmM0msLMzfvzOzvR3BJlQaK/wVnoM47yJx+6qiHORydg2LZrmfcSfNjObXGCbMdIzZvzNaFxbNkf+oGg25Uf/2VlA04DDw2C9PFaaTeDgwJhvODgI77xJRolfgE8/Fd93ZcWYX1Diiz9K/AJ8+634vn5ke7Py+vVrEAnMomUcIsJra0JUAVKZsc1vRF2GfrtLr169ip9++kmlKxSkUCjg6tWrwvsr8QswP++eD6hQMHLx+0m5XJZ6CVwhhzJ7BLCLHTKpVoGHD5WdnzSU+AVoNg1xW82aSsXw6BAp70pSUYtQK7IAN9NvIsXPGHsOgGeFzwE4DLk7XklKX9PQz0Mi+mRyYyLFbwdjrEtE59/GiCFJ6Wua+6lsfkVmUeJXZJa0iV/q9ceISUpfU9vPVNn8CoUMaRv5bWGM1ax/FfKk7TdMjfgZY61RtrdzSbMYY2UA24yx7fB75tw3kfqwiPNvONGXZcbYhkO92O9pF+ifpAIf84BG0DfH+hj1M7Lf0Ka/u17+HdaSlpE/0DygU+LWN5G+h0Gcf0MZhH/PREV1BpUHNGSGU9aHxdD6JWa/oQxDu4pEiZ8smd6sMMaewrgtm/v1J+pbAJ6M/gPDxrFvAvVhEeffUAbh3zMtZk9geUBD6Bu3PgLi/BuOMepDzWrSePk9lZ9fkVnSMvIrFNIo8SsyixK/IrMo8SsyixK/IrMo8SsyixK/IrMkaoZXcZ5RtGUdwCKAPoAlIrofaacSgprkSjiMsTIRDRlju0R0K+r+JAll9iSfT0ej/wvgdOpfIYASfzqoAfgfxliDiNzXAlUAUGaPIsOokV+RWZT4FZlFiV+RWZT4FZlFiV+RWZT4FZlFiV+RWZT4FZnl77kUk+tqu+5PAAAAAElFTkSuQmCC", "text/plain": [ "
" ] @@ -118,13 +118,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 7, @@ -1089,6 +1082,127 @@ "savefig('MCMC.pdf')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VI" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from utilities.vi_helper import vi_model,vi_predict\n", + "\n", + "params = [[32,16,1],[nn.relu]*2]\n", + "mlp_model_vi, vi_model, results = vi_model(params,x_train, y_train.flatten())\n", + "\n", + "mean_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).mean(axis = 0)\n", + "sigma_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).std(axis = 0)\n", + "mean_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).mean(axis = 0)\n", + "sigma_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).std(axis = 0)\n", + "mean_vi_test = vi_predict(vi_model, results,mlp_model_vi,x_test).mean(axis = 0)\n", + "sigma_vi_test = vi_predict(vi_model, results,mlp_model_vi,x_test).std(axis = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/sindata/MLP_VI.pdf to figures/sindata/MLP_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/sindata/MLP_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAALsAAACeCAYAAABq8XvvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAuEElEQVR4nO2deXhb9ZnvPz8d7fIW24md2EnAWUmhSUyALhQoDWuxSjsJ9N5eprQdoKXpc7sBTzt0ep+ZTOdSbpneS4YW6BRaCtMS2k4tSloIhYZQOokTCGv2zYnj3fKifTn3j1eKZUeyJVmW5Vjf5/Ej65yjc16d89X7e3/v9lO6rlNEETMBhqkWoIgi8oUi2YuYMSiSvYgZgyLZi5gxmPFkV0rdrZS6Pfa3Sym1Vin1sFKqYYrkqVBKVSTZ3pD4Gvu/oGQveOi6PqP/gLWx1wrghYT/G9P4bAVwexbXvB1YC6xLsm8d8ELsbxfQGJcNeBhoyJHsZ1w78bzAfePsb0w8x+j3hfg34zU70DJ6g67rbuDweB/Udd2t6/ojmVxMKXU30KLr+lbgqiSH7NZ1/Spd168C7tF1fXds+z26rt+h63qiXFnJHtP8ya4dP8dWIOnooJRaC1TE5Eo8Zi1QOdZ1pxoznuwxciTDGqXU5pipEDch1sXeV8Ter1VKPZzw/wtKqUal1H1jXPIi4PQ1lVKNo+Q5HL9WjHSJ8qxVSt2ehuyVMdNmXewza2Ny3R6TvQFoiBE3I8Rk+pZSajOwNSbr2vj/hYwZT/ZUiBNN1/Xv67p+OEZ4N/AM8K2EYypHHb8b6BlN4jHgTrH9ogRZ3LquPxK7xvo0znkfot0Px45vBG4GDsfOtTX2f8YEjd2HexCz6j6lVGM255kKGKdagAJHb8L/DQhxNpNiiCcFcRM0/QvATsSuBoY1+ajj1456fzvw9BiaPBlaYsffESNoL6KRDyeMHg3Jrj8O1um6/v3Y5wHqlVKVyA9qkVKqIkM584aiZkc8IMBNJAztsdeGBM9IA0JWgN6YWdAYPybh/wZEK6+Jn1/X9Xtif1uBRxCTpDHhfHFCx1EBHEp4/3SCbJvHkx3RvN9KkHFdTJ6djPwBJx19Er57Y8K2uHzPxMyjRqBX1/Xm2PeqSnauQoKKzaSLKOKsR1GzFzFjUCR7ETMGRbIXMWNQJHsRMwZ5cz0mBjOQkPf383XtIoqA/Gr2mwC3ruvPADcnS3YqoojJxJS4HpVSu3RdvzDJ9tuRJClWLKy/8J1f/yT5CXxDUDkXKmtTX8Q7CDXngKM8FyKnhSuuuAKAl19+OW/XnAimXF5dh1MHIRwGsyX1ce5O6D4BtlJ57x2Aj6xXmV4u7xHUWCJU0pB3LKnqEYA1yxpS/wrNVujvhlk1oFJ8Z80oNyWPZP/Yxz6Wt2vlAlMur38I/B6wl6U+JhSAnpNgtk/4cnnV7LHI3GFIHiZPxJplDXrLUz9KfYBvEOqWgc2RfH80CiE/LFgBqjgPLzjoUTh5QLS7yZziGB1OHZZnbU14zoWu2WNEvw8JV1cCZ5gxmZ1Qg6He1GQ3GCASgWAALLYJXaqISYBnAAI+cIyh1T394HEPmy8TRN7IHsufmBjBE2G2wEAPVM4DTUt+jFIyTOaJ7Ndddx0AW7Zsycv1Joopkzcagd62sZ9LOAxdx+WYVKZqhpi+47tBk2HON5T6GJMZhvryJpLP58Pn8+XtehPFlMk72AfhEBhNqY/pPSkjszbGMRli+pIdZBI60DXGfhMEvBAJ5U+mIsZGJAR9p8A6xoTTOygOCGsKEzVLTG+ymywyWQkFk++PD3+B6aNtz3q4O2VENqQwPSNh6DgqHrccmS9xTG+yKwUomcikgmYce38R+UPQB+6usTV2b7sQ3pjCQzMBTP9KJbMF+juhvDq5JjBZhOzVdZPugrzhhhsm9fy5Rl7l1XXoaQOTKbXG9g2CuyNn3pfRmP5k10wySfV7k7shDQaZ/Qf9YJl4YGIsfPOb35zU8+caeZXXOyC2eCpXYyQM7UcnxXyJY3qbMXEoJT73lPsNY3ttiphcRCMS7h/L1dhzEqKTY77EcXaQ3WwTn3skknx/nlyQV1xxxel8k+mAvMnb3x2zw1O4EYfccowlt96X0Tg7yG4wSPjZO5B8v9EkZkwqr00Rk4egH/raU09KQ0HozG3wKBXODrIDaOaxfe4g0dQi8of4pFQzikJKtr/rOKDnNHiUCmcP2U1mscuD/uT7jWbJsygif/D0y2ibylbv75JjJtlxEMfZQ3alxp6omsziDYimsOuLyC0iYZmUpoqU+r2x/VnY6aEA/OzvX830Y9Pf9ZgIkw3c3VBRe+awqRSgS/rAJPlxb7rppkk572RhUuV1t4uZoiWhWCQM7UdktE0VSU2FoB86jwGMUbmTHGcX2TUNghEZOksqkuyPRVMniex33nnnpJx3sjBp8vo9onTsSe6zrkNXK4SDYCvJ7LxBP7TujXvdBjMV6+wxY+LQTBJRTQaTRdxcenRSLu31evF6vZNy7snApMgbjcikM1VwqL8LBnszN1+CPiF6NJw6pXscnF2aHUZOVM3Wkfvi0dSAL+cZdQDXX389MH1qUCdFXnenuBOTaXXfIHTF7PRM3IwBH5zYK89OM4oZlAXOPs2ulERM+7uT7zcYZKJaRO7h90BfR/K0jVBQ7HSzJTM73e+B1vekzDKZ/Z8Bzj6yg2j0ge7kGsBkEY9NsaFrbnHafLGcmXAXiUD7YbnnmaQD+IZEo6NnbbokonDJHo1kb1vHq5iG3Gfu04xS8hVK4Y8vIjv0nZLqI9Oolhi6Li7GgC8zf7p3AE7uk/8z9dikQOGSPRKWyUy2MFtkSE2mwRVFUyaX8PSL9yXZPKi/S0bZTOZIQ31wcr/8nyOiQ0FPUHWZzNjLz5xopgPNBMFBGQpHT5bMVvEIlM/OaT7GrbfemrNz5QM5kTcUFFei1X7mvfT0DweO0r3PAz1i2xsMyVMMJoACJjuilTuOQP2y7AovNJNo99Fk14wQGJRIXDY/pBSYcWSPRsVOV+rMyWPAJ3a62ZqedtZ1eVbdrXKuSUgKK1wzBsBoFM2crTljsoBvIEW+jMp5jnt3dzfd3Sm8QAWICcvr7kjeqiQcglOHwGBKL8ErbtdPItGh0DU7yJfvapUWaeYM+78kuiFn14/cZ7bAYI+U8+UI69atA6aPn31C8noGJCVgdDQ67nmJRtKbkEYj0HFMnoVmkvnUJKGwNTsIYXXEjsvGO2O2ycgQHuWGjOe4p8qSLCI1QgHoOibFFonmZVSXvBW/Nz2iR8LSAm+wR57H5KazTwOyg5gzfo8UAWSK+CQnWTakUqkLPopIjrgmVoaRdrquQ88J8PSl53kJ+uH4exJVHatZUg4xPcgOcmO727IrwDDb5IcSHTUyxL0yxQBTetCjYlKG/Gfa6X2nxIa3lo5vc/uGJCoaDmRM9OaWRdz2yPU89df3Z5zNl1eyx5YVH2up87E+LNKeOpy61jQVNE2GzNH9YzSjDMnFAFN66OuQApjRmruvUyqS0iF6f5dERfXswv+uliU4LEGady+vyPSzeZ2g6rq+VSl1R9YnMBiFmN2tstBAJjBZRPuUVIx8IMoAQ/1QOfHmp1/60pcmfI58IiN5B3pkdLSPIrS7U56HtWRsv7gelbiJu2NCHpemNQdw7VyEs3GvG1Zm9Nm8r7yhlNqs63rSxQgSsaZhnt5y3/9MvjMcgnlLkuesjwXfoHwu0e8ejUhgZMF5xT7uqeAdkBHV5hjpM+/vkgmptWRsX3o4KK5I31Bu7PNIGALePXz2n1dl8rGCerpKqduVUi1KqZaugTHyrA0GcW+FApldQDOdOck1aJIjnYN+kK2trbS2tk74PPlCWvL6vdK8yGrPjui+ITj+Lvg8eZuIpkJB+dlHLDPTMC/1kGPQxJV46hDMX56+RjZZJCfG5xmZhmrQsisoGIVbbrkFmD5+9nHlDQViUVDzSPva3SmR07GIruuSE9N5TEwWY26o1tyyiE3PrcYb0Bq2fzazz+Z9ggo0KKUaJ3yyuDuyuy0TAeShuTtGbjdbpWd4lkUBZyXCISG6wTCclqvr0HsqlgszBtGjEenE23FUjslRMldzyyL+5fHzueTwr2l3OzIeJvI+QSWXq29oJpl02kqhJM2FwsxWSf0N+IbdZ8oA6GLTl8zKmXjTFvGC6Gh0+B5FYyH9/ph7MdVkNOiXETfgzZ3Z4g+w61c9lG/dyiuBlzAS4RHjlzL+BRWUGZMxFDH7/SAseF96SV1KiSuyrwNqzxnebrJIWsFMJ3s0IqZHKDBs6kVi2zxusJWl9qQM9cYi3frEiR6OwJ79sP0NaHmXCwMhThrqeNjyFfYsuIobavf1w5qMTjm9yQ4J9vsBqF+RXkWL2SYPJlg7/AOJ95VJVrs6UxCPjvo9w5X/8aSugE9Ml2REj0TE/djfJc8j26qiaBT2H4dXXoe/vgWDXiixw2WNfP3Q13jefwXzqjxsuGYnT22ry9gEn/5kB7HfA37J16g5d3wfbjxBrK8DahYmbDdI4UDl3KzE+MY3vpHV56YKI+SNhKXnon9omOhBP7QdlB9BqrYXvqFhz1i22vx4u2jwV9+ALjeYTbBmBXxkFaxcAkYjV7RoDLa007TmAM7VB3Ge/2Zrpn72s4PsIDd6oEc8KhU14x9vsUsC0qyaYU1usYkHoWJOVpOqpqamjD8zlTgtb3xpl4BvmNS+QWiLTVAtdpq3VeP68xyaLu/EeVm3BIl6TsmE1aAyJ3pXH7y6R0h+vF2us3IJ3HwNXLwCrCPL+5xrDgESQSUawXn+mxl/37OH7CCels5WIa99nAmrUrKWau8pqD1XthkMMpT6hrJaGXvfPqmZXLZsWfIDQsHYArb2Ee7SZtezbHroIbq6upg9ezYbvvIVnJ+4MePrZyVvOMSycvNIG72/B7qO0fzXBWzafC6dfSa63RZWLx9g0y8XsOmpOgj62HBVP86LtfSjoV4/vPYmbNsN7x2VbcsWwuc/AR+6AMrGbpoUTxVw7VpaJLvkzygZehesGD//3WITs8VfM9yT0GSWdX+yIPsdd0gmRFK/tR6VqquBHiG6UjKRU+B66mfsfedt2rt7qa3uwPUfP8e59sqxF8TNAe64/TYI+nn5l48J0aNRyXFxd4DFgWv7XPYec9DWZUEBf9pRxdI6Nz1uKz1Ds3i39Xp2HHmLDneJmBcx7TsCkQjsOSAE3/kuhMJQNxs+fTVcugrmVI4r572/vJTNry1n5UJxGTdduDer73t2kR3E/IjnSS84b+xKmbjfveckzFss7+Mr8AV9mReLjAW/V/zTQ32gDDRv34nr1RZqZpVz7Ngxhjweyuw2IqEQTauXw7G3oW4plFVNTuWO3ytmi1Lyow8FaX7Gg+tP9TR91IHz8m5qqvy091iIREWHaCrMO0fL0GPhGbfXwuMvXcC6D+zD1bJkJNmPtsGfd4uZ0j8EpXb42MVweSM01KX9nZpbFvHglgupdPjYc6yGff/3J7F0gcy/8tlHdhjOZjx5QOpXx7K/TVYxLRILsw1GiahW1eVOpt5TMh/Y9Q489wJVO9/mb3wBguEIHy9zEFVgN2u0GxU/euDH3PfEM3R7/Kz/5CfZ+IN/lR9IczMul4umpiacTmf2sgy5hyOb8US4zqO4XlqFo0Th2jYHFGzeOpcyRwifX2HVwvjDJiymKFFdR0dRWeKnssTHuyeqAPjDn6u5dvC39G95m/KuY0QNGoY1y4Xgq5dlFUV1tSxh6dwe9p+q4ivX7cr+O3O2kh1kwuT3iN93bkPqlAKlwGgR11n9eaLCzFaxWytqJtyFCohV4B+HnW/C3f8bn1GjLBSmBDAohcnjY+6sMsptFryt7azV4c7OPn5T6uCxX/yCjfd8DarrcblcOBwOXC5XdmTXo2Ki9bXFqowUhAI0b/aw6bcfptMtkdL+ISM/f24uFi2CSQvz4SXtbLj+dXYcnMtjL13AnHIvS2p72frWQoYG7azseIk7LT/hwoe2gR6mp+x8nl9xN69WXc8Pv/LXCd26moohBnz1fOW6XWz89PYJnevsJTsI4Yf6JPJXPT/10Bn3scdrUg2xiOqQOzc1qkO90N+N+wePMmQyckE4zECsg3ZtRQmfu2wVHf0emhqX8dZ7R7jsxV08Ggwx2+9ny+xKmp/6OZt++wc6+4eYU1PLhg0bMpch3i/d45aIczAgUc5oFNd/LaCtx4aOTjhswD1oIhIBf1TjA4tP8Pw/PAOIR2Tjp7eDrvNP9y9kbeBBPhXeTDU9nPLP5YmyO1i4bhFD1fN5vmUJTWsOTOi2NbcsYvNryymz+dn82nIuXnwq+bwgTZzdZAchfF+H2O5j+c8tdrHdHRUy3FpsMlErrUy7f8m99957xrbm3/0O1y/+nc+W27m0vZsfVpTgG9JR0SjlpQ5+9I/fxvXSqzjmzcF1dJBHv/v38IWDtH3rAb53oJXlQ0P83f/6Pj2DXhx2GxetXp25Vg/4ZHIciYClRCLF3Se5svECfr7lJG8e2oHJeAMmTWdhdS9dPRXYTDqL57rZcP3rw+dxD0rA5+VdfKe1gwAWfq/dwH+Y/gct9o9gt0X50NE2Hr32DxMiZRybtjTiC2ic6JnNlecfO3NekCHSIrtS6m+AQ7quv5H1laYSmlG0mmaUxkhJj9EgGJUU4Nn1sd4yPvD2p51CsHbt2jO2uf7zNziUjv3ZrXjMRp4A5s2tZfHSJWz40hdxfvzjUNuA6/e/p+nGdbBoFRjNzLv/bl794nf52/ZeuoF7FHh9fmpsGrfd8hma1t+E0/mJsQXSdRjohd4Tscb/Bok0+4do/q+3eOR3f6S9x4bN8m9EIk1cvPwY295bwGXnnaC2wsujX/wDBILwl73iTXljv3hsliyA2z6J5YMX8KkSO8aWMP1b+ukcsLPzUA1X/9N6Nly3OyeEL7OHqCztpbbCO+GRIl3NvhtYq5T6NlLr/7Cu63+a0JXzibjXpeOokLokhbvL6hjW5la72O69p8Rnn4Z2f+ONNwBYtWrV6W1Nl17Ea08+yeqOHv6juhyrw0HdkuU8/8ILp49xrrsJ57rYKhi6LpPRbX/lu1YrXzJqfD0coUyHny1dyI6W3bS5Bzh2/BjOj1+f2tsUCUH3STHjzDYx0XpPyX2wleF6ZSfhSASDoY9IWOe61QfY0zqXpXN7OdRWxrcX/Qw2bYEdb4M/CJVl4LxMJpt1c0ZcyrnmEM41h7jtx9fyxz0LOdpVzqYtjRMie3PLIgDmzRrM2Q8nXbL36Lr+KPCoUmo1UKmU+jtd138yYQnyBaXEK9N2GOq05H50pUQDdh2HumViAnnT1+5f/epXgQQ/e9CHc+US3vcTPxEFj5lNKPM4SyAqBZW1uHa9R8hg4m6zkQHgm+EIq3rdfNdiZKB/AHf/IM2PP4LzM7eemYfvGxJvSzQKKPFKhYMyKTUYaP7zX9j59jv0uAexmDS+ev0ldLjDfG3Zr1l88GUuHXwO6+Z+sFvhQ++HS1fDinPH/cE3rTnAc6+fi1GLsutwDavu+ixzyrwZkbW5ZRGuliUc6yplRX0PnoA5J0SH9Mn+baVUOVAJHAb+JSdXzzfiE8+2mEsy2XIzZqv42Qe7xeQx22LavSzzFAJ3J3SepHbHm2yxWug2W/jghy+l6ZOfGvtzykDTTZ/m2IkTmIwavwgFKPX5uKO9l8fsFq7WFPPPXYDruT/gvOJDUDUPymZLDou7Q1ycyjC8yoXZKmkAehQGunH9fgsnOrqJRKNYdQOXvO7mA51OZntbwWSExuXwkSZYtUzyVDLA++q7efdENf6wxtut1ViNEToH7LhiE9bxiLtpSyN7T1bS57Hy7olqbv3oWxldfyykS/ZfAYd1Xe8HUEqdmzMJ8g2DJqQ4uV+CNskIb3GIjW8rk85hXr8Ud2TimQmHoP0Ib//bE5wfjfJjk5HZNbU8+tOfplVZ5XR+AmdTkwSiWvfKaPOb55jzk9/xp9oqrtx/iBd3vU3/QD9PP/RDcSlGI2IGxXvsGAzy/fQo9HfR/McXcO14h8UWE1/wBXgRKAmGueHoo7w362JOXN7E6pvniEbPEM0ti7jriY+ydG4vK2KENxqiBMIG3m6djVmLcKyrdFzSdw7Y6egvQSmdcnuADneG6y6NgbTIruv666PeHwGO5EyKfCOR8POWiNZOhKZB2CAEm7dY7Pe+U2L6pJvwNNgDvR3UvPQa7zisHHTYuP+rX8usqFsZaH7tDVxPP0XTmvPYMRDAXe7ghx09/Dyqcw3wmz9tl+ZEs+ZKLW1fR6z1nE1mV/2dkrA1OEhv8za+2dXHkp5+DEg2eLvFhPp/d3Hw0Goh4rvja99kcLUsoczmZ9t78/nKdbvYcN1uNm1p5NV99TgsXvadquKSxW2S2zLKq3LvLy897b8HKLX5GfRZMJsiE56UJuLsdz2mQiLh5y460ya32KUpan+XZEHqxDw188c/d6ws7cnvPsBnfAH+paqM+7/+RZw3fTpjMV3PPoujugbXznfZ9tqb2CvL+TuDYlPvINuBq6M6zS/+GTQTrldbaLrigzg/comYM+2tsPs92n7/KrMPneDWqM4xzcCv62bz78EQ+wa9RKJRmg+dwNVyU1IiAjJa6NHYa7w0WMXa1cn8o6lxH9veq+Oy5cfo6LPjXLUP56p93Pv0ZTz0QiNGQ4QKuw+PT6OmvJ/bfnQ1NeVDdPSXsuWNBgZ8ZtrdDirsfmzmCB9cepTaCm/O7HWYyWSHGOGjkjhWc86Zbsm4OWNxiHYf6BHtPnokiOF73/ue/ONx09zs4pztO2k1KLaYjDzw+S9nnTbscrlouum/0d8/wB///ApLVi/nlrf28e/9HrYD3338aX7fO4imaZxoPYGz9wS7frGFxa0dlOtgNWr8ZX4tDw55qb9oBVv2HGDp/Bqibx9m9TnzcLW8R1Pjfly7FtPU+J4UwxjUMLk1o9j9RosE4EyW4dpSgwGUhnO+AWYfxLV9Lk2X9UPD+wHYeE8fm3dHsZvC7Gmbz77Hm7nt/otoHzTz5Pb3YbcGCQSMBCMaBoNOIGRkfvVgTlyNo5H3vjHpYsy+MbmGrkuEsXKuTPYSTY1QEPQIzD9P3kdCMrlN5fKLRuHQbv7p81/iO9tb+JZJQ7vpBjb+7JkJpx7c9oXP49CDeHq76O/p4fBf3+DZSJQKXafZqGHQda4GyiJRBoHfKnjGqBFcdg7z51RSU1HCjv3H6Rr0MLvUwcWL6+noH6KpcTnOD6yUybjVLvlCJrN4poymCffTufehBh5rrmdOZZCPX9rJs6/M5kBrCZohitevYTVFWVw/SFefBc0QZfncbp7/1pPDmaHxH1UcWfaNmdmaPQ6lhLy9p6Tiqfbc4dIyk1nC6u2HYe7iWOHxSZiz4AwS/OUvfwHvAB8qjdB04hQeBfarLuY73/8/OcmxafrEjbj+87c0XX89d/3z/dgXzOWjbZ1sslu5rG+AsFLsmVXGL0Ihfj7gJQhoEZ0b7RbQo1x87lw6+oZYcU4dnojO9WuvBJOZDzWunNQGURvvPExHjxWHLcLjzfV4AhrRqI4/qGHSdCorQsypivCPX94bKxDphUWrpeWJr1/SNkJBOdkEEkCLZI9DIVrM0wetPpm4nq5gsovvuuOo/BA8bui3nVER9e1vfxv8Q7x89xc4/1gbr9bXcDxsSM/OTwNOp1NSBfwedrzxJpt//0dWrrmAX5rgWJcbi9nIn94+hDcYRlMKHZ1rGpezp62H6y5pxHWgi6Ybrse17TWarriEb/9Ucl5efnhlLJg1eT2ja6r8PNZcz4BHw2KOEgqbWFDrJxg0cM0Hek5XQDkviy+OYJSOESXlMHuBZLF6ByVINtgHUNitNKYFjCZJkjr+DtQuHm7RYSsR/3v7Ebn5PadkNChNiMbGWun9+R/+lY/oOvdHwtx+4yfPXEFuorA62Hj/v7Lxa/vF+zLUR/NL27nrp7/BbjETjOhEo1GWzp/H7iOnsJiMvHvsJBv+5jqcl7wf5wdXAfDAk7+R8wV8Cb3vkxBeM8pIZ8i+R2NHj5VZZWE0LYrNonPlRT3s2V/OZ649wcY7D49/ApMFyi3i/h1yw9E3b85UhiLZk8FoFOK27Zca1ao6sRvjhO+MEb7zOKBDaZVoxlAAQkFWvnuQl2wWDmDA+bkvTo6MjjKobQAUlMzC+ckaqKxl09PPglJs+Nv/juulV3jt9TfRgYVLluL87G0jyWqN+bAXrRqejOpRSRiLhMQ2DgXkxxD0j2wXrpSYeOksIwM0Xd7JsVNW5s2GDTcfT9DgWcBggM/+89uZfqxI9lQwaFKj2tchmqT2XAnQ2ErEhj95QOz2zuNS9WO2QiiAp62DCl3n/qiO3VFC8/MvTqzQYiyUVsocovck2Bw4b/wUzk98clg5GwwcO9kGStF0xaVC1rgXZfQcQqnYDyG+yECS0UjXxXYOBcSs87qH16XSNPHWpEgpGGmiTA2KZB8LcTs+EpIoZmmlaHmLXR5620Eon3NaG/b09OHocbPXbuUts5FZmplNmzZNHtkByqvE1AoHhxdbiGlv5y2LcN7yedmuR8WlGAkOL68TjbX7i3fXHc/zopRElM2WWH3sPLkPAa/Y0vFFljVNPDqTOAfIBkWyp4O4NhyMTY5KZ8nk1OqQHJqBbkCnyWrmJuDHJo05VZUE89UCWzNm5e354b/9SH4AVfNk8ucfkh+GMgih0zmnySx/JRVi/vg9cj/iCz+YbTlZij0XKJI9E8RrKAd7hfQmC5RX0bxrKVteqWZjay9+u5WP3XoDV668HNdfdhd0L5lVF140/KasWrR/0C81uQM9Yqtrmpho6fxwNU00vqNMNP5gryTDBcMx0k8t3fJ6daXU7UjWZIWu68/k89pJoevSUzAUPvMvHJaHFH+NJoTL9dhnoxHw+Am9uIBv9L3C654+uHktTVd/FD72tzg/V9grcWzduhVIKDpR0hAJi13Ms4B3+IetdDFN0s0NMpmhslai0kN90NsWay04dZo+bxFUpdTdwFZd13crpR7WdX3M5WbGjaCGI5KN6PXHcs794Im99/gS9iVuCwwf6w8KqXMEt7matQ4TAUeEuQsWgqOCDRs2TK69PkFcccUVQBr95CNh8UK5O4WwppjdngkiEUmO62kTRZHJEu+j4R2Aj6zP+MP51OwXAae1uVKqUdf13SmPHvKBa5v0HHEPQf+gvA54hLCB0PhXtFkkXdVhk9dZpdKgx2EDq1lyt+N/RiOYY68mbeQ+kxE0w7DH4rTnAjlviZ0Kk4GSf/wJJzv60HsH0d2e7LsAFBo0o0zOS2YNpw97+iWzMt67fdxzaJJQVzJLMkj7u0EzZ/6jmQCm0ohyj94QM3Nuh1gT9yeeE6KVl0BFKVRXSIOdOHkd1pFktie8t6V2g+UMkbCYNOgy9JttVFcbmVtfL+6+ArbXs4JS4nq1LpIJbc8JebU60r/XRpPEKEqqpBGtb/B0BdVkI59k3wlUxN/oun5G2GzEMjPz5+hs/LKQtpBcWPGkMWIP3l5K857z2fTbJbx55DfU1TOivvSshFIyCbUtG15Fz6CduTbqWLA5oH65pCL3nhrOqJxE5HOZmUeANbElZsZng8koWrpQiK4j1UfRiERV65dIF4L5K9j0u2XsfKeUQW8fh44cpLm5eaqlzQ8MMdOkfpmMbJ7+zJbqMRgk07RuGbLyydCkLsCcN82u67qbmNZGuhVMH4RjD7CiBmbVSmDGZKX5jRW4Xiyjs9+BZjISjZ7HokX2aWOrP/zww7k5kcki9QDxhlSallmfTJtDFoLrOiFxi/FW4MsSRT/7WIibLFaHpAuYrBJ4cZRDVR2uF8toHyqjq9eI1QqNjbcwMOCaNrZ6ytba2UCpWAsSh6ze4R0UMy/dkVkzysIQVrvU3VpsaefdpIuCWge1oBCNiLuser5onTjRSypjKadBmtZZ2X/QRGMjrFgBs2Yp1q93TgutDuByuXC5XLk9qckC8xYJ8b0DwykM6UApMYvqlkpQKujPqWhFzZ4M4bAMxfMXD3cf8A0OE90/BNXzca53gAVcLmhqggce+AHbtwNMD83+gx/8AJiEFUMMGlTXC/F72kRbZxI9tZeKgmk7KIEtiz03YuXkLFOE5pZF3Pbja093j8oJwiHx/S5YkUD0Iak7nbNA+raXVcpfEakR19I150jaQTiNuEgizFaZ+JosolxygGmn2eMdo2oqhnj8pQswGHR2Hqph0xZZR3hCrdLCISH4vMXDmijoF5dYzTmSJahpUDnvtC3qcoHDIa9FJEFJhWj69iOAnn4QCsQnP29xbA7Qn3q1vnRPl/UnpwiuliW0u208+cp5hCIKUHj8JnoHrfQM2Xj3RDU7Do6z9EkyhEMS3as9d9gTEInIBHXectnm9ULdEtCMNDfDpk3Q2Qlz5sCGDfDAA5P2tac37KVix5+KPYtMCK8Z5Zl0toqnJllTqzQxLcyYRHOlpmKIF98+h1DUQFQ3gK7T7zVzyl1CJGrAH9Z4cMuFtLul5VpaCIck629uwzDRdR0CQ2KjW2wSJq+YDVYHzc1w112wdy+EQrBwIUyTOenUweqQ/jyhYOYmjcEANQvE9esblIKVLDAtNHt8lbRNWxo51l2O3RLCFzBitYbxh0wYVISorqhwBFDAubPd7D5Sw/vqu2luWYRzzaHT5s8Z2j5O9JqFI9NYA16ZkJZVinZXhtMF1i4XLF0Ku3dDXZ1MTgGeeOKJ/N2UHCDv8sYJ33ZwuLNyulAKquvEjGw7mNXlC5bsbo+V2358LTUVQ+w8VMOh9gr8ISM2cxiTFuHDy06y4brdI5Y++XjjIZ7dtYi9J6sIRjTcHsvpDlenlxVM7HgVDkkhxmiiRyLyWi05LgS8ouFjD6emBrZtg899DjZuHP7Y/Pm56SKQL0yJvFaH1M6eOiTBpEyCR0pJxNXvzerSBUv2fq8VhyXI5teWY9KiBMNGdMAfMvKBJW08/53NQMLSJ8BtP76WcEQjGDaiVJRdR2oJRTWaWxbRtObAac0OCNEd5XLjRxcmBLwwZ6FMTMMSLaWkAoB774UHHxQ7ffNmuPjiYRPmV7/6FQA335xx4fuUYMrktZeKguk4KjZ4pklgVfOyumzBdgQzG1fr9ZUvUmaX5jhtfXb6PTbqKof44edeTDrxbG5ZxKYtjRxor6CzvwS7JcScMi8fWtYmq0jEEQ6Jhkm2kl7c+1K3VDSJd1A8MbF+7suWiZ1+4gRccw3U1sKjj8pH084PLxBMubzuTvHD20sz87IEfLDwfQWdz54RjIYoPUN23je/m/2nqli1sHvc5vTxFSBgmPidA3aOdZVy7y8vFQ/N6r04P9AqZB5NdF2X4uq5i+Tmh4Ox1bKHPQDr14tGv/FGKC8ftteLyALls6VTwWDviHs8WShYskdRXLPyMHuO1bB0bi+BkIYnYE672WXi0iftbhsPbrmQy5Yfx7V7Oc51KeohA15ZZDe+2nW8FV7MzGluho4OuP/+ovclJ1BKTJKQX7R1JinCWaBgXY/1swYot4dY/8G91FbIUiWPfjHzVdia1hxg/6kqZpd5ePHthbz4XgNXf/WDNG8btbBAJNbIf1ZsRb1wSMyZmMaJuxvb24sBpJzCoMn8SNdlJJ1EFKxmj09QO9wlI+3tDOFcc4gdB2p58A8XYjTC8S47nX3DK0tII81OnBcdkchovIAg6Jd8dWU4PSk1GODIETFhisghjGYZQU8ekBZ7k1S1VLAT1MU1DfpHz/9tZlHQUTi9GFVnCRaHiRffmEc0CpGIwmzSWbrQQ/0cP/uP27n/y+/g/EyV+HGjUTFpFqwAzXh6Unr0qASQzGbYt+/M63V3S8er6uoMlqOZQhScvP1d0iHZkbz//WlkOUEtWDOmwuE/7S7MNtHL1bKE9j4r77TNJoCNr3/mCJet7sNs1plTFaCz18z+4w6W1g/g2rlouMVDwMe9D55D3QIjq1fDypUQCMDs2UL69euTX6+6urpwiJMGCk7eslh3syz96OOhYDX7moZ5+upzt+CwBPEEzFmZMs07zuGup65iaUOQ2qoQj37nXUCa4z+0eQFKwcfWdFPuCFNzbgkd3SZqqkN0tMOW7bPoHzAwNAR2u0RML78cPJ5hV+NoPP744wDceuutWX7r/KIg5Q2H4MS+4XZ8yXC2aXaQyWUmHpgRiEZwrjnM/V89QG1ViKbLO0/v2njnYWbPCrGgJsCeA6U8+oNuOrpNtHdpPPhYOW/ss+PuN+DxxJZPNUjCl8cztqvx8ccfP02g6YCClNdoiqVSe3Nej1qwE1QY6TfPCLoudnf9UpyL3TivdJ9xyPq1p9i8tZaVSwa4+gvLOXDURGe3xrw5Id7Zb5Za4Eo5jd0Ot946MjWgiEmEvUw6kg30SGlfjlDQmj1rRGLNOu1JVrGOYeOdh9n35B8przKz97CZk+1GNIPOyXYTmqawWsFqhccfl2hpkeh5xqxaiYWEcueOPPvIHg5LvkXl3LGPi4Rofm0+O98ppaPTgNWiE4koqqt1Zs0yUFUFDz1UDB5NGeL+96A/YVWQiaGgzZiMEY2IR2VukuSu0Qj6ce1oIBxW1MyOYLPB+mt62bFvNhikGKNI9CmG1SGlfQPdOTFnzh6y64iBXbd4/EqYSBg0I00fD3GsM8w8FBtu7cV5lQfqarJeke25557L7oNThGkh76waWbAtHMyswikJzh6yR0JyYxyp7fTTCPqhug7ntX6c18baNXgHobx+QjWOdntuquDzhWkhr0GTWoK2g9JHZgLP5+yw2SNhyU6sqhv/2GhkuKFPHHEXl22cyN04eOihh3jooYcmdI58YtrIayuBiuqRC5hlgelP9jhR5y5Kr+ol4JMRIDHrMRQQd1e6jfZT4Omnn+bpp5+e0DnyiWklb0WtPN9M61cTMP3JHglDVX16jXTis/rSqpHbwyEJVRdRuNCMUiYZ9GV9irySXSm1Vil1X85OGA7LEDdrTnrHB3xiviS2RtajYtZYHTkTq4hJgr0MHBWSpJcF8kp2Xde3Ag05OhkYFNScO76bMX68HhVXViKCAbmBk9A1togcQylJw86y4en0NWMiYZmlm63jHwtil9tKzqyGiYSlw0AR0wMms5TzZYGcZz0qpdYl2bw11p8dpdRmXdeTJskmLjNj0gyNi+bMStrkz2gwGD3BkOdIl/tYjsTOFaqBqV3GOTtMR7mtuq6fn8kH8p7iOxbZRx3Xouv6mnzIlCtMR5lhesqdjcx5n6ACDbGlZoooIq/IawQ1NkG9MJ/XLKKIOAp5gvrI+IcUHKajzDA95c5Y5oItyyuiiFyjkDV7UiilGhJfi8gtzub7W3BkV0rdHou0nuHCVEpVAA8rpXK0pmH2GEvOdPZPFabL/U3EeJH3tO+1rusF8wfcDTTG/n84yf6K+P4Cl3PM/QUsd0Hc3xSyb87mOyX+FZpmvwhwx9+kcFGuif2Kb8+bVGdiPDnT+R5TgelyfzNB2vd6Soo3UkVZk2xzJ77RE1bJVkq9QOF4EdwT3D9VcCe+KeD7mwncqXZMCdl1XX8m2Xal1E5kKI0fd3jU/tuBp2MPZSoxppxp7J8qTJf7mwnSvteFZsY8ggyjjcAL8Y0JQ+rTSAR2LbB5CuSLYzw5k+4vAEyX+zsCySLv2dzrop+9iBmDQtPsRRQxaSiSvYgZgyLZi5gxKJK9iBmDItmLmDEokr2IGYMi2YuYMTh7ej2ehYhlIa4BGoHDwEW6rt8zpUJNYxSDSgUMpVSFruvudIvUixgbRTOmsHFTTLv3wumweRFZokj2wkcDcEgptU6XgvUiskTRjClixqCo2YuYMSiSvYgZgyLZi5gxKJK9iBmDItmLmDEokr2IGYMi2YuYMSiSvYgZg/8Pr4AIOfSXHhoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "train_vi_loss = errors.loss(mean_vi_train, sigma_vi_train, y_train)\n", + "test_vi_loss = errors.loss(mean_vi_test, sigma_vi_test, y_test)\n", + "\n", + "ax= plot.plot_prediction_reg(\n", + " x_train,\n", + " y_train,\n", + " x_test,\n", + " y_test,\n", + " x_linspace_test,\n", + " mean_vi,\n", + " sigma_vi,\n", + " f\"Train {train_vi_loss:.2f} Test {test_vi_loss:.2f}\",y_min=-1,y_max=2,\n", + ")\n", + "savefig(\"MLP_VI.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/sindata/Calibration_VI.pdf to figures/sindata/Calibration_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/sindata/Calibration_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjHUlEQVR4nO2deXRURdq4n0pCAomEgGFXljAICggGUBBxjhBAmPHz+BFEx2VAIMKMO4rrNwaQ1WVwFCFxBUFlUdRB1igKIvwkCYwswiAgGEFMgIRsBJJ+f3/c26ETernd6U46oZ5z6uT2vVW36qbv27W8SykRQaO52Amp6QZoNMGAFgSNBi0IGg2gBUGjAbQgOEUpNUkplWSmDKVUglIqRSkVV9Ntc0awtNdeX7D+n9wRVtMNCFIyRSRNKRUDjDCP0wGPX7BZ5nYRSbWQNwEYJCJPurkeAyAiy12dq2J7Ex3uU/laEnAQiKmcx6wrCcg025MGpCilDgKzPNUbdIiITpUSxhcPxhe8vvJ5P9e1zMX5eCDRPM5wda4q7cUQlBQX1yYB8ebxBXkwXvYE+zOYdcfX9Hfna9JDIyeISK6LS72UUsvMoYh9GJBofo4xPycopVIcjtcrpeKVUl79SopIJpBm/irPcHXOQ3ubmMOlRLMtCWZbksz2xgFxZi9Tmd5A+X2VUvFumhsDtMP4/ySY7atVaEHwAhFJM//OFpGDpjDkAsuBpx3yNKmUPxM44eFlclZfLsaQY6S7c26YhTG0OQiMwOhRRgIHRSTXbN9Bezs9kFvp8wwg3nymJsBpEUk17zXCwv2CCi0I3nPS4TgO40uPw/V4PNeXSuy9jogcxPjVjnF2zsKt0kUkU0TuxxDYGcAgxwmti8ntNsy5CIBZJw6fc4FUU8gPAgkW2xOUaEFwgX3Si8PQwfzr+ALGAevN45PmsCOe8y+u/TgOY6jRq1Id9vvFO5yzDyuWAzFKqURgifniOTvnsr3Ak8DTDu1KNNuwjYoC7aynSsUY6sQ7PGN5+8xnut1sywxgqUPdy5zcL6hR5mRHo7mo0T2CRoMWBI0G0IKg0QBaEDQaIACCYCpUXCqPTGVOgrnaoNEEBX4XBFOh4nRNXSk1CWNdOw0Y5O+6NRpfqe6hkVu1vdlbpCul0rt06SKATjpZSieLTkqTqc1FTVaiJivBS2p6jpDr+MFU0fcSkV4NGjSooSZpahsr962k9SutOWX7Hb4Dznp/j+o2w3arttdovMFmszH6s9Es/GEhCgVLmtLg58YUU2+ut/cKyGQZ12YDTtX2Go23HMk7QrtX27Hwh4U0rt+Y7t9OgR+zmTnzBWTtrge8vV/Qmlj06tVL0tPTa7oZmiDk7cy3Gf/FeEptpQyOG8zjl33M4AFX06JFY379dRshISHK23tqDzVNraHUVsotH9zCmgNrCFWhzP/TfMZecz9t284FDjFv3huEhPg2yNGCoKkV/HD8B25acBMni0/SumFrNt+3mbYxbZkzp5Bff53KlVfeyK23DvH5/jW9aqTReGTGphn0mN+Dk8UnuavbXRx55AhtY9qSkwPPPPMacJzU1Oko5fWIqBy/9wjeOHxXvq7ROFJwtoABCwaw7eg2IkIj+HD4h9x25W3l1ydOPEVx8Sz++Mc/ccMN/apUl197BAua46cxIy5gzdVQc5Hy6d5PiZ0dy7aj27gy9kqyHsuqIATbtsHChS8Bubz66rQq1+fvoZFXDt+1Mf6NJvCM+3wcty25jZKyEhSKjaM3EhsZW37dZoOkpN+AOSQm3kn37t2rXGeg5wi5lT5XdviugKOJRXZ2doCbpgk2fi/4nSteu4K3tr9Vfk4Q1h5YWyHfe+/Bjh3TCAkpYfr0yX6p29+C4JXDt5Pr5SYWTZs29XPTNMHM0t1LuXzO5ew/uZ9rW19LiPlqhoWEMaTD+dWgU6fg8cd/RqkUxowZQ8eOHf3TAH8GScIQgiQcAlGZ55McAkolYTiRuw0G1bNnT9HUfcrKyiRxaaKQjIRMDpGXNr8kIiLZhdmy6IdFkl2YXSH/Aw+IwF8lPDxCfvnlF1e39f7d9aVQdSQtCHWfn078JC1eaiEkI01nN5U9v+9xm3/HDhGldotSITJx4kR3WXWkO03t4PXvX6fT6534reA3bu10K0cnHuXKple6zC8CDzwA9er9H1FRUTz11FN+bY/WLGuqlbOlZxmyaAhfH/6aeiH1eO/W97i7+90eyy1eDN9+uw34hGeeSSY2NtZjGW/QRneaamPbr9sY9P4g8kryiIuJY9PoTbSKbuWx3OnT0KkTFBUNIjx8BwcPHqRhw4buinitYtZDI0218OxXz3LdW9eRV5LHuPhxHHj4gCUhAJgyBX777StOn07jmWee8SQEPuH3HsGdiYV5PQEj3GCcs+t2dI9QN/jp5E8MXDiQI3lHaBDWgBUjVzDkD9aN4/bsgauvFpo06UtExK/s37+f+vXreypWsz2CJxML+yYXYugRtFa5jrNk1xI6vtaRI3lHANg1YZdXQiACDz4I9et/Tnb2/+P555+3IgQ+Ua0mFqaAPK2UWoYR2rwCWrNcN7DZbNz32X3c8fEdFc5v+XWLV/dZvhy++qqM6Ohn6dixI6NGjfJjKysS6FWjXMcPpm3Rkxi9wSwq9RpibLeUCsbQKMBt0wSArNNZ3PDODRzOO0x0RDQFJQXYsF2gIfZEdrbRG1x22YdkZe3mo48+IiwscK+rpTsrpaJF5LRSqh1wUkROu8jqyTk/UURmm/fEIda/pg6wYMcCxv57LKW2Uga2H8iqv6zi9NnTrD2wliEdhlQwnHPHkSPQpQsUFJwlNPQfXH11D0aMCOzeI1aHRiOVUgOA8YCzbYbsuI2pDyw3tzGKxxAoLQR1gFJbKX/+4M+M+mwUIsLcYXNJuzeN8LBwYiNjuavbXZaEYM8eSEqCK66AggKAtygrO8SYMdN8dsG0jBX1M3ANhuVoI2CALypsb5M2sagd7Dq+S2JnxwrJSMuXWspPJ37yqnxZmcjq1SKDBxsGP/Xri4weLdKiRY7ApRIRcb0UFNi8bVbATCyaYCx5NgF6BkQiNbWO2Ztnc/X8q8kpyuHOrneS9WgWHZp0sFS2qAjmzzeGQEOHws6d8MILxrDo9deLOHPmCuAEjRsfQqniwD4IaKM7jfcUlhRKnzf7CMlI+NRwWb57ueWyv/wi8tRTIo0bGz1Az54iixaJlJSczzN37tzyUI5RUVGSkZHhbRO9ft+sDo0GmH/bAf/rS0Vm+Rgs7lWsBSE4+ebQN3LJ9EuEZKTza50vMJOuTGGhSEaGyDffiNx5p0hYmEhIiMjw4SKbNonYKo16vv/+e4mKipLw8HCJjIyUuLg4KSws9LaZXr+bbleNlFLDMZY445RSIzA0dgJ84qaMO81yAnC/GW2gCTBODOWaJojJKcph7YG1bDy8kdSMVAAeue4R/nnzP92WKyqCuDhjKdRmg+hoeOghw4q0ffsL8+/bt49hw4bRtGlT0tLSyMvLo3PnzkRGRgbisSrgVhBE5GOllF0LbN+F0eUcwdQsp4lIprnpdmVByBSRQWbeBC0EwU9OUQ7NX2yODRsAUfWiWPWXVdzY7kaPZZ96Co4fN47Dw2HlSujf33neo0ePMmTIEJRSrFu3jg4drM01/IXHybKIHMLoFQZi/KK729fAk2b5oHk+UZxscq01y8HHCxtfKBcCgFeHvmpJCObNg9deg6goiIyEyy6Dni5+Qk+dOsWQIUM4ceIEq1ev9p/7pTdYGT8BPRyO27nJtwzDmA4gxX7sJN8sT3XqOULNUlZWJncsu0NIpjyFTQnzOCcQEXn/fRGlRG65RSQ315gjuBrmFxYWSr9+/SQ8PFzS0tL81Xz/zhEcmK2UEiAPaI/xy+8Mj2HfHTbD1gQph04d4oZ3b+Bo/lEubXApn4z8hF9O/2JJO/zppzBqFNx0EyxdCvXrQ7yLoD6lpaWMHDmS7777jiVLljBw4EC/P4tlrEgLMNDh+Bo3+WJw47xvHic6fnaVdI9QM7zx/RsSOjlUSEb+tPhPcq7snOWy69eLhIeLXHedSH6++7w2m01GjRolgLzxxhtVbPUFBGb5tEIBiPalIm+TFoTqpeRciQxcMLB8CPRO5jteld+8WSQyUqRbN5ETJzznnzRpkgCSnJzsY4vd4l9BAMaaf+cDS4ClwDZfKvI2aUGoPjKOZkjMzBghGWn7z7byS57LMClO2b5dpFEjkY4dRY4d85z/5ZdfFkD+9re/ia2yIsE/+F0QGpl/LQ2N/Jm0IFQPz3/1vKhkJSQjoz8dLWVlZV6V//FHkaZNRS6/XOTwYc/5Fy5cKICMGDFCSktLfWy1R/wrCBUyQg8cVo8CnbQgBJa84jy5Zv41QjJS/4X68sV/v/D6HocOiVx2mUizZiL79nnO/8UXX0hYWJgMHDhQzpw5432jreP1+2bVH2Ec0ME87iUib7nJ68lnORFD1xAvpm+CpnpZf2A9t350K8WlxXRr1o2NozcSUz/Gq3scOwYJCYa59DffGKbT7tiyZQuJiYl0796dFStWEBER4fsDBAIr0kLFodFAN/kmYYZyBFKcXE/AIfyjuzp1j+B/sguzZcCCAUIyopKVPJ32tE/3yckR6dpVJCpKZMsWz/l3794tjRs3lo4dO8rx48d9qtNLAtMjYNgabTOPnViJlNMbB7MKpVS8VDSjGAQcsO+8iemWqQk8u37fRbd53co/r75rtVeO9Hby8w2z6f37YdUq6NPHff59+/YxYMAAIiIiWLt2Lc2aNfO6zurAqj/CUmA28CZOnO7dkFvpcwxGFOw0YJC5g0452sQiMCzeuZge83tUOJdTnOP1fYqL4ZZbIDPTUJYNGOA+/86dO+nWrRvHjx8nNDSU5s2be11ndWFJEEQkT0TGi8hIEfnZTVZPmuUMD/XosPB+xGazcdtHt3H3J3cjiMtQ61Y4exYSE2HjRli4EP7nf9znX7RoEX369OHcuXMA5ObmsnfvXp+eo1qwMn7CGA4txdAltHOTLwbPmuVJOMwVXCU9R6gae37fI01nNxWSkRYvtZD/5vzXZah1T5w+LTJokLHGmJLiPm9+fr7ce++9Aki/fv2kTZs2EhUV5atfga8EZvkUeMIUhvbA475U5G3SguA7L3/3soRMDhGSkcSliV7rBhzJyjI0xiDSpIlr4zkRkYyMDOnYsaOEhITI888/L+fOnZPCwkLJyMioTiEQCaAgDHA47mH+DaiphRYE7yk+Vyz93u5nuFBOCZePdn5Upftt2ybSqpXxloCxSuTMa9Jms8mcOXMkPDxcWrduLV9//XWV6vUDXr9vVleNnnLwUGuvlDpk9g6+7/Cs8Subj2xm6OKh5J/Np2OTjnw7+luaXeLbCo0IpKYa3mTNmkGrVpCXB82bQ+fOFfPm5OQwevRoVq5cyS233MK7777LpZde6ocnqmasSAtOdAfOzvkz6R7BOk+se6LcTOKBLx6o0r0KCkTuucfoAW6+2dAZ2P2OK49uNmzYIK1atZLw8HD517/+FSi7IV8IzNDIn4nzjjtaoVZFThSekKtev0pIRqKmRcmXB7+s0v327jUUZUqJTJlixBxyxrlz5+S5554TpZR06tRJtm/fXqV6A4D/BAEf5wAYq0YJOKwaOVyLwYiA59J7zZ60ILgmuzBbHl3zqERMjRCSkd6pvSW/xIMDgAeWLhW55BKR2FiRdetc5zt8+LD069dPABk1apTke3I8qBn8KggzHI57OBy3c1PGk4lFDB5207QnLQjOOZ5/vHwYRDI+m0nYKSkRefhh403o21fkyJEL89hXfj744AOJiYmRhg0byuLFi6tUb4DxWhDcTZbTlVLzzeM4pdQpzMkyrl01PZlYgBEbtYnZI1QwsTAN9pIA2rRp46ZpFyeHcw8TnxqPcD5QeJdmXXy+X1YW3H47bNkCDz8Ms2cb0SYcKSoqokuXLmRlZVFaWkp8fDxLly6t9igTAceKtODgg4B7V01Lzvvm9fXu6tQ9QkVS0lPKXSjtPYJVZ3pnrFtnDIMuucQYFrli2bJlYvqrS1hYmGzdutXHJ6hW/Dc0uiCjBX8EPA+NkjAj3WlBsMa5snMyeOHg8hf/rYy3fNYQixgT4MmTjQlxly7GBNkZNptN3nzzTWnQoIGEhIRIREREdWuHq0JgBAEYB8w001g3+WJwY2JhXo9Hm1hYYvvR7dJ4ZmMhGbn8lcvlcK4FFzA3HD5szAPAWCItKHCeLzc3V26//XYBZODAgfLTTz/VhHa4KgRMECz5I/gzXeyCMPnryeVDoL+u+GuVzCRsNpEFC0RCQ41vPDbWtRBs2bJF2rVrJ6GhoTJ9+vQq1VuDBLRHiDaTyx7Bn+liFYT8knzpmdLTcKGcWl8+2/tZle534IChGAMj+K4rU4mysjKZOXOmhIWFSdu2beW7776rUr01TMAEoRHnI1m086Uib9PFKAhpB9IkclqkkIx0ndtVThRaiIvigpISkWnTjI03GjYUefFFkfbtDSGIi6uoJT527JgMGjSo3Kn+1KlTVX+YmiUwglAT6WIThAkrJ5SvCE1aN6lK99q0SeSqq4xvd/hww4JUxLmpxJo1a6RZs2bSoEEDSU1NDSYziapQ84KAG82yQ554d9flIhKEY/nHpMOrHYRkpOH0hvLdEd+HJDk5ImPGGN9q27YiK1e6zltSUiJPPPGEANK1a1fZvXu3z/UGIYETBH8sn1bKd1GvGmUXZsvfv/i71JtST0hG+r/TX4rPFft0L/tkODbWmBBPmuR6MiwicuDAAendu7cAMn78eCkqKvLxKYIWrwXB3+FcPGqWTcf9NKCXlbrrIr8X/E6Ll1uUa4in3jSV5258zqd77dsHEybAhg3Qty+kpEC3bs7zFhUVMWfOHGbOnEloaCjLly9n+PDhvj5GncKqP8JBEXkTQCnlTcjiXMcPpmCkVd43weF6nTex2Jezj2vfuraCmUT7xu4CgzjnzBmYMQNmzjT2H0hJgbFjwdkurKWlpXz++efcc889FBUVERERwY4dO+hc2bngIsZqFIs4pVS0UioaIwyLK6yGhU8AelaOYiF13Hl/ztY5XPXGVZwuOY1CAb450q9cCZ06wZQphkP93r3G/sSOQiAiZGZm8uijj9K6dWuGDx9OUVERAKGhoeXHGhMr4yeM5dN5GMunV7vJF4MH533z8ywMW6QYV/eqS3OE4nPF0v+d/kIyUm9KPfnghw+8NpM4e1bko49Eevc2ZnZKibRocaGzTFZWlsyaNUu6dOkigISHh8vw4cNlyZIl0r59+5pwpK8J/DtZpmI07KXoaNhes/WXrRI9I1pIRjq82kGO5VsIF+3AyZMis2YZMUbBCLYbHi4VFGP5+fmyYMECSUhIKDeQu/7662XevHlywiFGew050tcEfheERuZfHQ3bB55a/1S5mcT4f4/3quzevSITJpyPIDFggMjnnxsbcLRrVygREd9L8+afyx133C2RkZECSPv27eX555+X/fv3B+iJag3+FYSaTLVZEE4Vn5Kuc7sKyUjktEhZ95Mbly8HbDbDPHrYMOObiYgQue8+kf/853yePXv2SHR0IwHjl79Ro0aSlJQkmzZtqivKMH8QGEHATxuOe5NqqyD8e++/pf4L9YVkpGdKT0sulEVFIm++aZhFg0jz5oaptD1ers1mk02bNsnw4cMlJCREMHelj4iIqO02QYHCv4IADDfnB2vNyfJ8YJ4vFXmbapsglJWVyV9X/LXcTCJ5Q7Lb/IWFImvWiDz5pKEIA5Hu3UXee0/EvnVASUmJLFq0SHr16iWANG7cWB577LGaih5Xm/B/j4DhmmnZ9BoPJhb2a66u21NtEYTswmyZs3WOtH65tZCMNJ7ZWDKOOomC5UBOjrHVkj1w1p//LLJhgzE0EhHJzs6WadOmSatWrQSQTp06ybx586TAVBdfRJNeX/G/IFxQwI2ZBZ491MqXVYEMd/XUBkHILsyWkOSQckf6P777Ryk5V+K2zKZNIm3anBeCBg3Om0Tv2rVLxo0bJ/Xr1xdABg8eLKtWraqtPgE1idfvtSWFmqlMm6mUWgLc7yZrbxy0yZU1yGKYW6SZGuQZTuqpNWHhS22lJCxMqLAr/bie4wgPC3eav6DAiBx3442gFLRsaexK36KFjcOHVzNkyBC6du3K+++/zz333MOuXbtYu3YtQ4cOJcSZuljjX6xIC/A45hAJ9zvmWHLex9BOL3NXZzD3CDt/2ymXzrrU8q70X35p+AKAyIMPGkugv/xyQkaPflo6duwkgLRs2VKmTZsm2dm+OeNrKhCYoRHGpHkgxoTZ5aoRnodGkxwEJYNaqFmevnF6eaTpvyz/ixzPP+5SQ5yXJ5KUZPyX//AHkY0bjfNffvml1KtXr1zz+/bbb0tJifshlcYrAjdHcOgRxrnJE4N75/04+zVgkrv6gk0QCksK5drUa4VkJGJqhHyy5xO3+VevNrTBISEiEycaq0QFBQXy8MMPi1KqXAMcGRkpGc5CTGuqQsB6hHa+3LwqKZgEYcOhDRI1LUpIRjq/3tmtfdDJkyKjRhn/2SuvPL/Z3oYNGyQuLk4ASUpKknbt2ukl0MARMEF43OE4oPsi2FOwCMJDqx4qnwc8uuZRt3k/+0ykZUvDOeaZZ0SKi0VOnz4tEyZMEEDi4uJkw4YNIqKXQANMwATBvm3UfGCtLxV5m2paEI7nH5crXrtCSEYumX6JbDq8yWm+wkKRtDSRESOM/+bVV4ukpxvX1q5dK23atBGllDzyyCPlegBNwAmYIFxUcY2W7V4m4VPDhWSk71t9pbDE+a92YaFhDmHXCTz7rBE94tSpUzJmzJhyZdjmzZur+QkuegIjCF7d0HNY+KCdLJeVlcnIZSOFZCRkcojM/na22/wvvXReCOyKsZUrV0rr1q0lJCREnnzySSku9s0PWVMlalYQLCyfJgXr8un3Wd9LzMwYIRlpOrup7Dq+y23+9esNv4CICMNUum3bE3LHHXeXR4X4/vvvq6nlGid4/e76W2XpSbOcKg7umyKSSxAwa/Msrn3rWnLP5KJQ/DDhB7fh1r/7Dm691XCX3LmziHHjXqKo6EqWL/+If/zjH6Snp9O7t6vI+ZpgxKrzvq/kOjuplJoEjHByvlqd98+WnuXmxTez4ecN5ecE4ctDX3JXt7ucltmxA4YNg9at4f33j3HNNR0pLCwkPDycjRs30rdv34C3W+N//N0jWHXeX175vJm/2pz3t/26jWYvNWPDzxtoE93G0q70e/fC4MEQHQ2zZ29l2LBeFBYWGuXCwoiIiAhomzWBw9+CkIqxI048xl5pQPkvvV0I7I77y/xct2X+76v/47q3riOvJI+x14zl0MOHOP7EcRb97yKOTTxGbGTsBWV+/hkGDQKwcfvtsxgxoj9hYWG0atWKqKgoWrRoocOj1GZ8mVhURwrEZDmvOE+6z+suJCMNXmggq/67ylK5o0dFOnQQiY4+Jn36VAyWqxVjQYnX71ug5whBw+r9qxm+dDjFpcV0b96djaM2El0/2mO5EyeM4VBW1loiI+9lx47TpKSkMG7cOJQyYhPFxzuNV6apRdR5Q3ebzcbYz8cy7INhnCk9w3P9n2PH+B2WhCA/H4YMOcuPP06ipORmWrVqSnp6OklJSeVCoKkb1OkeIet0Fje8cwOH8w7TKKIRafem0auVtZCrxcWQkHCQjIw7gG2MHz+eV155hQYNGgS20Zoaoc4KwsL/LGTM52MotZUyoN0AVt+12qX3WGXOnoXrr/+QHTvuJzIylIULdbDcuo7fBcFcITqIoTW+YJnUXDkaJCJP+rvunKIcvtj/BR/s/IB1B9YRqkKZO2wuf+v9N8v3OH26kB49HuLQoXfo0OF6vvzyA9q2bevvpmqCDL8KgqkoSxORTKVUCk70BWJEw3bn9+wTOUU5NH+xebkPcfOo5my+bzMdmljbGLuoqIjPPvucCRMmk5e3j4EDn2XNmmTCwupsp6lxwN/fssf9EdxRFc3yg6sfrOBI/+KgFy0LwYEDB+jd+1pOnToNNOWOO9bw4YeDvapfU7upERMLV4hIKoZSDjOolUeKzhaR8H4CW7K2lJ8LCwljaMehTvOfOXOG7du3s3Xr1vJ05MgR8+ow4Eb+/vcLFWqauo2/BcGjiYU/+fbItwxdPJSCswV0urQTK0auIPO3TIZ0GEJsZCwiwuHDh9myZUv5S799+3bOnTsHwCWXtCUkpC/wKNAX2AdMQakHA9lsTRDib0FIBW4319grmFiYv/b2yXKct8Omyjy29jH+ufWfADx07UO8OvRV8vLy2LRlE2+seIPMzEy2bt3K8ePHAahXL5KoqN7AY0Af4DrCw1vSpw/06gUpKWfIza1Hy5Y/cM01kb42S1NLUSKWRiDVTq9evSQ9Pf2C8zlFOfR/tz97c/YSVS+KVX9ZRZeGXXj99deZOnUqZWVlADRseAUifSgo6AP0ISSkGz16hNGnD+XpD38wgm0BFBUZRnWdOxtbMWlqNV5rO2vVksiKH1dw58d3UlJWwnWtr2PutXN5c+abLFiwkDNnioEbgUeAGwkPj+bGG+vRt6/x0vfs6f4Fj4wEbSlx8VIrBMFms3H3irv5cNeHhBDC6KajOfzhMXqN64VSEYjcQ0jIQ9hsvwOLgZf59NN13HBDvZpuuqaWEPSCcOjUIfq/259f838lqiiKyI+b8e6Bd4EWwFS6dbuf++5ryq23wk03FfHbb41p0aIz8fF6fKOxTk1olt1ed2Tetnk8+MWDlEkZ6ut6FG4qpFBiaNx4IWPHjmTUqHCuuup8/t27I9m7N16P8zVe49fJcmXNsojc7811R8KaR0jZ+LOGJuJjCP3tNgYNepSJE2/gppsUoaF+a7am7uH1ZLlanfctXC+nLPQs/Kho9u9RvDLxAHl5n7B6dX8SErQQaPxPTWuWK1x3NLEgFNgk/J67OPuxx9478thjluuMBXK8aaQfylZ3uZqosza1dZeIdPWmQHVrlt1edzSxUEqlyzGx5jzggFIqXcT7clUpW93laqLO2tZWb8tUq/O+q+saTU3j1x5BjIBdqebHTIfzqe6uazQ1TTD7LKd6zuLXcjVRp25rkJQLWlsjjaY6CeYeQWMRpVSc498A1xWjlIoJdD3VTVCYWPiqja6Kf7S7suYXHWdPIjLbyzpjALxpq5kn3qyvcns8tTVFKXUQI4qgpXLm9USMJex4L54xAbjfNLVvgrGnXqbFsvb/z8nKz2nhGZMw5pWV3wG3PvCWLRl8iQrmz4S1nTgvuO6pnEO+C7axtVCn0/D1Fso53VDdi7ZOwtx40Yu2xtive1kuAYdNHr0o55g3wYc67f+fSV6Um2Wvy8X36XSrYqv/dxH/h4X3BV+10Za11N7WKa7D1/u6obrHtpq/bGnettWkl1IqwWGZ2kq5QQ71JlgtZ/+/KKUSRaRyez2VTQOeVkoto+KzevNdxngxBLR832AQhMrk+njdUzmv7+kqfL27cqbQpAEjrZYzvfWcCYHHsiKSawpumpdtjQEOmuUGuRn3Vy5nx8oGEBXKmi/wkxg6pFnOCriocwYQb77ITSzUa/W+5QSDIPiqja6Kf7Sv4evdllNKTVJKxZnn4xxeLqv1JQA9K72UnupMcvESe6ozw0kZb9rqS9lEEUkTQ6+0zOGX3ZPFQS6Qava4B734rq2/I+7GTdWR8LxJuavrbss5jEkzqDSGtlCnvdx6Ko71PZVzuqG6lbY6jIVTcNhSy+L/Jx6HMb8X/59JPpZLrNx2L/4/id5+l2a5JHvZSnVe8B17enecJa1H0GgIjqGRRlPjaEHQaNCCoNEAWhCCEqXUem/NJUw9Qkqg2lTX0YIQnFjy1TDtfpKgXFlVlTX2i5qgsDXS+IZU9O/QVAEtCEGCg2FZGmZ4fXN4ZDdSyxVjb4lJGMZnJzF6gBFyYbSQC8pV24PUUvTQKHiYBSy3a08dzh000wi7EJgvdi83w6EK5QLe8jqAFoTgJ11EMs1f/d6YQmKaKVgtp/GA1iwHCebQ6HYgHeMX3W6heT+wxMyWizHkOYgx7AF4ExiIYYZgP27iWE6qEH7/YkELgkaDHhppNIAWBI0G0IKg0QBaEDQaQAuCRgNoQdBoAC0IGg2gBUGjAbQgaDQA/H+j2vZR16ZzRQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1)\n", + "_, df_test = plot.calibration_regression(\n", + " mean_vi_test, sigma_vi_test, y_test, \"Test\", \"blue\", ax\n", + ")\n", + "_, df_train = plot.calibration_regression(\n", + " mean_vi_train, sigma_vi_train, y_train, \"Train\", \"black\", ax\n", + ")\n", + "ax.set_title(f\"Train {errors.ace(df_train):.2f} Test {errors.ace(df_test):.2f}\")\n", + "savefig(\"Calibration_VI.pdf\")" + ] + }, { "cell_type": "code", "execution_count": 40, @@ -1146,7 +1260,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.13" }, "vscode": { "interpreter": { diff --git a/notebooks/regression/step_data_comparison.ipynb b/notebooks/regression/step_data_comparison.ipynb index bb56b20..156a3f0 100644 --- a/notebooks/regression/step_data_comparison.ipynb +++ b/notebooks/regression/step_data_comparison.ipynb @@ -1069,6 +1069,116 @@ "savefig('MCMC.pdf')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VI" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from utilities.vi_helper import vi_model,vi_predict\n", + "params = [[16,16,1],[nn.relu]*2]\n", + "\n", + "mlp_model_vi, vi_model, results = vi_model(params,x_train, y_train.flatten())\n", + "\n", + "mean_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).mean(axis = 0)\n", + "sigma_vi = vi_predict(vi_model, results,mlp_model_vi,x_linspace_test).std(axis = 0)\n", + "mean_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).mean(axis = 0)\n", + "sigma_vi_train = vi_predict(vi_model, results,mlp_model_vi,x_train).std(axis = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/step/MLP_VI.pdf to figures/step/MLP_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/step/MLP_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAntklEQVR4nO2deZBlV33fP79z7r1v69f7Mj379EgjCQYhRiNj402FBEYCnLgsEHEMwXFABuLYKTvgxE5VFhwKEiokKQlLdrxRTlkgx2VjMDbCFktkm5HEJhDaRrNpumd6e/2633aXc/LHfd3TM9PL62Xee919P1Vd85a7nHfnfu/5nXN+i1hrSUjY6ahWNyAhoR1IhJCQQCKEhAQgEUJCApAIoamIyAdF5L31vydF5E4ReVBERprYhm4R6W7W+bYKTqsbsMN4ylr7aP1GfFv99RPAqkKo7/N2a+1DDWz7XuAk0G2tfeSKr+8E7hMRgF7gPdbap0TkTqAbYIl9tj1Jj9BcnrjyA2ttgfimXRFrbaFBEXwQeMJa+yjwhiU2ecpa+wZr7RuAD9VFcIxLovm3q51jO5IIoYnUb/qlOC4in6mbTiMAInJP/X13/f2dIvLgotdfFJFjIvLRK451G7BwnvpNvrgNJ+ePXxcL1tqngEfrPclHNvo7tyKJENqARTfkx6y1J+tiKAALT+j6Nr1XbP8UMHnlzX4FhWU+v+2KNhSAR4F71/s7tjKJENqHqUWvR4C31f9dbvxQWObzE9RtfbjUAyymPh5Y/P6DIjJS33ZkJw6mEyE0mflBL/ENd2f9szu5/AYcAb5Yfz1VN4GOzW+z6PUI8ZP9+KJTPERsah1bdIz5AfQ83cCLi94/AnSLyD3AwyuYcNsWSXyNEhKSHiEhAUiEkJAAJEJISAASISQkAIkQEhKARAgJCUAbO9296U1vsl/4whc2dAwbhAQvnUMyKawf4h7ag6jN0f7tt98OwGOPPbYpx9uKLHcNrDEEp0cRLJLyVj1OcGaMaHwS1dW58Fl4YZL0jx7D29W/GU2V1TZoWyFMTExs2rFEa0xYxZarSEd2U455xx13bMpxtjLLXYNosgBBiHRkVj1GOFEgGptAejpX3fZa0rYLasePH7dPPHGVs+aamO8RVEcWGwRYFN6B4U1qYcJSmFKF4MwoqjNH3dV7xW39Z04i+Ryi9WXfNbtH2DFjBHFdbKWKqdRa3ZRtiw1CwtFxVC6zughqPv7zp5Fs+ioRtIIdIwQAcR1Mobgpx7rrrru46667NuVYW5XF18BaS3hhAkQQZ+Ub24YRwYtnEVGIt/oYohm07RjhWqAyKaJiCd3XjXjuho5VqVQ2qVVbl8XXIJqYxpSq6PzKYzBrDP6p89iaj8p3XOsmNsyO6hEARAlRca7VzdhWRMU5oskZ1CqDY2stwdkx7MxsW4kAdqIQMimi6SI2ilrdlO2BMYTnx1G59KrjgnB0nOjCFNLZXiKAnSgEpRDAzJZa3ZStjzGYmo/KpFYd8AYXp4jOXUR151cVTCvYUWOEeSTtxV15Z8e6F9je8pa3bHKrtha25vOm1/0oohXirnwbhRMFwtPnUV35TVvQ3Gx2phA2YYHtV3/1Vze5VVsHW/MJzo7xK7/w/lUnHcLpYryW09mB6PYUAexA02gelfYIx6dp1wXFdsVUa/hnRhFHNyaCF87Gi2ttsFawEjtWCOK62JqPXecC2+23377ga7NTiFeNx1Ceg3gud9zzU9xxz08tuW04USB44QyqM4vo9jc8dqwQAFTKJZoqtLoZW4KoUCQ4O4ZKe4i7ck8QXJxaZA61vwhghwtBUh52roypJm4Xy2GNIbwwSTg2ierIrLhqbK0lGB0nOnUe1ZlftzlkqzUqX/8W/vOn19vsNdM0udZTlczn6Rmx1n6sWedeEdfFFGZRu1KtbknbYf2AYHQCaj4qn11x2tMaQ3juQiyY7vXPDoWjF6l98/uE58axYfPWeprZI7wdKNTza97bLkmkVCZFNDOH9YNWN6VtsNYSFefwT70MURj3BCuJwA/xXzxHND6N6ulclwhstUblye9S+YfvIOkUqitHM5cbmtYjXJnAdqkkUvUkVO8F2L9/f3MaBogWokIRZ7Cv4X3e/va3X8MWtQ7rB4QXJzFzlXi1eAXz5p63/CS25uM/cxKLQXXl134+a4lGx6l96/vYyOL099JUBdRpejxCPVvzI0ulIlzMZscjrLidtZhSFe/QnlUXh7YrNoqIpotEkzOIq1HplU1FawzhxSnCsxdiwazDi9SWK1S/+zzhy+Po7jziXjqGf/oc+XfcRcftP7Dm4y5Be0Wo1VMbNjX3vvUDrDErdtciEjvjzczi9Pc0dNxyuQxANrs5EW+twtZ8otkS0XQRgVXNIABTrhKcGWVucgrV0UFujSKwUUR4doza08+DSMt6gcU0c7B8J/BR4mS3vcCtzTivf+plVFee1OF9K24naY9oaiZ+MjmrX5a7774b2Hoxy9ZabM3HVKqYwhz4PigV+wutYtvbMCK8MEk0Oo6kPN72X34DgC/8j1XLNiwQTRbwv/Mc4cwcTncenI25w28WzRwjPEqTbv7LSHmYiWnM8AAqm152M1EKESEqlnB6u5rYwGuDtRaiCCKDDUJMEGBLFWylhjUWUYJ4LpLPrX4sY4imioTnxsAYZB0+WqZUofbsS4RnxlC5TNwLtBHb3iAWJViE4Mwo3g0HV+z2JZMimiyguzra3iVgOaJCkWiqiA1DIDaOrY0nBMRxkGwa1aAZYo0hmpkjOncR41dRuVxDveVlx6j5BC+di9cElMLp7wZpv+WrbS8EAJVNY+ZKRBMFnIHlxwCiFGIt0czcluwVouIc4egEqiODSq8/BNKGEVFhlmh0AlOtonIZdNfarof1fYIzowTPnsIYi9PdCap9Hy47QggAKt9BeHYM1ZlDrZBrR7LpLdkrWGvjefyOzLraba3FlqtEkzNEE9NgDZLNoHvWKICaT3B2lOC5U9jIoDs7UG0yDliJHSME0Rq0IjgzhnfdvmVNpEZ7hXe/+93XqKXrw1Z9bBihMo2tkMeD5gBbrWGKc0TTRQhCcDSSyzbkMv1P3/TWhdemVCE4fZ7g5DmwsQDaZSDcCDtGCACSyxJNzRBNzsS26nLbzfcKncvbxO0mBFOpIupycVtjIIiwUYQNoji3U6WGKZWx5Uo8doDYpTqdQnJrmwr+2TfeRTQ5Q+Ufvk14cRIRQXd1gNp6t9XWa/EGUV05wjOjqHx2WRNpPpwzmpnD6etecpv5THz9/ZuSgGrDmNky4rnxQtf4NGZ8GlutLSwlWerrJVqD6yD5joYHzYuxUYSZmSMcm2DsmWchiOjv6cXp6WrLQXCj7DghiHYaM5EWzyAt0Svcc889QHusI9gowtZqqI4swZlRwouTqI4cqntz0ijamk9UKMZeqC9fgDACUfyLP/8UiPDn7/ngppynlew4IUBjJpIohShFVGh8tblVWD8EG6/4hhemUN2dGwqQt5UqUamCmZohGhsnKsxiRVCORueyMB9j0IZB+OtlRwoB6ibS6fMrm0iZFNFUEd2Vb2sfJFONxwfRVBFx9aoisNZCEGCDMB4w12qYuQpRoYiZKmJ8P7aoRKEyaXRvz7a66Zeiff93rzGiHXD0iiZS7IME0dQMzlDjnqnNxs5VwNVEkwUkk8YagynMYopzmHI1DjzyA6j5mJqPrfnxfoDES24gCkm5qIyH6lh9tXm7sWOFAHUTaXqGcHwad3DpJX/JpIkKs7EPUgO5/puNNSa+0UWwQYA4muqJpwmnCpcGx46OXSKUihfastlt/4RfKztaCACqs4Po7Bg6n1tyDl7qSW2jyQLO7sGFz9/3vvc1s5nLYmsBH/6P/5FH/vhh9usM73jd7bz5th/C6Vta2J//y8/x1FNP0tc/wOTEOMeO3crdd7152e2W+x7g5157+9obHEWIX0NqVcSvIrUq1GoLr+c/T124iL6uFzbHDXtVtn19hNJXTiChQTLLO9zZShVcB+/IgWWdyaJiCXf/8IqOe83iAx/4AA888AAiwmtvPMqFZ59nUDz6lcttXhdTxgcRdvX3Ux2/wEBPLzPTk/T29FKYnEBh4/UD4hDF3p4eZqam4vcCN15/hBdeeA6MXdhWAy6W3o48ldkZXCzd2RxBeY49g4OM7N0LYYAEARIGEPiLbvjapZs8bDwS0P+Zd+H90R9sxiVbtftLhFDHzMyih/txFz31LzuWH2AR3APDiAhnz54FYN++ld27rwWO4xDVc7fulRQI3Khz7FEprtceP+hPcktUoc+2X35XK4L1UpBKY70UNpXGemlsKlX/N34dVkOcn7mHzL9+/2actr0Cc9oZyecIX74Y+8YsEdEmnouZLWHmyuh8jne+851Aa9YR7rvvPh544AFcURw9cgPPv/gid/3ET/DOG1+D+uRv0hfG2b4NUELhi2AApR2CKCKVyZLv6mKmWKRUqaAdBz8MMYBF6v/GN22+s4vCzAwREIhCpzPMVasEInzP1igZwx0Hb+KGozeD62IdDxwH67qXbuz5Gz6VBtdraHzinz5H/jXHr91FvIJECHVEKVQuQ3DyZbybRhD3asc1lUkTXZxquXl0//33c//992PmyoTnx4nKZfzvncT57/+VXGmGqKefubf9POH+w3z+r/7yMlvfJRbITP1Y8/NDi3/RleODXcu04zd+O05E8vPv+WW2ekKcRAiLEM/D1Obwz47G8ctXPLnE0dhqjagw26IWXo6ZmUVcjZmehQsXyX7zcQBm/8l9RHsPAXD3XW9edrC7HOvZZ6uzdZ1DrhEq34GdLBBNFJb8PnbIm4mjXVqIjSKiUgUrCjtbQn/pC0gY4B85uiCChMZJeoQlkM6OeNU5m0blLq8CI0ohWmGDcMPlpzaCrdTA2tibNAxJf/tJAGq3/GD8vV/DlKpYaxAstr50hsTtx9GI1nHmui2SlvFaklyBJRCtIZ2Kxws3HrpqvKAyKX7pnT+HXsYztRmYmVmU5xKOTWBHL+CeO4nVmuCGm4kKM4jr4N1wMB7PCGDqznlBgKnU4tXlcgVTrmCDgMUTK+J5sdvJKjlO3/8jb7y2P7KJJEJYBkmnMLNz+GfO443svWq88Na778aGsa9/syPZbBgSlaqoXJpouoj+f19GrMUfuRGDoNIpMj9yrOGVcBuFcWBPpRYPwCdnMNMz2OLswgyPyqSRdOoyV+s33XTLtfh5LWHbC8Eau/ok8jKofAfR1AxhRxb3Cl+j506fwpQr3NTV0XTvVFOqxOZOuYr1Q7yvx4Nk/xXHiEoVsj/8mjW5g4h2kJwDuSy6vwf34B4gdr82c2WimSLRhan62MgAgsqkeWF2GkS4fmC5eaWtw7YXgv+N70OtRvpHjq0rRbnqzMeBPJk0uvOSM9r7f+3fgLX89e/9Eaoju2pmuM0kmioiKY9wYho7OYV36jmsCLWRm3C6cqhFJpu1dt0u2ZLy0CkvNgFH9mOjEFMsEU3PEJ0f51d+7/ewwJ+965fisVQSmNPGKIV/6jwoTfr4K1ddYb4S0SoOeHnxLHLTyOXZIUTiyjujE/GKcxPqg5lqDWo1yOeILk6jTvwdEoUE+w8TaSdOZGYtpjjLvN1vLAj1uMy6t6k4DrgOuG7DQhHtoHu64oD+kf2oh7sgCFF9XUSjE1gMOpNBMiuXmW1Htr8QBHRfD2auRPnLJ0jfdhTdtzZTRlwXgojg5Bm8I4cuqxEgnoudq8Su2k0wkUypAlpjy1UIAtzHvwKA/4rXAIIz0IcpzOLsGcQZ7ENcXU/2ZeKSumGE9QPMXCWOXS7OYbEggmTSqxYBWYwogZRL5vhRrO8TTUwTnDxHODmFaB0H8LdxCpfFbH8h1FGdndhqjfJXnyJ19DrckX1reoJLNh0Pnk+dxxvZc/l3uXScTS9z9XTrZmNm5pC0R3hxClMu4z77NADVkZvQAz0QGVRPJ87ugYUnvYjE06WOhhSQy6B74jBOG0XYqk80V8ZMTBMVioijkExjmSzmEc/D2T2Es3sIU5zFPzNGeOo8WIPqyCJee9ef2DFCgHgmyPE0tadfJJqeIX3zjWsaVKp8B1FhhvDly/cREcikCUfHcQ/svmbRbLbmQxBAyo1dPZ48gfKrhLv2EGXzpPcPY3yf1Aqx2FciWiO5TCzgoT5MuUo0VSQan8JEESq79kzXqjNP+mgee+QA4egE/vOnMRNTqFx2zaZps9hRQgBAOTj9PUQXpyl/5UnSx1+xpiRWqquTaHScD77rn8fZ2+qI62DDiGB0HHfv0DUZL5hqDZTClioQhDhf+zIA/k23ACou1+TF1S6j2fKlpQE7PwsqoBXiOctO+apsGpVN4wz3xekeR8cxhSKSSSGpy5/qH3znz6/YXvE83AO7cfYOEV2cxP/+S4QTU+hc+40jGhKCiLzeWvs3Gz1ZvRDISaC7XjmnNYigu7uwlQrlrz5J+uj1OIeuXitYeleBrjw/Jodwr7u8mInKpOLY34npNRUdaRRTLCGeQzg6iQ0N7tPfAKB26Cb0YA+YCNXXj6n5sRgzqVgBURSXYQojTLkSjw/KNURAUu6S4wLRGqe3C93TiSmWCc9fwEzPILnMQg/x+uOvbajdojXO8CB6qD8eRzxzsu0E0ehj640i8noRWXd+kHqBkCfqWbHfsN7jbCaSyeB0d1L99nNUn3gaW2vMh1K04umLL/PUF/+WqFi6/Ltc7IsUFec2ta02ijCVKmhNNFVAnnkaPTdD1NVD0D2As3dX/OR3XZzhAVQus5DhWxwHlU6hOrI4g314I3vxDu1BD/VhUUTFUrzCvIT/VJy0K4d34yHcI4fAWKKZGWwU8q3nn+Vbzz/b8G8QpXAG+0j/2HEyP/Rq0IpoYqrh634taahHsNb+GoCIfFpEvk5c8ebUGs91G4uKhIjIMWvtU4s3aEnpKO3g9PcSXZii/NgJ0sePNuQ68aEHPoE1lr88fD1y46GFQbKIoDoyhKMT8Q24SS7bthZHdplSBcII9diXAAhufHUci5zLQEcWlUmtWiEI4tku7bnornw9b9FsXCzEdZYNWdVdOVR+JE4Rf3aMD37iY6A1f/U/G6+PMH8sZ7AP3d9DNDZO7bsnMZNT6M78qm4d14qGegQReUJEHgY+Yq39b8C0iNyywXMXrvzAWvuQtfa4tfb4wMDABg+/BkTQPV2Idih/9Un875/ERuHquylBMin8Z0/FN+j851qj0h7BubFNK107n9LRFObAWrxvxtF7tcOvQPf1xOGUnR2o3rXnNJKUhzPUh3doz6XiisHSv1/qqd1Tr7wunhQII6zvr+s3iVI4u4fIvv42Uq86gimX47rXLYisa9Q0+oi19l5r7Tfq7+8krnqzFk4A3fNvVquh1gokk8bp78Z/7hTVr36jvii1yj6eh6TcWAxz5Uufuw4q5RKeu7CQPmUjmGIJ3LhAujl1CmdiFJPKUBvah97dD65GZdNxAq51IikPZ/cg7v5hbBhh5spLmktAPOBOp5BsGkKDmZ1bdttVz6sd3EN7yd7xg7iH9xFOzWBmy031dG9ICNbaP7ny/ToGzw8Bx0XkGPDFNe7bPESj+3oxvk/psRP4L5yOF6JW2iWVQtJePCuyKGhHXBdxNMEGxWCDMPYQ9QMIQvTfPgpAcOQoKI3OZVCd+fp8/cZNC5XL4B7cjerOY2dLy/YOEAcreTeNoHo6sdPFVa/VSkgqReoVh8m9/rW4u/uRVPPMpKY5h1hrC3XT56mWzhg1iMrlcLo7qX3vRapfewozs3LvIJ6HdGQJnj9NcHFq0ecuohX+mVFMZX1mkqlUEYRotoSNDO5TXwegdv1RdGdHLLhsemGRbDMQrXEG+3D2Dcelp1Zou7ga98Bu9MHdmOJc3a17/aiOHOkfeDXeSPMSI+y8dYS1oB2cvl5MqUTpyydIHTmIe90+xHH5D+/5wFWbi+OgOvNEp85D1cfZOxjP3HguSiA4M4qzewDdQN2yxZiZOcRzMBMF7PgE7rmXsEpT2z2CMzyA1TruDa5BLLXKZXAP7CYYHcfMVZBcemEM8p8/9O8WthMR3MFeVCZF8MIZCE08fbtFSITQACqXQ2XS+M+fJjg7RupV1/PaV968dJpIraCnk+jiJKZSxTu4Z2GuXikVjxkGetB93Q0Naq0fYMpVxNGYag157EuINfgjN2FTKVQui+rMxce7Rk5/4jq4e4eIxqeJpotxVR6leN1tt121rc7nkBtHCF44iy2V11xzoVVsXb/ZZqM0urcH0ZrK33+Hrzz8f3n8H/5uyU1FJE7JXq1R+96LC+MG0RrVmSOaLBCcbWxGKSoUES3xukRkcJ/4ewCCG18Fbjw9qzuy6AamTDeCKIUz1Iezqy8eyEYRj584weMnTly1rcqk8G44GFc0nW2PRAerkfQIa0RSKZyBFB/+5INgLZ/9UC+p6w8s6UMjuSwEAeFzpzGDPbi7hxDPQedzca2xU+fRnXEdA8mkruoh5v1+pCODeeEcpjCNdypewKruO4IeHgAB1dvVtPhp3d0JWhOeH+fff+Q3QQlfeuRPr/7tnoN33X78l85jZoqors0bv1wLEiGsl7ordnhmjPD0KO7hfXEKmCsEIa4LPZ2Y6SK16VmcvUPovq6FoBdTrRGdGY2nW7vzcYCPSPz5RCF225itxIVAHv0SKvAJ9h4iynXh9XTGq8VNjp3W+RyyT8eJAczyRoU4Gu/wHvxTQjQ1g+rKb6huw7UkEcIG0T1dYEKCk2cJXjyLc2AY7+BuVGd+YRsRQfId2CgkOH2eaGwSvXcQ3dUR3/jpVOxCMT1DFMUBNCKCysRZ4cKXRzEGvK9/DQD/5uNx/EA2jerOt2RQqrJpVMqL08wH4bIet6IU3sHdBEpiYbepGBIhbAbKQfd0g40Iz44RnnoZ3deDe3gfeqB7IURUtIPu7oxrEL94ltB10AM98RRoJoXKXu6AZo0hPHcRW6kRvXCS1Evfj0Myr78Z1duFQnD6elpXBlepWAxVH1XPGr4UohTu/mEAoolCbF61GYkQNhPR6O6uOFSyXKHy9W8jjsbZN4y7ezB+emuNeB7a8+IkXRcmiUbjwoTSkY39hDwXwhAzORMPqLXC+Ys/RaIQ//pXErhpUkO9SCaF7mtxYXSlcHYPEJ67gMpnl525mheDNRYzNbNp9d02i0QI6+Q33/yO5b8UQWWzqGwWTEh4dozg1Lk4fHF4AHfXQFx4JJNG8h1AvZyT72MmprGRASVIykNlUlQf+zrdT8ZmUfV1dyDWxu7Q3fk4xUqL+MQnPgHEYwaG+wnHJmMxrFDD2ju4G9/aOC/TIvOx1SRCWCev2t2gd6xy0PMzJibEjE1SPTeGtcSu0T2d8Vghl0XSXhxUrwTrh0STM/jPvEjH5/4YXSoS7BvBP3A9ShTKddHdnS2t7XbLLbcsvNbdndhaEC/+dSwfY7AghpMvY2bnUPUHQavZ3kJ473vo/tM/w4oCLxWnLdcOOG6cttxxQccpzHHq7x0Hu/DaxTrOwnfWrW/vuHzjwjki7XDrwSOXjlXfF8eFpUwE5Vz+FAwDTGGW6MJU3ePyUlFkCaroi2N0fvtxUt/7BtZxKf2jnyUqV/FuGsFi0S2u6/boo7HP05133gmAHujBVH1MtbZiehvRGu/QHvwXzrTNotv2FsLFcfTM9DU59O2rfG+QeJVZ6rVp5s0FEezCZyz5vfg1JLjkpGddj9l730M0vA+mptHdnah8Dp3N8IEPfIAHH3yQ++67j/vvv3/Tf+dKfPjDHwYuCUGUwh3uxz99HhtGyw6eoT61OrIP/9lT2Eqt5e4Y27tizsVxCv/rDzFjE+hsGsIQCeuljcIACcP6vwEEARKFl16HwcL289teeh3w7ZeeJwX0i8bF4liLg8UFHOyGl+x9hEnRnOno5Qu+Ykq7/PDNr6GScvg/j3+ZC9bH3HCA73w3zmKhtSYMV4+h2Exuv/124OpiKaZUITw7iuRzq06VmqqP/8zJ2A1lUZKA8MIk6R89hrerfzOausMr5vT0YLr7iGoCawjQb4Sf/eivMTs7y0+nBy9LET8wOMT4hTEULNQpEy6lnZRFfyz6fvF3PkJtvpcIAQWeFb73ne/weFDERZiMfM7VRQBxFZ12QeUyqP4eoskZdH5ls0elPbwjB/CfORknFmhRZu7E12id9Pb2cuDAAY7fWi9vJMLx47fxvvvex/HbfgC0pndoF6EIgSh8pekaGqYmiqrSdAwNUxFFRWm6BneDSqFVGk9n2Nc/TI/26FEe1/UN0aNcPNGcCIr4JqJiQmYk4ujRo2itef/73990s2g1dG8Xkkk15E+lchmcw/swxdjNvBVs7x6hCSxVXWalijM28OO6BSaq9woSB9p35lAdWXQuA44TzwZJ3CuIo7j/d/83hc99jp9+0938wq9/kMxrbmrLus/ziFK4u/rxT73cUMZwp6cTu3+Y8PQoahPjKholEcI6+fg/fldjG1oTZ4io+VhA57K4h/ag+7pQHbnY2a4+w2SthTCEMIoTlioB10Ech1/85V/mX/7iv8LMzqG78/GiW4t58MEHV/xePBdnVz/h+fHLEigvhzPYi63WMBPTTa9IlAhhnayWCt1WKnH6FUAP9uHu24Xu6bzKKc8GAaZUhrpJINl0nI1Ca0wUwVyFaLYUiyUy6F196O6utvDXueGGG1bdRnd2YOfKmHJ11YweIoK7bxd+1cdMFjaplY2RCGGdfOGZbwJXFMswIVGxhI0inJ5OUjeO4Az0XGXCWGOw5TI2iOJ088MDcTyBo+OewMSicFwHSXmYUoWoVEGlUlhlUSssWDWTz372swC89a1vXXE7PdBL1KCJJErhHdpLNDUT945NIhHCOnnga38NxEKwgY+ZLYEonIO78fbvWtJ9wFqLLZUgsuiBXnR/N5L2FuqhIQrVk0OlYj8kM1siKpZQuTRuLoO1Ni4S0iYhkB//+MeB1YUgroMz1NewiSSeQ+qGg00twJIIYSMYSzgxhcqkSL3qSJyBepmEuXF1mxp6sA891BeXqq342KqP6u2KffyvsPt1V55otkT48kVUJoUNAnRXvnXephtA5XOojjlMpbZkArGr8NymTgYkQlgHNqj74CtF+tZX4Az3Lzv/baMIU5xDdXbgXbcPlU1jylVMGOEM9qDyuRVvbJ3PIfuHCUfH6+GiLfY2XSfz2e38U+exxjSlqMpaSISwFqKQaGYWcepZI1Iu7t7lB822XMUGAe7BPej+7vgQs6U4M8RQf8MOcyqbxh3ZC9a23Q20FsRzcQZ7CC9MrjmTx7Vm617VZmJNXDdsZg7vxhGyd/4QKu0tO3Nj50s3uRrvldfhDPSAMZhiCd3XjbNnaM1eoyKypUUwj+rKI6kU1t9Y7qPNJukRVsHMlTCVKu6hvaSOXArS/51f/09Lbm8jg52ZRQ31xqnZtY4TZFX9OF65zZ6EG+FTn/rUmvcREZyhPoLT51Gu0xbTwJAIYVls4GMKc+j+LtKvfdVV4YV7B682iWwQYObKOAd3x9OmIvFnfoi7b9emZcZuF/btW18mOpVJxXUXZkvXJCnZekiEcCXWEBVmES2kjr8ynglawiR55G/i6dN7Xh9Xn7fVGtYP8G44tDBFaIMQUwtw9w03NlOyxXj44YcBuPfee9e8r+7rIirOtaRg+1IkQliEKceVJt3r9pG6/uCK03e/82dx+tZ7Xv/GuMKlNXg3Hlp46tswwlR93P3bUwQAn/zkJ4H1CUEcB2ewl3Bsoi3MxaYJQUS6gZH5P2vtx5p17lUxYZzKMJ8j++PH11RTzZYqoBTuDYdQdeFYYzDlamwObVMRbAaqswOZKmKDYE1lba9JW5p4rrcDhXom7Hvrwmg5plQinC7i3XSYzI/dujYRRCYWwZH9l0RgLWaugjPcf81LzW515gfOprLx+hEbpWk9grX2oSveF67cpqmlo0xEND2D7syTee3NqK61ZVSwUewduVgEALZURfd0otd4vJ2KyqbRnbnGV5yvEZsuBBG5Z4mPH52/8etFBd+21L51sTwEcajmZrdt4TyVCmaujHfTSJzmfY1RUXHxO4vKpC8TgSlXIe3Fhb8TGkb39xC99DLW2pZNp266EFYqAiIid7KooGDTsYZwuojKpMj82PF1uSvYIMBWfR7+7d9F5dKLPg9BBHd46Vmm7cgjj2zOf6V4LrqvCzNdRFpkTjZzsHwn8FFgirj+2q3NOjcQZ6WeLuIe2kPqlYfXNTizUYgtlXFvOMTgopkOawymUsPdP9zSPEPNpr9/UwLrAdDdeUxhtmXTqc0cIzxKs2/+OmauDNaSvu0o7t6hdR3D1l0k3MP70Pkcf/DwHwPwz+59B6ZUQQ/2brsFs9X4/d//fQDe/e53b/hY4sR5YMOxyVUD/q8F2/7xZas+qiNL9sdvXXdWNWstpjCLc2AYp25O/eFn4sWkd/7kT6FymU2tX7ZV2EwhQOyqLVMzKxYvvFZseyF4rz6CEjaUWtAWZ9FDvTiDV1TUtRYbRbiDu9rGZ2YrI0rhDPQSvHyxgUxEm8u2H9WpTApx1r9YY0olpCOHu+/qm91GBmeor2nVanYCksvEEXhNDt7f9kLYCLZaQ5SDd2jv1TNBxiBaoTrbI4ntdkFEcAZ6m/5wSYSwDDYMsTUf9/BexLvcgrTWYo2N6xsnJtGmozIpnKH+pl7bbT9GWA8LM0RHDiw5E2TLVf7ikT9panB5O/L5z3/+mh272TNwiRCWwMzO4uwdxOleIhNFEIJS5PcO75iFs+XIZlufzn2z2Nn/k0tg5uZQXV04wwNLf1+p4Qz18cnf+i0eeOCBJreuvXjggQe2zTVIhLAIW6sPjg/sXtI+NZUaKh/XOfv0pz/Npz/96Ra0sn3YTtcgEUIdG0XYSi0eHLtXL/Fba7FhhDPQu8TeCVudRAh1zMwcev/wsjEEtlyNC4UnawbbkkQIEBe16++O064sgY0iENmRbhQ7hR0vBOv7iHbw9g0vO29tylX0YG9bBJknXBvatoaaiMwCz270OClUyhHRxtolS7FoEadio3IEjZRq6QcmNtqmTaTd2gPt2aa0tfboShu08zrCs9ba461uxGJE5Il2alO7tQfat02rbbPjTaOEBEiEkJAAtLcQHlp9k6bTbm1qt/bAFm1T2w6WExKaSTv3CAkJTaOdZ40WEJFjxGkiW5cK5lJbummD1JX1ZGgnge7kuqxMI/fPVukR7iROAdMOtDx1ZT1J2hP1zCBvaPb5l6Hl12UFVr1/2l4I9XxIj7a6HfNYax+y1p5c9L7QgmbcBiyct/7Eayltcl2uotH7py1Mo+XSRBJ3Z4+24j96I6krW0Ch1Q2Yp52ui4gca/T+aQshLGe7ici8oo8Bh0Wku1lPmrZOXQkngO75N4ufxK2kDa7LVTR6/7SFEJbDWvsUgIi0ix3c+tSVMQ8Bb687CX6xBee/ija5LpexlvsnWUdISGALDJYTEppBIoSEBBIhJCQAiRASEoBECFsOETkmIk/W5+sRkY/W/7pb3LQtTTJrtAURkRHgQWvtG0TknnbwNdrqJD3CFqS+gPaUiDxIG7mfbGWSHmGLUjeFnrTWHm51W7YDiRC2KIt8oW6z1n6opY3ZBiSm0RakPkY4WR8bHKu/T9gAiRC2GHWfns8QB+VA7Gv0mUQMGyMxjRISSHqEhAQgEUJCApAIISEBSISQkAAkQkhIABIhJCQAiRASEoBECAkJAPx/9fEyICaNjh0AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "train_loss = errors.loss(mean_vi_train, sigma_vi_train, y_train)\n", + "\n", + "ax =plot.plot_prediction_regression_without_test(x_train,y_train,x_linspace_test,mean_vi,sigma_vi,y_min=-3,y_max=3,\n", + "title=f\"Train {errors.loss(mean_vi_train, sigma_vi_train,y_train):.2f}\")\n", + "ax.set_xlim(-4,4)\n", + "savefig(\"MLP_VI.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " /home/shobro/anaconda3/lib/python3.7/site-packages/probml_utils/plotting.py:70: UserWarning:renaming figures/step/Calibration_VI.pdf to figures/step/Calibration_VI_latexified.pdf because LATEXIFY is True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "saving image to figures/step/Calibration_VI_latexified.pdf\n", + "Figure size: [2.5 2. ]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMIAAACeCAYAAABgrdW9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAetUlEQVR4nO2deXgUVbq435ONJCiEhCXsEAybcsUQEBEUQkAgM3plkQFmEBAQGe9cda6o84jCvfJjUBAXHAkyMGTAhcVBJTBIQJY48kAgDJthC7sQhJCQ0AmdTp/fH13dNiHpru70luS8z1NPV1fVqfOlU1+d5VuOkFKiUNR1gvwtgEIRCChFUChQiqBQAEoRFApAKYJPEELMEEJM1bZ9QohkIUSqECLO37IpLIT4W4A6wn4pZYYQIgoYpe1nAU4VQSvzlJRyiY7r4qyblPLtSq6ZCuQCUVLKtVUdq4uoFsE3ZFU8IKUswPIAOkRKWeBMCTSeAgq0h3m0phg2hBAzgCwpZQYwqKpjdRWlCD5Ae+grI1EIsUbrOsUBCCFGat+jtO/JQohUu/0tQogEIcS8CnUskVLm2n2vWGdPwHZMCJFQxbE6iVIEP6K9iZFSvi2lzNWUoQBYC7xmd010hev3A9cqe3C1t/woHdUX6DxWJ1CK4H/y7fbjsDzE1r5+ZRRUdSMhRDIWJaqMvUCU9YvWelR2rE6iFMFHWAe9QJz2wFof3Di7/nwcsEXbz9e6QAnWa+z247B0axLt7p8MzANSgTV2x6dqu0uwdMUS7Oqo7FidRChfI4VCtQgKBaAUQaEAlCIoFIBSBIUC8IIiaEafeQ7OT9WuGenpuhUKd/G4ImhGn0rnwJVJXxGo+Lpr5NCkr7UWWUKIrHvvvVcCalObri3fkC9j/1+sFLOFFLOFxEX8PUYosP+i+cskSikTIyIi/CSSoqax4dgGWrzbgryyPLgEGF2/h68VQZn0FR7DbDbz9D+e5tef/5pbplsE7Qhi08hNkMU4V+/llcEyFjeABLtjjsz8CoXLnCs8R7v325F2MI2w8jBYDJ9N+4whyUOQm+Wnrt7P44E52kC4R4VjS7TPAizKALDf03Ur6gZ/3f9XpqVPw2Q2ESfjyJ2Ty5z/m8NTTz3l9j39PUZQKHRjMpsYunIok7+ZjJSSCY0nkDs7lwnjJ/Daa69V694qVFNRIziYd5ABKwaQX5JPy7tb8u597/K7X/+OAQMGkJqaihCiWvdXLYIi4Jm7ay7dF3cnvySfcd3GsTVlK8+NeY727duzbt06wsLCql2Hx1sER8Hgmt/9VCzjgzodLK5wTrGxmKQVSez9aS/1guvx2YjPeKTpIzz00EMIIUhPT6dRo0YeqcujLYIOy/FraBkdgNGerFtRu/jnyX/S9J2m7P1pL10ad+HCSxdI6ZDC8OHDOXv2LOvXr6dDhw4eq8/TXSNXgsGjVF4fRWU8t+E5hq4aSomphAZhDciamkVMRAxTpkxh586dLF++nL59+3q0Tm8PlgsqfJ8LTBVC5KMFpNujdaumArRp08bLoikCjSvFV+i7vC8n8k/YjpXLcnKu5rDxrxtJS0tj9uzZjB071uN1e7pFcGg5ttoRtCwMuZWct7lYNGnSxMOiKQKZ1UdW0/q91pzIP8FDrR6iXYN21A+tT7P6zTiYcZCZM2fy29/+lpkzZ3qlfk+3CEuAp7SpLJvlWAgxVUq5ROsKJWstwlwP162ogZjNZkavG83ao2sJEkHMHzSfP/b5I4YyAzlXcyg4UcCwwcPo168fS5curfY0aVUEbPB+YmKizMq6I0GcohZxKv8UfZf35XLxZZpENmHHhB10adLFdj43N5cHH3yQqKgodu/eTUxMjN5bu6wtyo6g8AuL9iyi06JOXC6+zBOdnuCnP/50mxJcvHiRgQMHUl5eTnp6uitK4BbKsqzwKUaTkcdWPsb2s9sJDQrlk5RP6B3Zm40bNnL8+HFOnDhBTk4OmZmZmM1mmjdvTqtWrbwul1IEhVcxGAwcPnyYyMhItv64lVePvkoppUSURND468ZMnjUZs9lsuz4mJoaWLVsSFBSE2Wzmxo0b5OTkkJDg3bSsShEUHsFsNnPx4kVOnDhx25t9y5YtlJWVQRLQz3Jt6MFQup7rSvz98XQc1ZH4+Hg6drR8NmrUCIPBQLdu3cjLy6NZs2Z07tzZ6/L71MVCO5+MJd9nnHKx8C0Gg4GcnBw6d+5MZGSky+U6depEcXGx7UG3f+hPnjxJSUmJrUx4eDitWrXCFGqCKUBTCCWUtGFpjH5ztMPZn8jISA4dOuSWrO7iUUXQXCwypJT7tVTmFX2NkrEoSIY1/6fCO5jNZi5cuGB7WI8ePcqyZcsoLS0lLCyMHj16EBTkfK7EbDazb98+jEYjUkrsZxlDQ0OJi4sjPj6eQYMGER8fb3u7t2zZkjUH1/Cb9b+xzOGUQ87zOcQ11edMEBkZ6fXukD2ebhF6YvfwCyESNOMZYAna0ZZOGk0ldgRlWdZHVW9o+8+TJ09SWlpqK1OvXj3Kysowm80YjUZu3brF3Xff7bSuoqIijEYjZrOZkJAQXnjhBQYOHEh8fDxt27YlJOTOR8hsNjP5m8ksP7DcNpEZGR5JQXmBp34Cj+NTFwvNoPYKlnQv86jgmKdFsi0Bix3By7LVGK5fv257yI8cOcKHH35ISUlJpW/oDh06EB8fz2OPPXZb3zsqKor777/f1u/esWOHri5Hxf767NmzHZa7cOMCfZf15WzhWRrWa0j90PoU3iqkWf1mdG7s/b6+u+hSBCFEAynlDSFEOyBfSnmjikudBeePtK7tJYRACBFXVwP4K/bXi4uLOXnyZKVv92vXrtnKBQUF2RQgJCSEF198kaSkJDp27EibNm0qfUNbcaff7Up/fcWBFUz+ZjIms4mB7QeycexGTNJEztUcOjfuTGSo9/v67qLLsiyEmAKcAgYDe6SUX1ZxXRSWNQCysBsMV3CxSMAymHY4WK7NluUzZ87Qq1cvrl+/TkhICFFRUVy+fPm2a1q2bGl7m1s/4+PjiY2NJTEx0faGPnTokE8Gk44wmU385+f/SfqJdIJFMB8M/YDpPaf7UySXLct6FeEBLA/4n4EeUsptrsvmGrVJEcxmM9nZ2aSnp5Oens7evXttXZqgoCBSUlLo3bu37YG/5557qF+/fpX3c3f2xxscuXKE/iv6c9VwleZ3NWfXxF10iPZcnICbuO6QZG1mHW3AQOBloD3wsp4y1d169OghazI3btyQ69atk5MmTZKxsbESkEII2atXL/n666/Lli1bysjISBkXFydv3rzpb3HdYl7mPBk0O0gyCzlm7RhZXl7ub5GsuPy8Kae7amL/dr5w4YLtrb9z507Kyspo2LAhjz32GCkpKQwZMoSmTZveUc7fb3VXMRgNDEwbyO6LuwkLDuPT4Z8yousIf4tlj8stgt7BcpKUcps2WE6QVYwRdNwnCmxxCTUeg8FAx44dycvLA8BkMgHQtWtXXnjhBVJSUujTpw+hoaF3lPX1PLmn2HlmJymfpVBsLKZzTGd2TdpF48jG/har2jhUBCHECCxTnHFCiFFYNE0CVSqCE8tyMvCsZlWMBqZIOztDTSMtLY2LFy8CEBwczCuvvMK0adNo166dfwXzMNbYgGXZy/ho70cAvPDgCywcstDPknkOp10jIUR7LPP+1mVQe0gpl1Zx7W2WZSnlsxXO26ZLhRDJUls3uDICvWu0a9cuhgwZQllZGaGhocTGxgbEDI6nMZQZ6LKoC+dvnEciqR9an41jN/JIu0f8LZojPB+PIKU8jaVVGIjlje5oXQOHwft2SjCyMiWwTwv/888/6/oD/MEPP/zAsGHDaNOmDceOHWPXrl21UgkAPtr7EedunEMiCSKIzb/bHOhK4BZ6LcufSykPAGjjBL0UVHH8NlcMK7IGWJb37t3LkCFDiI2NZevWrbRo0YL27dv7WyyPYzabGfflOD4/8jkAoUGhtG7QmgdiH/CzZN5BryK8LYSQQCGWKdSeVVznNO17TXa2y87OZvDgwcTExLBt2zZatGjhb5G8wunrp+m7vC8/Ff1ETEQMG8dtJCQoJOCtw9VBryLMk1JuBZtxrSocBu9rX6OwWKlrFAcPHiQ5OZkGDRqwbds2Wrdu7W+RvMLHez/mvzb9F+WynJT4FNb/Zj0hQbU/bMVlO4LV78hL8tgIpMHy0aNH6d+/P2FhYezYscOjGdYCBaPJyLBPh7H19FZCgkJY8qslTHxgor/FchfP2hGEEJOllEuFEIuBRloFjrpGtY5jx46RlJREcHAw27Ztq5VKsP/SfgamDaSgtIC2DduSOSmTVg28HyccSDhr89ZYP3V2jWoVJ0+eJCkpCbPZzI4dO+jYsaO/RfI4s76bxf/u/F8kkondJ7L010t1BezUNhwqgpSyUPvcKoToru1n+0Auv3PmzBmSkpK4desW3333HV26dHFeqAZxo/QG/Vf0J/tyNuEh4ax7ah3D4of5Wyy/odfFYgrQQdtPrMqgpp13FrM8Esu0aoLUYhMCjfPnzzNgwACKi4vZtm0b3bp187dIHmXLqS088fkTlJhK6Na0Gzsn7iQqPMrfYvkVvW1grpTyVSnlq8Dpqi5ylhZemzqN1s4HZOD+xYsXGTBgAPn5+Xz77bd0797d3yJ5DEOZgRFfjGDwysGUmkp5re9rHHzuYJ1XAtA/fRonhNir7TuyHjmMWcaiHKesK2/yy8KCAcHly5dJSkoiLy+PLVu2kJiY6G+RPMap/FN0XtQZkzQhEHz39Hc82u5Rf4sVMOhtEVYDbwOfAFX6B1VCQYXvUVhalwxgkNUb1Yo/XSzOnDnDww8/zIULF9i0aRO9e/f2af3eZNWhVXT+yKIEAOHB4dxdz3ngfl1CV4ugDZqn6bjUmWV5n5N6/OJice3aNTp16oTRaKR58+Y10j26MsxmMyNWj2D9sfUEiSCiI6K5ZboV8IH0/kDvYLk9lqwTEnhFSnmmiksdWpalJW55hvV8oMQlzJ8/H6PRCOCzFIPe5seff+TRvz3Kz4afib0rlp0TdtKyQcsaEUjvF/SEsfFLmGZ74H/cCYVzdfNVqKbJZJJxcXGyXr16NT500sqCfy2whVCOXD0ykEIofYXLz5vewfI+aXHHRgiRoX36xNXC23z11Vfk5uaycuVKunTpUiNDJ62UmkpJTkvm+/PfExYURtrwNEbfp9Zs1IPeLBbfYnGUs7pYnAbaSykf85ZgvvA1klLSp08frly5wvHjxwkODvZqfd7k+3PfM3TVUIqMRcRHx5M5MZOmdzX1t1j+wjsxy9h5n9pqEmKgq5UFGt9//z27d+9m0aJFNVoJZmyZwfx/zUcieb7n83w47EN/i1Tj0DtrtFXPMT1YwzUDIcvd/PnziYmJYeLEmullmW/Ip9/yfhy9epT6ofX5eszXJLVP8rdYNZIq7QhCiAbu3FCzBSRrrhQVz0UBqVqmbL9y7Ngxvv76a37/+9/XuDGBoczAu/96lxbvtuDo1aP0bNGTy/9zWSlBdahqFA3MtdvvbrffzkGZGVh8iABSKzkfZT3vbPP2rNGUKVNkeHi4zMvL82o9nqaotEjeNecuySwks5Bvfvemv0UKRDw6a5SlxSGAxcXiOs7jEZy5WAAkCiGiseQ+vc3Fwldp4fPy8khLS2PChAm2hFs1gbMFZ+m1tBfFZcWAxUL8eKfH/SxV7aBKRZBSrgPWgSUGQWru1y7GIxRUuGcBmuVYCLGFCr5G0keW5UWLFmE0GnnppZe8VYXHWbJvCdPTp1Muy4kIiQAJze9urizEHkLvYDlbZzyCQxcL7Y2/WvrRonzz5k3+8pe/8MQTT9SIQBuT2UTKqhS+zf2WkKAQlv5qKWO6jVEWYg/j6XgEZ8H7q7F0s6L5JfrNpyxfvpz8/Hxefvllf1TvEgcuHSApLYnrpddp3aA1mZMyadPQ0mVMaF6zXUACDj0DCWBgZfve3LwxWC4rK5Pt27eXffr08fi9Pc3s7bOlmCUks5BP/+PpuugmUR285mKhNx4hoPnyyy85ffo0CxYs8LcoVVJsLKb/3/qz79I+woPD+WLUF2pA7AP0ulg0xOJ92gjH3qcew9MuFlJKevXqRWFhIT/++GNAWpK35m7l8c8fx1Bm4L4m97Fjwg6iI6P9LVZNxDsuFlJ/PELAsnPnTrKysli8eHFAKsH09Ol8nPUxAsGMPjOYN2iev0WqU/h8wXHtmgR8vOD4O++8Q5MmTRg/fryvqtTF5eLL9F3Wl1PXT3F32N1s/u1mHmr9kL/FqnPoTmAjhOhunUJ1cI3D4H07krGsj+ATjh49Snp6Os8//zwRERG+qtYhhjIDc3fNpc3CNpy6fop+bfpx5eUrSgn8hKenT51alrXA/QzAZ5HxCxYsICIigunT/brSo43iW8U0W9AMQ5kBgHnJ85jx8Aw/S1W38Wg6l0oosP+iKUaVwf/eCN6/dOkSK1euZNKkSTRu7P8ljo5dPUbb99valCAiOILkuBqbILzWoFcR4oQQDTSP1DgH1+lNC58M9KiYxUJa4poTpZSJTZo00SmaYz744ANMJhMvvviiR+5XHd7b/R5d/9KV/JJ8IkMjiQiOUG4SAYLewfJqLNOn0cAcB9c5C97fr313NH7wGEVFRSxevJjhw4f7NXlvqamUwX8fzK5zuwgNCmXlkyt5ovMTyk0ikHBkbQMma5+LsSjDamCvO5Y7VzdPWJYXLlwoAbl79+5q38tddp/fLRvMbSCZhezwfgd5qeiS32SpQ7j8vDlThIbyTheLB9ypyNWtuopgNBplmzZtZL9+/ap1n+rw6pZXbW4S076Z5jc56iCedbGQdtmw7Y7ViGzYa9eu5dy5cyxatMjndReUFtBvWT8O/3yYyNBI1o9ez6AOPukNKtxEr4uFRxYcd4XquFhIS4tCSUkJR44c8Wm+/w3HNjBq7ShKTaX0aN6D7RO2c1fYXT6rXwF4YcUc64Lj7fUuOB4IbNu2jezsbD755BOfKYHZbGbS15NY8e8VCASzHp3Fm/3f9Endiuqje8FxqTNrhY71EZLRplgrO2+lOi3CkCFDOHDgAGfOnCE8PNyte+jFUGbguzPfMW3DNC7cuECj8EZkjM9Q8QL+xfNOd9KS4c5mRBNCdJfamst31G5xsciQUu7XMlWsrXA+AU1BhBD7Kp73BAcPHmTz5s3MmTPHJ0rQdmFbrpZcBSCpXRKbxm0iLCTMq/UqPI+ufoNmTPuzEOIL4FkHl/bEzpqsPfg2pMWOkKG1GnMrqafaluUFCxZQv359pk3zrrOsyWxi8N8H25QgLCiMdwa/o5SghqK3Az0VSMViMHPlLV5Q8YC0xCtnAHck5ZTVtCyfOHGCVatWMX78eKKjvefTdzjvMLHzY/n+/PcEi2AigiNo1aCVshDXYPQqwmksrhUjgYYOrnMWvD/DLsNdXEUXi+pgMBjo2bMn5eXlbNiwAYPB4Klb38bcXXO5P/V+rpVcY+x9Yyl4pYDMZzI5NP2QshDXYPQG5qzTBs1rcexr5Cx4fy0QpWXB+0J6MJvFDz/8QGFhIQD5+fkeX+PAYDQwYMUA9vy0h3rB9fhsxGc82eVJQAXS1wb0umG3sxs0Vzl7JO3yFgH77Y4v0T5zK57zFJs2bQIgIiKCZs2a0bmz57op289s51ef/oqbZTfp3LgzuybuonGk/z1ZFZ5Db9fIlsfU3Zyo3iQ/P5/U1FRGjhxJZmYmhw4d8lg+0//e9N8MWDGAm2U3ebH3i/z4+x+VEtRC9Hqf9tJmjK5jyWLhtXUR3OH999+nuLiYN954w2NrIl8pvkK/v/Xj+LXj3BV2F5vGbaJvm74eubci8NCrCKlWg1qgrYtQWFjI+++/z5NPPukxJVh7dC3jvhyHsdzIQ60eIuN3GUSGqYFwbcbl9RGcWZgdWZa1WaI46yalfNtVgSuyaNEiCgsLef3116t7K8xmM2O/HMsXR74gSATxdvLbvPxw4GfEU3gAd1xWq9pwnhZ+KhYFAMtSs1FV3UuPG3ZRUZGMiYmRKSkpTq91xsG8g7Lx240ls5BN3m4iD+cdrvY9FX7D5WfX0x5pzizLS6SdbUFWc/r0448/5tq1a8ycObM6t+G93e/xHx//B1cNV4kIieDkH05yb9N7q3VPRc3C266ZBZUd1HySRlVyXLeLhcFgYP78+QwaNIgHH3zQLeGMJiNJK5J4cfMv8cxBIoiT+Sfdup+i5uLpBF96g/crddOQLqyPsGTJEq5cucIbb7zhnqAX9zLo74MovFVI24ZtkVJyreSaWpW+ruJOf6qqDYsSTAUSgJF2x6dqn8lYxgZbsKzd7FaoZklJiWzevLns37+/Wx3I17e+bguhnPzVZFleXi5vGm/KfT/tkzeNNXuxcYWU0tOhmm4oVQGOLcsZQI/q1rNs2TJbviJXuFF6g0f+9gj/zvs3ESERrHtqHUPjhwIQGRSpXCXqMLpCNf1BVYE5RqORe+65h9atW5OZmYnm1+SUTSc2MWL1CEpMJdzf7H52TthJg/CAM5IrPIPLgTm+C+b1EGlpaZw/f56ZM2fqUgKz2czkrycz7NNhlJpKeb3f6xyYdkApgeI2alSLYDKZ6NSpE9HR0ezZs8epIly4cYG+y/pytvAsDes1JGN8BoktfJZyVeE/vLM+QqDw6aefkpuby8KFC50qQdq/03jm62cwmU0qhFLhFI+3CDqD9wdJKV9xdJ+KLUJ5eTldu3YlIiKC7OzsShXBUGbg8JXDzNo+i00nNxEsgvlg6AdM7xkYWbAVPsO/LYKz4H2wzBwJIRzFPVfKmjVrOH78OGvWrKlSCTp90IkLxRcAiK0fS+akTDpE+y/nqaLm4FMXC2dUZVk2m8289dZbdO3aleHDh1da9k9b/2RTgmARzDdjvlFKoNCNt8cIBa5cLKuwLK9fv54jR46watWqOxJ2GYwGkv+ezA8XfgCgXnA9Wt7dkq5Nu1ZXdkUdwucuFq4ipeStt94iPj6e0aNvT3yReS6ToauGUmwsplNMJ/457p/kl+arVOsKl/F012gJkKh1iW4L3rfbT8aSwUJXtyk9PZ3s7Gz+9Kc/3bYa5kubX6Lf8n4UG4v5Q68/kPN8Du0atSOheYJSAoXLBLQdYe/evfTu3ZsrV65w/PhxQkNDuWq4Sr/l/ci5mkP90PpsHLuRR9o94m9xFYFF7bIjbNmyhT179pCamkpoaCj/+PEfjFk3hlvlt3iw5YNsG79NhVAqPEJAtwjh4eGcPXuW48eP88zGZ/js8GcEiSDmJM3h1b6v+ltEReBSe1qEoqIi9u3bx5sL3yT+43guFl0kJiKG7U9v575m9/lbPEUtwx+WZYfnrTRo0EDSEwyPGiiX5Qy7ZxhfjfmKkKCA1V1F4OBf71PNspylxR3csVaSs/P2FIUWUfRIEUIIlj2+jPRx6UoJFF7D15Zl/ZbnekA5ZE3IYuIDEz0qpEJREX9blm87r3WbLDaHYOAT6P7n7qcpI9+FOhsDV1243hNlfV3OH3XWJFkPSyldGkj62rLs8Ly9i4UQIkteki4HDwghsqR0vVx1yvq6nD/qrGmyulrG15blSs8rFP7G18H7lZ5XKPxNIMcsL3F+iUfL+aNOJWuAlAtYy7JC4UsCuUVQ6EQIEWf/6eW6ojy59l2gEBAWKnet0dWJj3Y3fb27C6rrsahrkwhxlcjjTNZUIUQuME9vOe38SCxT2Aku/I3JwLNauGw0MEValg3WW2cykF/x79TxN07FMq6s+Aw4jIHX68ng0ZSP7mw4TyVf6Xln5eyuW+NGnZWmr9dRzpbqEruUli7IOgMtPaYLskZZz7tYLplfUnHGuVDO/tpkN+q0/j4zXCg3z1pXFf/PO4658rtL6fm08O7grjW6OvHR7qavd3dBdaeyam+2DFdl1UgUQiTbB0DpKDfIrt5kveWsv4sQYqS0uMrollW7/jUhxBpu/1td+V9GudAF1H3fQFCEihS4ed5ZOZfvWVX6ekflpIMF1asqJ4RIqOSh0lVWSlmgKW6Gi7JGAblauUEO+v0Vy1np6UzQimW1B/gVLDakeZUVqKLOuUCC9iBXZyX5ive1EQiK4K41ujrx0e6mr3d3QXW99SUDPSo8lM7qnFrFQ+yszn2VlHFFVnfKjpRSZkiLXWmN3ZvdmcdBAbBEa3FzXfhf639GHPWbfLHhPJV8VecdlrPrk+6jQh9aR52Vpq/XUS7Oeo7b+8BOZbXrC6dit6SWzt8nAbs+vwu/zww3y42sKLsLv89IV/+XWrmp1rIV6rzjf+zs2alsU3YEhYLA6BopFH5HKYJCgVIEhQJQihCQCCG2uOouodkRUr0lU21HKUJgoitWQ/P7mQo2Y1V15tjrNAHha6RwD3l7fIeiGihFCBDsHMsysFht12rdI6uTWoG0rC0xA4vzWT6WFmCUlPLZCve6o5zP/pAaiuoaBQ7zgLVW66ndsVxtG2VVAu3BTnTQHbqtnNclrwUoRQh8sqSU+7W3fk80JdHcFPSWUzhBWZYDBK1r9BSQheWNbvXQfBb4QrusAEuXJxdsKW4+AQZicUOw7kfbl5N28QKKylGKoFCgukYKBaAUQaEAlCIoFIBSBIUCUIqgUABKERQKQCmCQgEoRVAoAKUICgUA/x+UzRPfmBn7OgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(1)\n", + "_, df_train = plot.calibration_regression(\n", + " mean_vi_train, sigma_vi_train, y_train, \"Train\", \"black\", ax\n", + ")\n", + "ax.set_title(f\"Train {errors.ace(df_train):.2f}\")\n", + "k = jnp.arange(0, 1.1, 0.1)\n", + "ax.plot(k,k,label='Ideal',color='Green')\n", + "savefig('Calibration_VI.pdf')" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1093,7 +1203,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.4" + "version": "3.7.13" }, "vscode": { "interpreter": { diff --git a/utilities/vi_helper.py b/utilities/vi_helper.py new file mode 100644 index 0000000..fbd4f5a --- /dev/null +++ b/utilities/vi_helper.py @@ -0,0 +1,128 @@ +import logging ## for ignoring check warnings + +logger = logging.getLogger() + + +class CheckTypesFilter(logging.Filter): + def filter(self, record): + return "check_types" not in record.getMessage() + + +logger.addFilter(CheckTypesFilter()) + +from ajax.advi import ADVI +from ajax.utils import train + +import flax.linen as nn +from flax.core.frozen_dict import freeze, unfreeze + +import jax +import jax.numpy as jnp +import optax +import tensorflow_probability.substrates.jax as tfp + +tfd = tfp.distributions +tfb = tfp.bijectors + +from functools import partial + + +class MLP(nn.Module): + features: list + activations: list + + @nn.compact + def __call__(self, X): + if len(self.activations) != len(self.features) - 1: + raise Exception( + f"Length of activations should be equal to {len(self.layers) - 1}" + ) + + for i, feature in enumerate(self.features): + X = nn.Dense(feature, name=f"{i}_Dense")(X) + if i != len(self.features) - 1: + X = self.activations[i](X) + return X.ravel() + + +def vi_model(mlp_features, x_train, y_train, n_epochs=50000, variable_noise=True): + """ + function to return trained VI model in ajax + mlp_features : features for MLP , [dimensions, activations] + x_train : training data + y_train : training output (n_samples,) + """ + + mlp = MLP(*mlp_features) + seed = jax.random.PRNGKey(89) + frozen_params = mlp.init(seed, x_train) + params = unfreeze(frozen_params) + prior = jax.tree_map( + lambda param: tfd.Independent( + tfd.Normal(loc=jnp.zeros(param.shape), scale=jnp.ones(param.shape)), + reinterpreted_batch_ndims=len(param.shape), + ), + params, + ) + + bijector = jax.tree_map(lambda param: tfb.Identity(), params) + + def get_log_likelihood(latent_sample, data, aux, **kwargs): + frozen_params = freeze(latent_sample) + y_pred = mlp.apply(frozen_params, aux["X"]) + scale = jnp.exp(kwargs["log_noise_scale"]) + if variable_noise == False: + scale = 0.1 + return tfd.Normal(loc=y_pred, scale=scale).log_prob(data).sum() + + model = ADVI(prior, bijector, get_log_likelihood, vi_type="mean_field") + + params = model.init(jax.random.PRNGKey(8)) + mean = params["posterior"].mean() + params["posterior"] = tfd.MultivariateNormalDiag( + loc=mean, + scale_diag=jax.random.normal(jax.random.PRNGKey(3), shape=(len(mean),)) - 10, + ) + params["log_noise_scale"] = 0.001 + + tx = optax.adam(learning_rate=0.001) + seed1 = jax.random.PRNGKey(4) + seed2 = jax.random.PRNGKey(5) + + loss_fn = partial( + model.loss_fn, + aux={"X": x_train}, + batch=y_train, + data_size=len(y_train), + n_samples=1, + ) + results = train( + loss_fn, + params, + n_epochs=n_epochs, + optimizer=tx, + seed=seed2, + return_args={"losses"}, + ) + return mlp, model, results + + +def vi_predict(vi_model, results, mlp_model, data): + """ + Function to predict given vi_model in ajax + vi_model : VI model in ajax + results : result obtained after training ajax VI model + mlp_model : mlp model for which ajax VI model was trained + data : data for which prediction is requires + """ + + posterior = vi_model.apply(results["params"]) + seed = jax.random.PRNGKey(4) + weights = posterior.sample(seed, sample_shape=(1000,)) + + def draw_sample(weights): + y_pred = mlp_model.apply(freeze(weights), data) + return y_pred + + y_samples = jax.vmap(draw_sample)(weights) + return y_samples