{ "cells": [ { "cell_type": "markdown", "id": "3fc274aa-b43d-45e1-87c0-340408efdc9b", "metadata": {}, "source": [ "# Classifying data with Neural Networks\n", "\n", "In this notebook we will see how we can classify data with a neural network. We will use the famous IRIS dataset and train our network to predict the species of an iris flower based on four features of that flower.\n", "\n", "## The IRIS dataset\n", "\n", "The IRIS dataset is a well-known and frequently used dataset in the field of machine learning and statistics. It is often used as a benchmark for classification tasks. The dataset is named after the iris flower, as it contains measurements of various attributes of three different species of iris flowers.\n", "\n", "The IRIS dataset consists of ***150 samples***, with each sample representing an individual iris flower. Each flower sample is described by ***four features*** or attributes:\n", "\n", "1. **Sepal length**: It represents the length of the sepal, which is the outermost whorl of the flower. It is measured in centimeters.\n", "2. **Sepal width**: It denotes the width of the sepal, measured in centimeters.\n", "3. **Petal length**: It represents the length of the petal, which is the innermost whorl of the flower. It is measured in centimeters.\n", "4. **Petal width**: It denotes the width of the petal, measured in centimeters.\n", "\n", "Based on these four features, the IRIS dataset aims to classify each iris flower into one of ***three species***:\n", "\n", "1. **Setosa**: Iris setosa is one of the species of iris flowers. It is known for its distinctive appearance, with relatively small sepal and petal sizes.\n", "2. **Versicolor**: Iris versicolor is another species in the iris family. It has intermediate sepal and petal sizes compared to the other two species.\n", "3. **Virginica**: Iris virginica is the third species in the dataset. It typically has the largest sepal and petal sizes among the three species.\n", "\n", "The IRIS dataset is widely used for tasks such as classification, clustering, and data visualization. Its simplicity, small size, and well-defined class labels make it an ideal starting point for exploring and evaluating various machine learning algorithms and techniques." ] }, { "cell_type": "code", "execution_count": 1, "id": "cf289899-758c-4df3-ab5f-20b4692a1aaa", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "\n", "from sklearn.datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 2, "id": "688bc589-8167-45e5-8d42-2812ff4d01e5", "metadata": {}, "outputs": [], "source": [ "# First we will load the iris dataset. This dataset contains measurements of different flower types\n", "# like the sepal length, the sepal width, petal length length and petal width\n", "iris = load_iris(as_frame=True)\n", "X = iris['data']\n", "y = iris['target']" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3da6a88-95af-43dc-9965-fd8a9692a304", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
05.13.51.40.2
14.93.01.40.2
24.73.21.30.2
34.63.11.50.2
45.03.61.40.2
\n", "
" ], "text/plain": [ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", "0 5.1 3.5 1.4 0.2\n", "1 4.9 3.0 1.4 0.2\n", "2 4.7 3.2 1.3 0.2\n", "3 4.6 3.1 1.5 0.2\n", "4 5.0 3.6 1.4 0.2" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at the four features of the dataset\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "a4c1e4e0-4113-4b6f-a6ff-422c4c940ba4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 0\n", "1 0\n", "2 0\n", "3 0\n", "4 0\n", "Name: target, dtype: int32" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# And now let's have a look at the labels\n", "y.head()" ] }, { "cell_type": "code", "execution_count": 5, "id": "2778f942-80d9-4d9f-899b-cce4aeca54a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0, 1, 2]), array([50, 50, 50], dtype=int64))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's have a look at how many samples of each species are in the dataset\n", "np.unique(y, return_counts=True)" ] }, { "cell_type": "markdown", "id": "4ebdc70b-faae-4502-b896-39313a919f52", "metadata": {}, "source": [ "So there are 50 samples of each species in the dataset. The species are encoded by the numbers 0, 1 and 2. Now lets build a dataset class for the IRIS dataset." ] }, { "cell_type": "markdown", "id": "9859a946-244b-4158-a747-d0feb61193f0", "metadata": {}, "source": [ "## Feature Engineering\n", "For training neural networks as well as for other machine learning algorithms it is important to standardize the features of our dataset. That means shrinking the values to a range between 0 and 1 (MinMaxScaler) or mapping the features to a standard normal distribution (StandardScaler).\n", "\n", "The library SciKit Learn offers you a lot of scaling techniques already implemented in the `sklearn.preprossesing` package. You can have a look at them here: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing\n", "\n", "**StandardScaler**: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html#sklearn.preprocessing.StandardScaler\n", "**MinMaxScaler**: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html#sklearn.preprocessing.MinMaxScaler\n", "\n", "Your task is to apply standard scaling to all four features of the IRIS dataset." ] }, { "cell_type": "code", "execution_count": 6, "id": "b424f9ff-1650-4694-9c63-0e548886587a", "metadata": {}, "outputs": [], "source": [ "# TODO Let's create an instance of StandardScaler\n", "scaler = _\n", "# TODO Apply standard scaling to our features\n", "X_scaled = _" ] }, { "cell_type": "code", "execution_count": 7, "id": "f74c3dec-e568-478c-8acd-8b6b857a93f7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123
0-0.9006811.019004-1.340227-1.315444
1-1.143017-0.131979-1.340227-1.315444
2-1.3853530.328414-1.397064-1.315444
3-1.5065210.098217-1.283389-1.315444
4-1.0218491.249201-1.340227-1.315444
\n", "
" ], "text/plain": [ " 0 1 2 3\n", "0 -0.900681 1.019004 -1.340227 -1.315444\n", "1 -1.143017 -0.131979 -1.340227 -1.315444\n", "2 -1.385353 0.328414 -1.397064 -1.315444\n", "3 -1.506521 0.098217 -1.283389 -1.315444\n", "4 -1.021849 1.249201 -1.340227 -1.315444" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now let's again have a look at our dataset after it has been scaled\n", "pd.DataFrame(X_scaled).head()" ] }, { "cell_type": "code", "execution_count": 8, "id": "4273153e-0f3e-4b3e-8b1c-e1405ce7a0e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'After scaling')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGzCAYAAACPa3XZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuSklEQVR4nO3de1iUdd7H8Q8oDJgcQkEkkZMV5SlX05RKKVckdZdSMzuBlrmGlamV7LOl5pNsbZllRuq2YLu4lqWmHTTPZmmmZW1tGhhqecBTgGKiMvfzR5fzMIEIOPxG4P26rvu6nN/9m/v+DjpfP9yHGQ/LsiwBAAAY4unuAgAAQMNC+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfioR/72t78pOjpajRo10jXXXOPuctyiV69e6tWrl+Pxrl275OHhoaysLLfVBFxs/vnPfyo2NlZeXl4KDAx0dzlGVNQLJk2aJA8PD/cV1YARPgzLysqSh4eH0xISEqL4+Hh9+OGHNd7uRx99pMcff1xxcXHKzMzU1KlTXVg1gLri1VdflYeHh7p161bh+u3btyslJUUxMTGaM2eOZs+erRMnTmjSpElau3at2WLRYDV2dwEN1dNPP62oqChZlqX8/HxlZWXplltu0dKlS9W/f/9qb2/16tXy9PTU66+/Lm9v71qouG6KiIjQL7/8Ii8vL3eXAhiRnZ2tyMhIbd68Wbm5uWrTpo3T+rVr18put+ull15yrDt8+LAmT54sSU5HDuu7v/zlL5owYYK7y2iQOPLhJomJibr77rt1zz33aPz48fr444/l5eWlf//73zXa3sGDB+Xr6+uy4GFZln755ReXbMudPDw85OPjo0aNGrm7FKDW5eXl6dNPP9W0adMUHBys7OzscnMOHjwoSUZOtxQXF9f6Pi5E48aN5ePj4+4yGiTCx0UiMDBQvr6+atzY+WCU3W7X9OnT1bZtW/n4+KhFixYaOXKkfv75Z8ccDw8PZWZmqri42HEq5+x5zTNnzmjKlCmKiYmRzWZTZGSk/vznP6ukpMRpP5GRkerfv7+WL1+uLl26yNfXV7NmzZIkFRQUaMyYMQoPD5fNZlObNm307LPPym63n/d1bdmyRQkJCWrevLl8fX0VFRWl4cOHl3uNL730ktq3by8fHx8FBwerb9++2rJli2NOZmambrrpJoWEhMhms+nqq69WRkbGefdf0XnelJQUNW3aVHv37lVSUpKaNm2q4OBgjR8/XqWlpU7PP3LkiO655x75+/srMDBQycnJ+uqrr7iOBBel7OxsXXrpperXr58GDRpULnxERkZq4sSJkqTg4GB5eHgoJSVFwcHBkqTJkyc7esikSZMcz9u+fbsGDRqkoKAg+fj4qEuXLlqyZInTts+eUl63bp0efPBBhYSEqFWrVpXWO2PGDLVt21ZNmjTRpZdeqi5dumjevHlOc/bu3av77rtPYWFhstlsioqK0qhRo3Tq1ClJ0tGjRzV+/Hi1b99eTZs2lb+/vxITE/XVV1+d9+dV0TUfHh4eGj16tBYvXqx27drJZrOpbdu2WrZsWbnnr127Vl26dJGPj49iYmI0a9YsriOpIk67uElhYaEOHz4sy7J08OBBzZgxQ8ePH9fdd9/tNG/kyJHKysrSsGHD9PDDDysvL0+vvPKKvvzyS33yySfy8vLSP//5T82ePVubN2/W3//+d0lSjx49JEn333+/5s6dq0GDBmncuHH67LPPlJ6eru+++06LFi1y2teOHTs0dOhQjRw5UiNGjNCVV16pEydOqGfPntq7d69Gjhyp1q1b69NPP1VaWpr279+v6dOnn/M1Hjx4UH369FFwcLAmTJigwMBA7dq1SwsXLnSad9999ykrK0uJiYm6//77debMGX388cfatGmTunTpIknKyMhQ27Zt9Yc//EGNGzfW0qVL9eCDD8putys1NbXaP//S0lIlJCSoW7duev7557Vy5Uq98MILiomJ0ahRoyT9GooGDBigzZs3a9SoUYqNjdW7776r5OTkau8PMCE7O1u33XabvL29NXToUGVkZOjzzz/XtddeK0maPn263njjDS1atEgZGRlq2rSp2rdvr+uuu06jRo3Srbfeqttuu02S1KFDB0nSt99+q7i4OF122WWaMGGCLrnkEr311ltKSkrSO++8o1tvvdWphgcffFDBwcF66qmnKj3yMWfOHD388MMaNGiQHnnkEZ08eVJff/21PvvsM915552SpH379qlr164qKCjQAw88oNjYWO3du1dvv/22Tpw4IW9vb/3www9avHixBg8erKioKOXn52vWrFnq2bOn/vvf/yosLKzaP8cNGzZo4cKFevDBB+Xn56eXX35ZAwcO1J49e9SsWTNJ0pdffqm+ffuqZcuWmjx5skpLS/X00087ghzOw4JRmZmZlqRyi81ms7Kyspzmfvzxx5YkKzs722l82bJl5caTk5OtSy65xGnetm3bLEnW/fff7zQ+fvx4S5K1evVqx1hERIQlyVq2bJnT3ClTpliXXHKJ9f333zuNT5gwwWrUqJG1Z8+ec77WRYsWWZKszz///JxzVq9ebUmyHn744XLr7Ha7488nTpwotz4hIcGKjo52GuvZs6fVs2dPx+O8vDxLkpWZmekYS05OtiRZTz/9tNNzO3XqZHXu3Nnx+J133rEkWdOnT3eMlZaWWjfddFO5bQLutmXLFkuStWLFCsuyfn3/tGrVynrkkUec5k2cONGSZB06dMgxdujQIUuSNXHixHLbvfnmm6327dtbJ0+edIzZ7XarR48e1uWXX+4YO9vbrr/+euvMmTPnrfePf/yj1bZt20rn3HvvvZanp2eFPeRsfzh58qRVWlrqtC4vL8+y2WxO7/GKesHZn0VZkixvb28rNzfXMfbVV19ZkqwZM2Y4xgYMGGA1adLE2rt3r2MsJyfHaty4cbltojxOu7jJzJkztWLFCq1YsUL/+te/FB8fr/vvv9/pqMCCBQsUEBCg3//+9zp8+LBj6dy5s5o2bao1a9ZUuo8PPvhAkjR27Fin8XHjxkmS3n//fafxqKgoJSQkOI0tWLBAN9xwgy699FKnGnr37q3S0lKtX7/+nPs/e075vffe0+nTpyuc884778jDw8NxKLissocufX19HX8+e9SoZ8+e+uGHH1RYWHjOGirzpz/9yenxDTfcoB9++MHxeNmyZfLy8tKIESMcY56enjU60gLUtuzsbLVo0ULx8fGSfn3/DBkyRPPnzy93OrGqjh49qtWrV+v222/XsWPHHO//I0eOKCEhQTk5Odq7d6/Tc0aMGFGla6wCAwP1008/6fPPP69wvd1u1+LFizVgwADHEdCyzvYHm80mT89f/ysrLS3VkSNH1LRpU1155ZX64osvqvuSJUm9e/dWTEyM43GHDh3k7+/v6A+lpaVauXKlkpKSnI6stGnTRomJiTXaZ0PDaRc36dq1q9MbaujQoerUqZNGjx6t/v37y9vbWzk5OSosLFRISEiF2zh74di57N69W56enuWudg8NDVVgYKB2797tNB4VFVVuGzk5Ofr666/PeSixshp69uypgQMHavLkyXrxxRfVq1cvJSUl6c4775TNZpMk7dy5U2FhYQoKCqr0tXzyySeaOHGiNm7cqBMnTjitKywsVEBAQKXP/62z15aUdemllzpdS7N79261bNlSTZo0cZr3258n4G6lpaWaP3++4uPjlZeX5xjv1q2bXnjhBa1atUp9+vSp9nZzc3NlWZaefPJJPfnkkxXOOXjwoC677DLH44r6SEWeeOIJrVy5Ul27dlWbNm3Up08f3XnnnYqLi5MkHTp0SEVFRWrXrl2l2zl7zdirr76qvLw8p6B19hRJdbVu3brcWNn+cPDgQf3yyy8V9gL6Q9UQPi4Snp6eio+P10svvaScnBy1bdtWdrtdISEhFV6xLqnK5xarevFT2aMLZ9ntdv3+97/X448/XuFzrrjiikr3+/bbb2vTpk1aunSpli9fruHDh+uFF17Qpk2b1LRp0yrVtXPnTt18882KjY3VtGnTFB4eLm9vb33wwQd68cUXq3Th629x9wvqk9WrV2v//v2aP3++5s+fX259dnZ2jcLH2ffW+PHjyx0VPeu3/9lW1EcqctVVV2nHjh167733tGzZMr3zzjt69dVX9dRTTzlu+62KqVOn6sknn9Tw4cM1ZcoUBQUFydPTU2PGjKlRb5DO3R8sy6rR9lAe4eMicubMGUnS8ePHJUkxMTFauXKl4uLiqvyGLisiIkJ2u105OTm66qqrHOP5+fkqKChQRETEebcRExOj48ePq3fv3tXe/1nXXXedrrvuOj3zzDOaN2+e7rrrLs2fP1/333+/YmJitHz5ch09evScRz+WLl2qkpISLVmyxOk3kvOddrpQERERWrNmjU6cOOF09CM3N7dW9wtUV3Z2tkJCQjRz5sxy6xYuXKhFixbptddeO2cfOdcvKNHR0ZIkLy+vC+oB53LJJZdoyJAhGjJkiE6dOqXbbrtNzzzzjNLS0hQcHCx/f3998803lW7j7bffVnx8vF5//XWn8YKCAjVv3tzlNUtSSEiIfHx8KuwF9Ieq4ZqPi8Tp06f10Ucfydvb2xEUbr/9dpWWlmrKlCnl5p85c0YFBQWVbvOWW26RpHJ3pEybNk2S1K9fv/PWdfvtt2vjxo1avnx5uXUFBQWOwFSRn3/+udxvCmc/9v3srb4DBw6UZVkV/qZz9rlnfwspu63CwkJlZmaet/4LkZCQoNOnT2vOnDmOMbvdXmGDB9zll19+0cKFC9W/f38NGjSo3DJ69GgdO3as3K2xZZ0N17/tKSEhIerVq5dmzZql/fv3l3veoUOHalz3kSNHnB57e3vr6quvlmVZOn36tDw9PZWUlKSlS5c63XZ/Vtn+8Ns+s2DBgnLXorhSo0aN1Lt3by1evFj79u1zjOfm5l7QJ1U3JBz5cJMPP/xQ27dvl/Tr+cN58+YpJydHEyZMkL+/v6Rfr5kYOXKk0tPTtW3bNvXp00deXl7KycnRggUL9NJLL2nQoEHn3EfHjh2VnJys2bNnq6CgQD179tTmzZs1d+5cJSUlOS5Mq8xjjz2mJUuWqH///kpJSVHnzp1VXFys//znP3r77be1a9euc/52MXfuXL366qu69dZbFRMTo2PHjmnOnDny9/d3BKP4+Hjdc889evnll5WTk6O+ffvKbrfr448/Vnx8vEaPHq0+ffrI29tbAwYM0MiRI3X8+HHNmTNHISEhFTZEV0lKSlLXrl01btw45ebmKjY2VkuWLNHRo0clVf10FlCblixZomPHjukPf/hDheuvu+46xweODRkypMI5vr6+uvrqq/Xmm2/qiiuuUFBQkNq1a6d27dpp5syZuv7669W+fXuNGDFC0dHRys/P18aNG/XTTz9V6fM0KtKnTx+FhoYqLi5OLVq00HfffadXXnlF/fr1k5+fn6RfT6l89NFH6tmzpx544AFdddVV2r9/vxYsWKANGzYoMDBQ/fv319NPP61hw4apR48e+s9//qPs7GzHUZvaMmnSJH300UeKi4vTqFGjVFpaqldeeUXt2rXTtm3banXf9YLb7rNpoCq61dbHx8e65pprrIyMDKfbS8+aPXu21blzZ8vX19fy8/Oz2rdvbz3++OPWvn37HHMqutXWsizr9OnT1uTJk62oqCjLy8vLCg8Pt9LS0pxum7OsX2+17devX4U1Hzt2zEpLS7PatGljeXt7W82bN7d69OhhPf/889apU6fO+Vq/+OILa+jQoVbr1q0tm81mhYSEWP3797e2bNniNO/MmTPW3/72Nys2Ntby9va2goODrcTERGvr1q2OOUuWLLE6dOhg+fj4WJGRkdazzz5r/eMf/7AkWXl5eY55Vb3VtqKfVUW33R06dMi68847LT8/PysgIMBKSUmxPvnkE0uSNX/+/HO+dsCUAQMGWD4+PlZxcfE556SkpFheXl7W4cOHK7zV1rIs69NPP7U6d+5seXt7l7vtdufOnda9995rhYaGWl5eXtZll11m9e/f33r77bcdc872tspurS9r1qxZ1o033mg1a9bMstlsVkxMjPXYY49ZhYWFTvN2795t3XvvvVZwcLBls9ms6OhoKzU11SopKbEs69dbbceNG2e1bNnS8vX1teLi4qyNGzdWqRec61bb1NTUcvVGRERYycnJTmOrVq2yOnXqZHl7e1sxMTHW3//+d2vcuHGWj49PlX4GDZmHZXEFDVAdixcv1q233qoNGzY4rswHAOnXI6bffvutcnJy3F3KRY1rPoBK/Pb7bUpLSzVjxgz5+/vrd7/7nZuqAnAx+G1/yMnJ0QcffNCgvpyvprjmA6jEQw89pF9++UXdu3dXSUmJFi5cqE8//VRTp06t0R1IAOqP6OhopaSkKDo6Wrt371ZGRoa8vb3P+dEE+H+cdgEqMW/ePL3wwgvKzc3VyZMn1aZNG40aNUqjR492d2kA3GzYsGFas2aNDhw4IJvNpu7du2vq1KkcFa0CwgcAADCKaz4AAIBRhA8AAGDURXfBqd1u1759++Tn58eHOAFuYlmWjh07prCwMMc3hl7s6B2Ae1Wnb1x04WPfvn0KDw93dxkAJP34449q1aqVu8uoEnoHcHGoSt+46MLH2Y/V/fHHHx0fMw7ArKKiIoWHhzvej3UBvQNwr+r0jYsufJw9XOrv708DAdysLp2+oHcAF4eq9I26cTIXAADUG4QPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgVGN3F4Dqi5zwvsu3ueuv/Vy+TQAXF1f3DvoGaoojHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKP4YjnUCr78DgBwLhz5AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEZVK3ykp6fr2muvlZ+fn0JCQpSUlKQdO3Y4zenVq5c8PDyclj/96U8uLRpA3ULvAFBWtcLHunXrlJqaqk2bNmnFihU6ffq0+vTpo+LiYqd5I0aM0P79+x3Lc88959KiAdQt9A4AZVXru12WLVvm9DgrK0shISHaunWrbrzxRsd4kyZNFBoaWqVtlpSUqKSkxPG4qKioOiUBqAPoHQDKuqBrPgoLCyVJQUFBTuPZ2dlq3ry52rVrp7S0NJ04ceKc20hPT1dAQIBjCQ8Pv5CSANQB9A6gYavxt9ra7XaNGTNGcXFxateunWP8zjvvVEREhMLCwvT111/riSee0I4dO7Rw4cIKt5OWlqaxY8c6HhcVFdFEgHqM3gGgxuEjNTVV33zzjTZs2OA0/sADDzj+3L59e7Vs2VI333yzdu7cqZiYmHLbsdlsstlsNS0DQB1D7wBQo9Muo0eP1nvvvac1a9aoVatWlc7t1q2bJCk3N7cmuwJQj9A7AEjVPPJhWZYeeughLVq0SGvXrlVUVNR5n7Nt2zZJUsuWLWtUIIC6j94BoKxqhY/U1FTNmzdP7777rvz8/HTgwAFJUkBAgHx9fbVz507NmzdPt9xyi5o1a6avv/5ajz76qG688UZ16NChVl4AgIsfvQNAWdUKHxkZGZJ+/TCgsjIzM5WSkiJvb2+tXLlS06dPV3FxscLDwzVw4ED95S9/cVnBAOoeegeAsqp92qUy4eHhWrdu3QUVBKD+oXcAKIvvdgEAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUY3dXcDFJnLC+y7f5q6/9nP5NgFcXFzdO+gbqM848gEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCKL5YDANQIX8SJmuLIBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIyqVvhIT0/XtddeKz8/P4WEhCgpKUk7duxwmnPy5EmlpqaqWbNmatq0qQYOHKj8/HyXFg2gbqF3ACirWuFj3bp1Sk1N1aZNm7RixQqdPn1affr0UXFxsWPOo48+qqVLl2rBggVat26d9u3bp9tuu83lhQOoO+gdAMpqXJ3Jy5Ytc3qclZWlkJAQbd26VTfeeKMKCwv1+uuva968ebrpppskSZmZmbrqqqu0adMmXXfdda6rHECdQe8AUNYFXfNRWFgoSQoKCpIkbd26VadPn1bv3r0dc2JjY9W6dWtt3Lixwm2UlJSoqKjIaQFQv9E7gIatWkc+yrLb7RozZozi4uLUrl07SdKBAwfk7e2twMBAp7ktWrTQgQMHKtxOenq6Jk+eXNMyANQx9A5UJnLC+y7d3q6/9nPp9uAaNT7ykZqaqm+++Ubz58+/oALS0tJUWFjoWH788ccL2h6Aixu9A0CNjnyMHj1a7733ntavX69WrVo5xkNDQ3Xq1CkVFBQ4/QaTn5+v0NDQCrdls9lks9lqUgaAOobeAUCq5pEPy7I0evRoLVq0SKtXr1ZUVJTT+s6dO8vLy0urVq1yjO3YsUN79uxR9+7dXVMxgDqH3gGgrGod+UhNTdW8efP07rvvys/Pz3EuNiAgQL6+vgoICNB9992nsWPHKigoSP7+/nrooYfUvXt3rlYHGjB6B4CyqhU+MjIyJEm9evVyGs/MzFRKSook6cUXX5Snp6cGDhyokpISJSQk6NVXX3VJsQDqJnoHgLKqFT4syzrvHB8fH82cOVMzZ86scVEA6hd6B4Cy+G4XAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEY1dncBQFVFTnjfpdvb9dd+Lt0egIuPq/uGRO9wBY58AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAoxq7uwBcHCInvO/uEgAADQRHPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGVTt8rF+/XgMGDFBYWJg8PDy0ePFip/UpKSny8PBwWvr27euqegHUQfQNAGVVO3wUFxerY8eOmjlz5jnn9O3bV/v373cs//73vy+oSAB1G30DQFnV/nj1xMREJSYmVjrHZrMpNDS0xkUBqF/oGwDKqpVrPtauXauQkBBdeeWVGjVqlI4cOXLOuSUlJSoqKnJaADQ81ekbEr0DqMtcHj769u2rN954Q6tWrdKzzz6rdevWKTExUaWlpRXOT09PV0BAgGMJDw93dUkALnLV7RsSvQOoy1z+rbZ33HGH48/t27dXhw4dFBMTo7Vr1+rmm28uNz8tLU1jx451PC4qKqKJAA1MdfuGRO8A6rJav9U2OjpazZs3V25uboXrbTab/P39nRYADdv5+oZE7wDqsloPHz/99JOOHDmili1b1vauANQT9A2gfqv2aZfjx487/TaSl5enbdu2KSgoSEFBQZo8ebIGDhyo0NBQ7dy5U48//rjatGmjhIQElxYOoO6gbwAoq9rhY8uWLYqPj3c8PnvONTk5WRkZGfr66681d+5cFRQUKCwsTH369NGUKVNks9lcVzWAOoW+AaCsaoePXr16ybKsc65fvnz5BRUEoP6hbwAoi+92AQAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRLv9uF5QXOeF9d5cAoI6hb6A+48gHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKMIHwAAwCjCBwAAMIrwAQAAjCJ8AAAAo6odPtavX68BAwYoLCxMHh4eWrx4sdN6y7L01FNPqWXLlvL19VXv3r2Vk5PjqnoB1EH0DQBlVTt8FBcXq2PHjpo5c2aF65977jm9/PLLeu211/TZZ5/pkksuUUJCgk6ePHnBxQKom+gbAMpqXN0nJCYmKjExscJ1lmVp+vTp+stf/qI//vGPkqQ33nhDLVq00OLFi3XHHXdcWLUA6iT6BoCyXHrNR15eng4cOKDevXs7xgICAtStWzdt3LixwueUlJSoqKjIaQHQcNSkb0j0DqAuc2n4OHDggCSpRYsWTuMtWrRwrPut9PR0BQQEOJbw8HBXlgTgIleTviHRO4C6zO13u6SlpamwsNCx/Pjjj+4uCUAdQO8A6i6Xho/Q0FBJUn5+vtN4fn6+Y91v2Ww2+fv7Oy0AGo6a9A2J3gHUZS4NH1FRUQoNDdWqVascY0VFRfrss8/UvXt3V+4KQD1B3wAanmrf7XL8+HHl5uY6Hufl5Wnbtm0KCgpS69atNWbMGP3v//6vLr/8ckVFRenJJ59UWFiYkpKSXFk3gDqEvgGgrGqHjy1btig+Pt7xeOzYsZKk5ORkZWVl6fHHH1dxcbEeeOABFRQU6Prrr9eyZcvk4+PjuqoB1Cn0DQBleViWZbm7iLKKiooUEBCgwsJCt5zDjZzwvvF9wj12/bWfu0u4aLn7fVgT7q6Z3tFw0DsqVp33oNvvdgEAAA0L4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGFXtz/m42HB7G4Dqom8A7sWRDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGNXY3QUA7hI54X2Xb3PXX/u5dHt1oUagoXH1+7I23pMXe40c+QAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAY5fLwMWnSJHl4eDgtsbGxrt4NgHqEvgE0LLXyCadt27bVypUr/38njfkgVQCVo28ADUetvLsbN26s0NDQ2tg0gHqKvgE0HLVyzUdOTo7CwsIUHR2tu+66S3v27Dnn3JKSEhUVFTktABqe6vQNid4B1GUuP/LRrVs3ZWVl6corr9T+/fs1efJk3XDDDfrmm2/k5+dXbn56eromT57s6jIAt6iNL4JrCKrbNyR6B+qPhtg3PCzLsmpzBwUFBYqIiNC0adN03333lVtfUlKikpISx+OioiKFh4ersLBQ/v7+591+Q/xLAy5EVb6dsqioSAEBAVV+H7ra+fqGdGG9g74BVI+r+0atX9EVGBioK664Qrm5uRWut9lsstlstV0GgDrkfH1DoncAdVmtf87H8ePHtXPnTrVs2bK2dwWgnqBvAPWby8PH+PHjtW7dOu3atUuffvqpbr31VjVq1EhDhw519a4A1BP0DaBhcflpl59++klDhw7VkSNHFBwcrOuvv16bNm1ScHCwq3cFoJ6gbwANi8vDx/z58129SQD1HH0DaFj4bhcAAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGEX4AAAARhE+AACAUYQPAABgFOEDAAAYRfgAAABGET4AAIBRhA8AAGAU4QMAABhF+AAAAEYRPgAAgFGEDwAAYBThAwAAGFVr4WPmzJmKjIyUj4+PunXrps2bN9fWrgDUE/QNoGGolfDx5ptvauzYsZo4caK++OILdezYUQkJCTp48GBt7A5APUDfABqOWgkf06ZN04gRIzRs2DBdffXVeu2119SkSRP94x//qI3dAagH6BtAw9HY1Rs8deqUtm7dqrS0NMeYp6enevfurY0bN5abX1JSopKSEsfjwsJCSVJRUVGV9mcvOXGBFQMNS1XeW2fnWJZV2+VIqn7fkC6sd9A3gOpxdd9wefg4fPiwSktL1aJFC6fxFi1aaPv27eXmp6ena/LkyeXGw8PDXV0aAEkB06s+99ixYwoICKi1Ws6qbt+Q6B2ASa7uGy4PH9WVlpamsWPHOh7b7XYdPXpUzZo1k4eHhxsruzBFRUUKDw/Xjz/+KH9/f3eXUyt4jfVDRa/RsiwdO3ZMYWFhbq7u3M7XO+ry3x21uwe1X5jq9A2Xh4/mzZurUaNGys/PdxrPz89XaGhoufk2m002m81pLDAw0NVluY2/v3+d+0dcXbzG+uG3r9HEEY+zqts3pKr3jrr8d0ft7kHtNVfVvuHyC069vb3VuXNnrVq1yjFmt9u1atUqde/e3dW7A1AP0DeAhqVWTruMHTtWycnJ6tKli7p27arp06eruLhYw4YNq43dAagH6BtAw1Er4WPIkCE6dOiQnnrqKR04cEDXXHONli1bVu5isvrMZrNp4sSJ5Q4L1ye8xvrhYnmNru4bF8vrqglqdw9qN8fDMnUvHQAAgPhuFwAAYBjhAwAAGEX4AAAARhE+AACAUYQPAABgFOHDxSZNmiQPDw+nJTY21t1ludzevXt19913q1mzZvL19VX79u21ZcsWd5flMpGRkeX+Hj08PJSamuru0lymtLRUTz75pKKiouTr66uYmBhNmTLF2JfJmbJr1y7dd999Tq9z4sSJOnXqlLtLq5JnnnlGPXr0UJMmTerEpz/PnDlTkZGR8vHxUbdu3bR582Z3l3Re69ev14ABAxQWFiYPDw8tXrzY3SVVWXp6uq699lr5+fkpJCRESUlJ2rFjh7vLOi/CRy1o27at9u/f71g2bNjg7pJc6ueff1ZcXJy8vLz04Ycf6r///a9eeOEFXXrppe4uzWU+//xzp7/DFStWSJIGDx7s5spc59lnn1VGRoZeeeUVfffdd3r22Wf13HPPacaMGe4uzaW2b98uu92uWbNm6dtvv9WLL76o1157TX/+85/dXVqVnDp1SoMHD9aoUaPcXcp5vfnmmxo7dqwmTpyoL774Qh07dlRCQoIOHjzo7tIqVVxcrI4dO2rmzJnuLqXa1q1bp9TUVG3atEkrVqzQ6dOn1adPHxUXF7u7tMpZcKmJEydaHTt2dHcZteqJJ56wrr/+eneXYdQjjzxixcTEWHa73d2luEy/fv2s4cOHO43ddttt1l133eWmisx57rnnrKioKHeXUS2ZmZlWQECAu8uoVNeuXa3U1FTH49LSUissLMxKT093Y1XVI8latGiRu8uosYMHD1qSrHXr1rm7lEpx5KMW5OTkKCwsTNHR0brrrru0Z88ed5fkUkuWLFGXLl00ePBghYSEqFOnTpozZ467y6o1p06d0r/+9S8NHz68Tn/T8m/16NFDq1at0vfffy9J+uqrr7RhwwYlJia6ubLaV1hYqKCgIHeXUa+cOnVKW7duVe/evR1jnp6e6t27tzZu3OjGyhqWwsJCSbro/30TPlysW7duysrK0rJly5SRkaG8vDzdcMMNOnbsmLtLc5kffvhBGRkZuvzyy7V8+XKNGjVKDz/8sObOnevu0mrF4sWLVVBQoJSUFHeX4lITJkzQHXfcodjYWHl5ealTp04aM2aM7rrrLneXVqtyc3M1Y8YMjRw50t2l1CuHDx9WaWlpuY/Db9GihQ4cOOCmqhoWu92uMWPGKC4uTu3atXN3OZUifLhYYmKiBg8erA4dOighIUEffPCBCgoK9NZbb7m7NJex2+363e9+p6lTp6pTp0564IEHNGLECL322mvuLq1WvP7660pMTFRYWJi7S3Gpt956S9nZ2Zo3b56++OILzZ07V88//3ydCZETJkyo8KLgssv27dudnrN371717dtXgwcP1ogRI9xUec1qB84nNTVV33zzjebPn+/uUs6rVr5YDv8vMDBQV1xxhXJzc91disu0bNlSV199tdPYVVddpXfeecdNFdWe3bt3a+XKlVq4cKG7S3G5xx57zHH0Q5Lat2+v3bt3Kz09XcnJyW6u7vzGjRt33qNR0dHRjj/v27dP8fHx6tGjh2bPnl3L1VWuurXXBc2bN1ejRo2Un5/vNJ6fn6/Q0FA3VdVwjB49Wu+9957Wr1+vVq1aubuc8yJ81LLjx49r586duueee9xdisvExcWVu5Xr+++/V0REhJsqqj2ZmZkKCQlRv3793F2Ky504cUKens4HPxs1aiS73e6miqonODhYwcHBVZq7d+9excfHq3PnzsrMzCz3uk2rTu11hbe3tzp37qxVq1YpKSlJ0q9HSVetWqXRo0e7t7h6zLIsPfTQQ1q0aJHWrl2rqKgod5dUJYQPFxs/frwGDBigiIgI7du3TxMnTlSjRo00dOhQd5fmMo8++qh69OihqVOn6vbbb9fmzZs1e/Zst/826Wp2u12ZmZlKTk5W48b1760yYMAAPfPMM2rdurXatm2rL7/8UtOmTdPw4cPdXZpL7d27V7169VJERISef/55HTp0yLGuLvxGvmfPHh09elR79uxRaWmptm3bJklq06aNmjZt6t7ifmPs2LFKTk5Wly5d1LVrV02fPl3FxcUaNmyYu0ur1PHjx52OTufl5Wnbtm0KCgpS69at3VjZ+aWmpmrevHl699135efn57i+JiAgQL6+vm6urhLuvt2mvhkyZIjVsmVLy9vb27rsssusIUOGWLm5ue4uy+WWLl1qtWvXzrLZbFZsbKw1e/Zsd5fkcsuXL7ckWTt27HB3KbWiqKjIeuSRR6zWrVtbPj4+VnR0tPU///M/VklJibtLc6nMzExLUoVLXZCcnFxh7WvWrHF3aRWaMWOG1bp1a8vb29vq2rWrtWnTJneXdF5r1qyp8GecnJzs7tLO61z/tjMzM91dWqU8LKuefZwhAAC4qHG3CwAAMIrwAQAAjCJ8AAAAowgfAADAKMIHAAAwivABAACMInwAAACjCB8AAMAowgcAADCK8AEAAIwifAAAAKP+Dy2n8yOK19I8AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(1, 2)\n", "ax[0].hist(X.iloc[:, 0])\n", "ax[0].set_title(\"Before scaling\")\n", "ax[1].hist(X_scaled[:, 0])\n", "ax[1].set_title(\"After scaling\")" ] }, { "cell_type": "markdown", "id": "af926567-41a6-44ad-b319-d93a97eba10d", "metadata": {}, "source": [ "The features have been standardizes, that means they have been shifted by the value of their mean to the left and then divided by their own standard deviation. We can see this effect by comparing the mean and the variance of a feature before and after applying the standard scaler." ] }, { "cell_type": "code", "execution_count": 9, "id": "9a6b0ca6-cc37-46f1-b72a-47a31e6d36c6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean before scaling: 5.843333333333334 , Variance before scaling: 0.6856935123042507\n", "Mean after scaling: -4.736951571734001e-16 , Variance after scaling: 1.0\n" ] } ], "source": [ "print(\"Mean before scaling:\", X.iloc[:, 0].mean(), \", Variance before scaling: \", X.iloc[:, 0].var())\n", "print(\"Mean after scaling:\", X_scaled[:, 0].mean(), \", Variance after scaling: \", X_scaled[:, 0].var())" ] }, { "cell_type": "markdown", "id": "2e78307a-7de2-4dca-a4bb-ff71ad05efbb", "metadata": {}, "source": [ "## Building a dataset class and a data loader\n", "\n", "Now that we know how to load and prpepare the IRIS dataset for training it is time to build a dataset class." ] }, { "cell_type": "code", "execution_count": 10, "id": "e394289d-bbf1-4a51-83b0-d2272243bb72", "metadata": {}, "outputs": [], "source": [ "class IRISDataset(Dataset):\n", " \"\"\"\n", " This class loads the IRIS data\n", " \"\"\"\n", " def __init__(self, is_train_dataset=True):\n", " \"\"\"\n", " Initialize the dataset class\n", " :param is_train_dataset: True if this class should use the training data, False otherwise\n", " \"\"\"\n", " # Load the IRIS dataset\n", " iris = load_iris(as_frame=False)\n", " \n", " # TODO Extract features and labels from the dataset\n", " X = _\n", " y = _\n", " \n", " # TODO Apply standard scaling to the features. This makes model training more stable.\n", " scaler = _\n", " X_scaled = _\n", " \n", " # TODO Split the data set into training and testing\n", " X_train, X_test, y_train, y_test = _(X_scaled, y, test_size=0.2, random_state=2)\n", " \n", " # Check whether the training or test data should be loaded\n", " if is_train_dataset:\n", " self.data = X_train\n", " self.labels = y_train\n", " else:\n", " self.data = X_test\n", " self.labels = y_test\n", "\n", " def __len__(self):\n", " \"\"\"\n", " This function returns the total number of items in the dataset.\n", " We are using a numpy array in this dataset which has an attribut named shape.\n", " The first dimension of shape is equal to the number of items in the dataset.\n", " :return: The number of rows in the CSV file\n", " \"\"\"\n", " # TODO return the size of the dataset\n", " return _\n", "\n", " def __getitem__(self, idx):\n", " \"\"\"\n", " This function returns a single tuple from the dataset.\n", " :param idx: The index of the tuple that should be returned.\n", " :return: Tuple of a feature vector and a y-value\n", " \"\"\"\n", " # TODO return a tuple of data points and labels\n", " return _, _" ] }, { "cell_type": "code", "execution_count": 11, "id": "13c29b96-c04e-464f-b396-0c160117a464", "metadata": {}, "outputs": [], "source": [ "# TODO Create two datasets: one for training and another one for testing\n", "dataset_train = _\n", "dataset_test = _" ] }, { "cell_type": "code", "execution_count": 12, "id": "18bb5256-2cd7-48e3-b75a-d4aa3b99602d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train dataset size: 120 , Test dataset size 30\n" ] } ], "source": [ "# Let's check how many items are in the training and the test dataset\n", "print(\"Train dataset size:\", len(dataset_train), \", Test dataset size\", len(dataset_test))" ] }, { "cell_type": "code", "execution_count": 13, "id": "7dec1b3a-e725-4646-9fd4-d84f0aabc41c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 0.4321654 , -0.59237301, 0.59224599, 0.79067065]), 2)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Let's sample the first item of the training dataset.\n", "# You will get a tuple consisting of the feature vector and the label of the item.\n", "dataset_train[0]" ] }, { "cell_type": "markdown", "id": "6165da74-d663-49be-9f7b-1cf0d14b56ac", "metadata": {}, "source": [ "In PyTorch you also have to define a data loader for each dataset which is responsible for drawing random samples from the dataset. The **batch size** defines how many random samples should be drawn from the dataset. As we have a training and a test dataset we also need to create one data loader for each of the datasets." ] }, { "cell_type": "code", "execution_count": 14, "id": "fb0a820f-63ce-411d-850c-fc5a56fd84bf", "metadata": {}, "outputs": [], "source": [ "# TODO Create data loaders for the IRIS dataset\n", "dataloader_train = _\n", "dataloader_test = _" ] }, { "cell_type": "code", "execution_count": 15, "id": "1a638a3e-2b8a-4b59-9350-fe226878b59c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[ 0.6745, -0.5924, 1.0469, 1.1856],\n", " [ 0.3110, -0.1320, 0.6491, 0.7907],\n", " [-1.2642, 0.7888, -1.0560, -1.3154],\n", " [-1.0218, 0.3284, -1.4539, -1.3154],\n", " [-0.4160, -1.2830, 0.1375, 0.1325],\n", " [ 1.0380, -1.2830, 1.1606, 0.7907],\n", " [-0.9007, 1.7096, -1.0560, -1.0522],\n", " [ 1.6438, 0.3284, 1.2743, 0.7907],\n", " [-1.6277, -1.7434, -1.3971, -1.1838],\n", " [ 1.0380, 0.5586, 1.1038, 1.7121],\n", " [-0.5372, 1.9398, -1.1697, -1.0522],\n", " [ 0.6745, 0.3284, 0.8764, 1.4488],\n", " [-1.2642, -0.1320, -1.3402, -1.4471],\n", " [-1.5065, 0.0982, -1.2834, -1.3154],\n", " [ 0.7957, -0.5924, 0.4786, 0.3958],\n", " [-0.2948, -0.8226, 0.2512, 0.1325],\n", " [-1.3854, 0.3284, -1.2266, -1.3154],\n", " [-0.1737, -0.3622, 0.2512, 0.1325],\n", " [ 0.6745, -0.3622, 0.3081, 0.1325],\n", " [ 1.0380, -0.1320, 0.7059, 0.6590],\n", " [-1.7489, 0.3284, -1.3971, -1.3154],\n", " [ 1.6438, -0.1320, 1.1606, 0.5274],\n", " [-1.2642, -0.1320, -1.3402, -1.1838],\n", " [-0.1737, -1.2830, 0.7059, 1.0539],\n", " [-1.0218, -0.1320, -1.2266, -1.3154],\n", " [ 0.1898, 0.7888, 0.4217, 0.5274],\n", " [-0.6583, 1.4794, -1.2834, -1.3154],\n", " [ 0.5533, -1.7434, 0.3649, 0.1325],\n", " [-1.3854, 0.3284, -1.3971, -1.3154],\n", " [-1.5065, 0.7888, -1.3402, -1.1838]], dtype=torch.float64),\n", " tensor([2, 2, 0, 0, 1, 2, 0, 2, 0, 2, 0, 2, 0, 0, 1, 1, 0, 1, 1, 1, 0, 2, 0, 2,\n", " 0, 1, 0, 1, 0, 0], dtype=torch.int32)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can get a random sample from a data loader by wrapping it into next(iter()).\n", "# You will get a tuple that contains a batch of feature vectors and their corresponding labels.\n", "# If you run this code multiple times you always will get another random sample.\n", "next(iter(dataloader_test))" ] }, { "cell_type": "markdown", "id": "4d1513ed-7c7d-4638-91a4-a5843732641c", "metadata": {}, "source": [ "## Building a neural network for classification\n", "\n", "Great! We now know how to load the IRIS dataset, preprocess it, split it into training and test data and build Dataset and DataLoader classes. The next step is to define a neural network that is able to consume our data and predict the species based on the four features of the data." ] }, { "cell_type": "code", "execution_count": 16, "id": "3f185b58-fce3-4837-8de8-cdef4ff8f728", "metadata": {}, "outputs": [], "source": [ "class IRISClassificationNetwork(nn.Module):\n", " def __init__(self):\n", " \"\"\"\n", " Here we define the layers of our neural network.\n", " \"\"\"\n", " super(IRISClassificationNetwork, self).__init__()\n", " # Our data has four features, so the first linear layer has to have four input dimensions.\n", " # TODO add the linear layer\n", " self.layer1 = _\n", " # The first hidden layer need to have the same input dimension as layer1 has outputs. \n", " # TODO add the linear layer\n", " self.layer2 = _\n", " # We have three different classes in out data, so the last linear layer must have 3 output dimensions.\n", " # TODO add the linear layer\n", " self.layer3 = _\n", " # TODO Add a ReLU layer\n", " self.activation = _\n", " # The outputs of the last linear layer need to be mapped to a probability function.\n", " # This can be done by running the vectors through a softmax function.\n", " # TODO add the softmax function\n", " self.classification = _\n", " \n", " def forward(self, x):\n", " \"\"\"\n", " The forward function takes a data vector and runs it through the layers of our neural network.\n", " :return: The forward function returns a vector of size 3 which contains the\n", " probabilities for all three classes for a given data vector.\n", " \"\"\"\n", " # TODO Run the input through the first linear layer and then through the activation function.\n", " x = _\n", " # TODO Run the outputs of layer 1 through layer 2.\n", " x = _\n", " # TODO Run the outputs of layer 2 through the third linear layer and then through the softmax classification function.\n", " x = _\n", " return x" ] }, { "cell_type": "code", "execution_count": 17, "id": "10f4d593-6e16-406c-8408-bb56ec7747f4", "metadata": {}, "outputs": [], "source": [ "# TODO Now that we have defined the network class we need to create an instance of it\n", "net = _" ] }, { "cell_type": "code", "execution_count": 18, "id": "566d562e-811c-470f-9740-8c40874daf3e", "metadata": {}, "outputs": [], "source": [ "def get_accuracy(net, dataloader):\n", " \"\"\"\n", " This function computes the accuracy of the neural network by sampling data from a\n", " data loader, running it through the network and computing the percentage of correct predictions.\n", " :param net: The neural network\n", " :param dataloader: A DataLoader instance\n", " \"\"\"\n", " # torch.no_grad means that no gradients should be computed when running data through the network.\n", " # When we run test data through the network this should not have an effect on our training, that is\n", " # why we don't want to compute gradients here.\n", " with torch.no_grad():\n", " X_test, y_test = next(iter(dataloader))\n", " y_pred = net(X_test.to(torch.float32))\n", " correct = (torch.argmax(y_pred, dim=1) == y_test).type(torch.float32)\n", " return correct.mean().item()" ] }, { "cell_type": "code", "execution_count": 19, "id": "7c238b0a-cae6-4c7a-9561-ab93fb7c8137", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy before training: 0.36666667461395264\n" ] } ], "source": [ "# Let's check the accuracy before training the network\n", "print(\"Accuracy before training:\", get_accuracy(net, dataloader_test))" ] }, { "cell_type": "markdown", "id": "a14de1e1-c98e-4b96-81f2-e1ce6b614d51", "metadata": {}, "source": [ "The accuracy of the untrained network is very bad, because the parameters of the network are initialized randomly. Let's train the network to find the parameters that allow us to make better predictions for the classes of our flowers." ] }, { "cell_type": "markdown", "id": "58df85a8-d098-4a7f-a3a9-807e6fba0ee3", "metadata": {}, "source": [ "## Training the network" ] }, { "cell_type": "code", "execution_count": 20, "id": "d7c16a77-d20f-4375-8071-d6dc1ca44658", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\tilof\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch\\autograd\\__init__.py:200: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ..\\c10\\cuda\\CUDAFunctions.cpp:109.)\n", " Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch [10/250], Loss: 1.0342, Accuracy on test data: 0.4000000059604645\n", "Epoch [20/250], Loss: 1.0147, Accuracy on test data: 0.800000011920929\n", "Epoch [30/250], Loss: 0.9456, Accuracy on test data: 0.7666666507720947\n", "Epoch [40/250], Loss: 0.8926, Accuracy on test data: 0.800000011920929\n", "Epoch [50/250], Loss: 0.8386, Accuracy on test data: 0.8333333134651184\n", "Epoch [60/250], Loss: 0.8800, Accuracy on test data: 0.8666666746139526\n", "Epoch [70/250], Loss: 0.7504, Accuracy on test data: 0.8666666746139526\n", "Epoch [80/250], Loss: 0.7091, Accuracy on test data: 0.8333333134651184\n", "Epoch [90/250], Loss: 0.6661, Accuracy on test data: 0.8666666746139526\n", "Epoch [100/250], Loss: 0.7385, Accuracy on test data: 0.8666666746139526\n", "Epoch [110/250], Loss: 0.6871, Accuracy on test data: 0.8999999761581421\n", "Epoch [120/250], Loss: 0.7192, Accuracy on test data: 0.8999999761581421\n", "Epoch [130/250], Loss: 0.6934, Accuracy on test data: 0.8999999761581421\n", "Epoch [140/250], Loss: 0.6540, Accuracy on test data: 0.8999999761581421\n", "Epoch [150/250], Loss: 0.6473, Accuracy on test data: 0.8999999761581421\n", "Epoch [160/250], Loss: 0.5812, Accuracy on test data: 0.9333333373069763\n", "Epoch [170/250], Loss: 0.6593, Accuracy on test data: 0.9333333373069763\n", "Epoch [180/250], Loss: 0.6610, Accuracy on test data: 0.9333333373069763\n", "Epoch [190/250], Loss: 0.5954, Accuracy on test data: 0.9333333373069763\n", "Epoch [200/250], Loss: 0.5859, Accuracy on test data: 0.9333333373069763\n", "Epoch [210/250], Loss: 0.6564, Accuracy on test data: 0.9333333373069763\n", "Epoch [220/250], Loss: 0.6279, Accuracy on test data: 0.9333333373069763\n", "Epoch [230/250], Loss: 0.5801, Accuracy on test data: 0.9333333373069763\n", "Epoch [240/250], Loss: 0.6023, Accuracy on test data: 0.9333333373069763\n", "Epoch [250/250], Loss: 0.6225, Accuracy on test data: 0.9333333373069763\n" ] } ], "source": [ "# Here we define how long we want to train the network\n", "num_epochs = 250\n", "# TODO This is our loss function. Which one do we need for classification: MSELoss or CrossEntropyLoss?\n", "criterion = _\n", "# TODO This is the algorithm used for optimizing our neural network parameters.\n", "optimizer = optim.Adam(_, lr=0.001)\n", "\n", "for epoch in range(num_epochs):\n", " # TODO Draw data from the data loader\n", " X, Y = _\n", " \n", " # TODO Forward pass\n", " outputs = _\n", " \n", " # TODO Compute the difference between the true labels and the predicted labels\n", " loss = _\n", "\n", " # TODO First reset the gradients\n", " _\n", " \n", " # TODO Then compute the new gradients\n", " _\n", " \n", " # TODO And finally perform the backpropagation step\n", " _\n", "\n", " # Print some metrics about the learning progress\n", " if (epoch + 1) % 10 == 0:\n", " accuracy = get_accuracy(net, dataloader_test)\n", " print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Accuracy on test data:\", accuracy)" ] }, { "cell_type": "code", "execution_count": 21, "id": "109bbc85-19ed-4d6f-8ee2-8bc88c2b7500", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy after training: 0.9333333373069763\n" ] } ], "source": [ "# Let's check the accuracy after training the network\n", "print(\"Accuracy after training:\", get_accuracy(net, dataloader_test))" ] }, { "cell_type": "markdown", "id": "d0f98318-5f5b-4624-859e-89b1d838a23f", "metadata": {}, "source": [ "The accuracy has improved a lot. It has nearly reached 100% which is very good." ] }, { "cell_type": "markdown", "id": "265a3be6-ec01-4bb1-a77a-27610176cf07", "metadata": {}, "source": [ "## Making predictions\n", "\n", "Now that we have a trained network we can make predictions for data vectors." ] }, { "cell_type": "code", "execution_count": 22, "id": "5e2c019c-88a5-4b92-88a4-e79ad2d9bebf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This is our data vector: tensor([-1.5065, 0.7888, -1.3402, -1.1838])\n", "And this is the corresponding label: 0\n" ] } ], "source": [ "# First let's sample a data vector from our test dataset.\n", "X, y = dataset_test[0]\n", "# Create torch tensors\n", "X = torch.tensor(X).to(torch.float32)\n", "y = torch.tensor(y)\n", "print(\"This is our data vector:\", X)\n", "print(\"And this is the corresponding label:\", y.item())" ] }, { "cell_type": "code", "execution_count": 23, "id": "1c172493-bd66-44e7-8a4d-b06c889b836f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This is the prediction of the neural network: [[9.9939942e-01 5.9499586e-04 5.6224708e-06]]\n" ] } ], "source": [ "# Run the data vector through the network\n", "with torch.no_grad():\n", " y_pred = net(X.reshape(1, -1))\n", " \n", "# Retransform the prediction to a numpy array\n", "y_pred = y_pred.numpy()\n", "\n", "print(\"This is the prediction of the neural network:\", y_pred)" ] }, { "cell_type": "code", "execution_count": 24, "id": "c3a8f761-4182-4dda-ac70-78338e968d3f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeNklEQVR4nO3de3DV5Z348U8CJtHRBFlKAjS78dKirRYoSDa4rTpNzbYMu/yxsyy6wjJeVod20Oy2Ja3CWneNdioys6bL1i11p11HelntTmFxaJQ61lRWLjNq0S7eoNYEWLYJhi5pk+/vD38eG0mQkwt5El6vmfMHX57ne57zzJmT93zPOUlBlmVZAAAkpHCkFwAA8G4CBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOSMH+kFnIienp745S9/GWeddVYUFBSM9HIAgBOQZVkcPnw4pk6dGoWF+V0TGRWB8stf/jIqKytHehkAwADs27cv3v/+9+c1Z1QEyllnnRURbz3A0tLSEV4NAHAiOjo6orKyMvdzPB+jIlDefluntLRUoADAKDOQj2f4kCwAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACQn70B54oknYsGCBTF16tQoKCiIRx555D3nbN26NT760Y9GcXFxnH/++fHAAw8MYKkAwKki70Dp7OyMGTNmRFNT0wmNf+WVV2L+/PlxxRVXxK5du+Lmm2+O6667Lh599NG8FwsAnBry/mOBn/rUp+JTn/rUCY9ft25dnHPOOXHPPfdERMSFF14YTz75ZNx7771RV1eX790DAKeAYf8MSktLS9TW1vY6VldXFy0tLf3OOXr0aHR0dPS6AQCnjryvoOSrtbU1ysvLex0rLy+Pjo6O+PWvfx2nn376MXMaGxvj9ttvH+6lRURE1cqNJ+V+SNerd80f6SUA8C5JfounoaEh2tvbc7d9+/aN9JIAgJNo2K+gVFRURFtbW69jbW1tUVpa2ufVk4iI4uLiKC4uHu6lAQCJGvYrKDU1NdHc3Nzr2JYtW6Kmpma47xoAGKXyDpQ333wzdu3aFbt27YqIt75GvGvXrti7d29EvPX2zJIlS3Ljb7zxxnj55Zfj85//fLzwwgvxta99Lb7zne/ELbfcMjSPAAAYc/IOlGeeeSZmzZoVs2bNioiI+vr6mDVrVqxatSoiIt54441crEREnHPOObFx48bYsmVLzJgxI+655574l3/5F18xBgD6VZBlWTbSi3gvHR0dUVZWFu3t7VFaWjqk5/YtHnyLB2B4DObnd5Lf4gEATm0CBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkjOgQGlqaoqqqqooKSmJ6urq2LZt23HHr127NqZPnx6nn356VFZWxi233BL/93//N6AFAwBjX96BsmHDhqivr4/Vq1fHjh07YsaMGVFXVxf79+/vc/yDDz4YK1eujNWrV8fu3bvjG9/4RmzYsCG++MUvDnrxAMDYlHegrFmzJq6//vpYtmxZfOhDH4p169bFGWecEevXr+9z/FNPPRWXXnppXHXVVVFVVRVXXnllLF68+D2vugAAp668AqWrqyu2b98etbW175ygsDBqa2ujpaWlzznz5s2L7du354Lk5Zdfjk2bNsWnP/3pQSwbABjLxucz+ODBg9Hd3R3l5eW9jpeXl8cLL7zQ55yrrroqDh48GH/0R38UWZbFb3/727jxxhuP+xbP0aNH4+jRo7l/d3R05LNMAGCUG/Zv8WzdujXuvPPO+NrXvhY7duyIf//3f4+NGzfGHXfc0e+cxsbGKCsry90qKyuHe5kAQELyuoIyadKkGDduXLS1tfU63tbWFhUVFX3Oue222+Kaa66J6667LiIiLr744ujs7IwbbrghvvSlL0Vh4bGN1NDQEPX19bl/d3R0iBQAOIXkdQWlqKgoZs+eHc3NzbljPT090dzcHDU1NX3OOXLkyDERMm7cuIiIyLKszznFxcVRWlra6wYAnDryuoISEVFfXx9Lly6NOXPmxNy5c2Pt2rXR2dkZy5Yti4iIJUuWxLRp06KxsTEiIhYsWBBr1qyJWbNmRXV1dezZsyduu+22WLBgQS5UAAB+V96BsmjRojhw4ECsWrUqWltbY+bMmbF58+bcB2f37t3b64rJrbfeGgUFBXHrrbfG66+/Hu973/tiwYIF8Q//8A9D9ygAgDGlIOvvfZaEdHR0RFlZWbS3tw/52z1VKzcO6fkYfV69a/5ILwFgTBrMz29/iwcASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AwqUpqamqKqqipKSkqiuro5t27Ydd/yvfvWrWL58eUyZMiWKi4vjgx/8YGzatGlACwYAxr7x+U7YsGFD1NfXx7p166K6ujrWrl0bdXV18eKLL8bkyZOPGd/V1RWf/OQnY/LkyfG9730vpk2bFq+99lpMmDBhKNYPAIxBeQfKmjVr4vrrr49ly5ZFRMS6deti48aNsX79+li5cuUx49evXx+HDh2Kp556Kk477bSIiKiqqhrcqgGAMS2vt3i6urpi+/btUVtb+84JCgujtrY2Wlpa+pzzH//xH1FTUxPLly+P8vLyuOiii+LOO++M7u7ufu/n6NGj0dHR0esGAJw68gqUgwcPRnd3d5SXl/c6Xl5eHq2trX3Oefnll+N73/tedHd3x6ZNm+K2226Le+65J/7+7/++3/tpbGyMsrKy3K2ysjKfZQIAo9ywf4unp6cnJk+eHF//+tdj9uzZsWjRovjSl74U69at63dOQ0NDtLe352779u0b7mUCAAnJ6zMokyZNinHjxkVbW1uv421tbVFRUdHnnClTpsRpp50W48aNyx278MILo7W1Nbq6uqKoqOiYOcXFxVFcXJzP0gCAMSSvKyhFRUUxe/bsaG5uzh3r6emJ5ubmqKmp6XPOpZdeGnv27Imenp7csZ///OcxZcqUPuMEACDvt3jq6+vj/vvvj3/913+N3bt3x0033RSdnZ25b/UsWbIkGhoacuNvuummOHToUKxYsSJ+/vOfx8aNG+POO++M5cuXD92jAADGlLy/Zrxo0aI4cOBArFq1KlpbW2PmzJmxefPm3Adn9+7dG4WF73RPZWVlPProo3HLLbfERz7ykZg2bVqsWLEivvCFLwzdowAAxpSCLMuykV7Ee+no6IiysrJob2+P0tLSIT131cqNQ3o+Rp9X75o/0ksAGJMG8/Pb3+IBAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkDCpSmpqaoqqqKkpKSqK6ujm3btp3QvIceeigKCgpi4cKFA7lbAOAUkXegbNiwIerr62P16tWxY8eOmDFjRtTV1cX+/fuPO+/VV1+Nv/3bv42PfexjA14sAHBqyDtQ1qxZE9dff30sW7YsPvShD8W6devijDPOiPXr1/c7p7u7O66++uq4/fbb49xzzx3UggGAsS+vQOnq6ort27dHbW3tOycoLIza2tpoaWnpd96Xv/zlmDx5clx77bUndD9Hjx6Njo6OXjcA4NSRV6AcPHgwuru7o7y8vNfx8vLyaG1t7XPOk08+Gd/4xjfi/vvvP+H7aWxsjLKystytsrIyn2UCAKPcsH6L5/Dhw3HNNdfE/fffH5MmTTrheQ0NDdHe3p677du3bxhXCQCkZnw+gydNmhTjxo2Ltra2Xsfb2tqioqLimPEvvfRSvPrqq7FgwYLcsZ6enrfuePz4ePHFF+O88847Zl5xcXEUFxfnszQAYAzJ6wpKUVFRzJ49O5qbm3PHenp6orm5OWpqao4Zf8EFF8Szzz4bu3btyt3+5E/+JK644orYtWuXt24AgD7ldQUlIqK+vj6WLl0ac+bMiblz58batWujs7Mzli1bFhERS5YsiWnTpkVjY2OUlJTERRdd1Gv+hAkTIiKOOQ4A8La8A2XRokVx4MCBWLVqVbS2tsbMmTNj8+bNuQ/O7t27NwoL/YJaAGDgCrIsy0Z6Ee+lo6MjysrKor29PUpLS4f03FUrNw7p+Rh9Xr1r/kgvAWBMGszPb5c6AIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkjOgQGlqaoqqqqooKSmJ6urq2LZtW79j77///vjYxz4WZ599dpx99tlRW1t73PEAAHkHyoYNG6K+vj5Wr14dO3bsiBkzZkRdXV3s37+/z/Fbt26NxYsXx+OPPx4tLS1RWVkZV155Zbz++uuDXjwAMDYVZFmW5TOhuro6LrnkkrjvvvsiIqKnpycqKyvjs5/9bKxcufI953d3d8fZZ58d9913XyxZsuSE7rOjoyPKysqivb09SktL81nue6pauXFIz8fo8+pd80d6CQBj0mB+fud1BaWrqyu2b98etbW175ygsDBqa2ujpaXlhM5x5MiR+M1vfhMTJ07sd8zRo0ejo6Oj1w0AOHXkFSgHDx6M7u7uKC8v73W8vLw8WltbT+gcX/jCF2Lq1Km9IufdGhsbo6ysLHerrKzMZ5kAwCh3Ur/Fc9ddd8VDDz0UDz/8cJSUlPQ7rqGhIdrb23O3ffv2ncRVAgAjbXw+gydNmhTjxo2Ltra2Xsfb2tqioqLiuHO/+tWvxl133RU/+tGP4iMf+chxxxYXF0dxcXE+SwMAxpC8rqAUFRXF7Nmzo7m5OXesp6cnmpubo6ampt95X/nKV+KOO+6IzZs3x5w5cwa+WgDglJDXFZSIiPr6+li6dGnMmTMn5s6dG2vXro3Ozs5YtmxZREQsWbIkpk2bFo2NjRERcffdd8eqVaviwQcfjKqqqtxnVc4888w488wzh/ChAABjRd6BsmjRojhw4ECsWrUqWltbY+bMmbF58+bcB2f37t0bhYXvXJj5p3/6p+jq6oo/+7M/63We1atXx9/93d8NbvUAwJiU9+9BGQl+DwrDye9BARgeJ+33oAAAnAwCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkjOgQGlqaoqqqqooKSmJ6urq2LZt23HHf/e7340LLrggSkpK4uKLL45NmzYNaLEAwKkh70DZsGFD1NfXx+rVq2PHjh0xY8aMqKuri/379/c5/qmnnorFixfHtddeGzt37oyFCxfGwoUL47nnnhv04gGAsakgy7IsnwnV1dVxySWXxH333RcRET09PVFZWRmf/exnY+XKlceMX7RoUXR2dsYPf/jD3LE//MM/jJkzZ8a6detO6D47OjqirKws2tvbo7S0NJ/lvqeqlRuH9HyMPq/eNX+klwAwJg3m5/f4fAZ3dXXF9u3bo6GhIXessLAwamtro6Wlpc85LS0tUV9f3+tYXV1dPPLII/3ez9GjR+Po0aO5f7e3t0fEWw90qPUcPTLk52R0GY7nFQDvvL7meS0kIvIMlIMHD0Z3d3eUl5f3Ol5eXh4vvPBCn3NaW1v7HN/a2trv/TQ2Nsbtt99+zPHKysp8lgsnpGztSK8AYGw7fPhwlJWV5TUnr0A5WRoaGnpddenp6YlDhw7F7/3e70VBQUHueEdHR1RWVsa+ffuG/K2fU4U9HBz7N3j2cHDs3+DZw8E53v5lWRaHDx+OqVOn5n3evAJl0qRJMW7cuGhra+t1vK2tLSoqKvqcU1FRkdf4iIji4uIoLi7udWzChAn9ji8tLfWkGiR7ODj2b/Ds4eDYv8Gzh4PT3/7le+XkbXl9i6eoqChmz54dzc3NuWM9PT3R3NwcNTU1fc6pqanpNT4iYsuWLf2OBwDI+y2e+vr6WLp0acyZMyfmzp0ba9eujc7Ozli2bFlERCxZsiSmTZsWjY2NERGxYsWKuOyyy+Kee+6J+fPnx0MPPRTPPPNMfP3rXx/aRwIAjBl5B8qiRYviwIEDsWrVqmhtbY2ZM2fG5s2bcx+E3bt3bxQWvnNhZt68efHggw/GrbfeGl/84hfjAx/4QDzyyCNx0UUXDXrxxcXFsXr16mPeDuLE2cPBsX+DZw8Hx/4Nnj0cnOHav7x/DwoAwHDzt3gAgOQIFAAgOQIFAEiOQAEAkjPqAuXQoUNx9dVXR2lpaUyYMCGuvfbaePPNN4875/LLL4+CgoJetxtvvPEkrXjkNTU1RVVVVZSUlER1dXVs27btuOO/+93vxgUXXBAlJSVx8cUXx6ZNm07SStOUz/498MADxzzXSkpKTuJq0/LEE0/EggULYurUqVFQUHDcv8H1tq1bt8ZHP/rRKC4ujvPPPz8eeOCBYV9nyvLdw61btx7zHCwoKDjunxcZyxobG+OSSy6Js846KyZPnhwLFy6MF1988T3neR18y0D2b6heB0ddoFx99dXx/PPPx5YtW+KHP/xhPPHEE3HDDTe857zrr78+3njjjdztK1/5yklY7cjbsGFD1NfXx+rVq2PHjh0xY8aMqKuri/379/c5/qmnnorFixfHtddeGzt37oyFCxfGwoUL47nnnjvJK09DvvsX8dZvU/zd59prr712Elecls7OzpgxY0Y0NTWd0PhXXnkl5s+fH1dccUXs2rUrbr755rjuuuvi0UcfHeaVpivfPXzbiy++2Ot5OHny5GFaYdp+/OMfx/Lly+OnP/1pbNmyJX7zm9/ElVdeGZ2dnf3O8Tr4joHsX8QQvQ5mo8jPfvazLCKy//qv/8od+8///M+soKAge/311/udd9lll2UrVqw4CStMz9y5c7Ply5fn/t3d3Z1NnTo1a2xs7HP8n//5n2fz58/vday6ujr767/+62FdZ6ry3b9vfvObWVlZ2Ula3egSEdnDDz983DGf//znsw9/+MO9ji1atCirq6sbxpWNHieyh48//ngWEdn//u//npQ1jTb79+/PIiL78Y9/3O8Yr4P9O5H9G6rXwVF1BaWlpSUmTJgQc+bMyR2rra2NwsLCePrpp48799/+7d9i0qRJcdFFF0VDQ0McOXJkuJc74rq6umL79u1RW1ubO1ZYWBi1tbXR0tLS55yWlpZe4yMi6urq+h0/lg1k/yIi3nzzzfiDP/iDqKysjD/90z+N559//mQsd0zw/Bs6M2fOjClTpsQnP/nJ+MlPfjLSy0lGe3t7RERMnDix3zGeh/07kf2LGJrXwVEVKK2trcdcphw/fnxMnDjxuO+vXnXVVfHtb387Hn/88WhoaIhvfetb8Zd/+ZfDvdwRd/Dgweju7s79lt+3lZeX97tfra2teY0fywayf9OnT4/169fHD37wg/j2t78dPT09MW/evPjFL35xMpY86vX3/Ovo6Ihf//rXI7Sq0WXKlCmxbt26+P73vx/f//73o7KyMi6//PLYsWPHSC9txPX09MTNN98cl1566XF/m7nXwb6d6P4N1etg3r/qfjisXLky7r777uOO2b1794DP/7ufUbn44otjypQp8YlPfCJeeumlOO+88wZ8Xni3mpqaXn8Ic968eXHhhRfGP//zP8cdd9wxgivjVDF9+vSYPn167t/z5s2Ll156Ke6999741re+NYIrG3nLly+P5557Lp588smRXsqodKL7N1Svg0kEyt/8zd/EX/3VXx13zLnnnhsVFRXHfDjxt7/9bRw6dCgqKipO+P6qq6sjImLPnj1jOlAmTZoU48aNi7a2tl7H29ra+t2vioqKvMaPZQPZv3c77bTTYtasWbFnz57hWOKY09/zr7S0NE4//fQRWtXoN3fu3FP+h/JnPvOZ3Bcr3v/+9x93rNfBY+Wzf+820NfBJN7ied/73hcXXHDBcW9FRUVRU1MTv/rVr2L79u25uY899lj09PTkouNE7Nq1KyLeuhQ6lhUVFcXs2bOjubk5d6ynpyeam5t71e3vqqmp6TU+ImLLli39jh/LBrJ/79bd3R3PPvvsmH+uDRXPv+Gxa9euU/Y5mGVZfOYzn4mHH344HnvssTjnnHPec47n4TsGsn/vNuDXwUF/zPYk++M//uNs1qxZ2dNPP509+eST2Qc+8IFs8eLFuf//xS9+kU2fPj17+umnsyzLsj179mRf/vKXs2eeeSZ75ZVXsh/84AfZueeem3384x8fqYdwUj300ENZcXFx9sADD2Q/+9nPshtuuCGbMGFC1trammVZll1zzTXZypUrc+N/8pOfZOPHj8+++tWvZrt3785Wr16dnXbaadmzzz47Ug9hROW7f7fffnv26KOPZi+99FK2ffv27C/+4i+ykpKS7Pnnnx+phzCiDh8+nO3cuTPbuXNnFhHZmjVrsp07d2avvfZalmVZtnLlyuyaa67JjX/55ZezM844I/vc5z6X7d69O2tqasrGjRuXbd68eaQewojLdw/vvffe7JFHHsn++7//O3v22WezFStWZIWFhdmPfvSjkXoII+qmm27KysrKsq1bt2ZvvPFG7nbkyJHcGK+D/RvI/g3V6+CoC5T/+Z//yRYvXpydeeaZWWlpabZs2bLs8OHDuf9/5ZVXsojIHn/88SzLsmzv3r3Zxz/+8WzixIlZcXFxdv7552ef+9znsvb29hF6BCffP/7jP2a///u/nxUVFWVz587NfvrTn+b+77LLLsuWLl3aa/x3vvOd7IMf/GBWVFSUffjDH842btx4kleclnz27+abb86NLS8vzz796U9nO3bsGIFVp+Htr7y++/b2ni1dujS77LLLjpkzc+bMrKioKDv33HOzb37zmyd93SnJdw/vvvvu7LzzzstKSkqyiRMnZpdffnn22GOPjcziE9DX3kVEr+eV18H+DWT/hup1sOD/LwAAIBlJfAYFAOB3CRQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkvP/APocnaEGSKbaAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Let's plot the predicted classes with a histogram\n", "possible_classes = [0,1,2]\n", "# We plot the possible classes on the x-axis and the probabilities on the y-axis\n", "plt.bar(possible_classes, y_pred.squeeze())" ] }, { "cell_type": "markdown", "id": "65b51e21-84af-4587-a3f4-bb0500879768", "metadata": {}, "source": [ "As you can see the network predicts the label 0 with a probabiliity of nearly 100%. The other probabilities are so small that the can't even be seen in the bar chart." ] }, { "cell_type": "code", "execution_count": null, "id": "f89c655d-d522-4193-8af5-eef27e6a32a9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }