Skip to content
Snippets Groups Projects
GINN_P3.ipynb 119 KiB
Newer Older
  • Learn to ignore specific revisions
  • Friedrich Heitzer's avatar
    Friedrich Heitzer committed
     "nbformat": 4,
     "nbformat_minor": 0,
     "metadata": {
      "colab": {
       "provenance": [],
       "gpuType": "T4",
       "toc_visible": true
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
      "kernelspec": {
       "name": "python3",
       "language": "python",
       "display_name": "Python 3 (ipykernel)"
      },
      "language_info": {
       "name": "python"
      },
      "accelerator": "GPU"
     },
     "cells": [
      {
       "cell_type": "markdown",
       "source": [
        "# 0.  Imports und Helper"
       ],
       "metadata": {
        "id": "194EMZeTSLIk"
       }
      },
      {
       "cell_type": "code",
       "source": [
        "!pip install torchinfo"
       ],
       "metadata": {
        "id": "6IoTrfAlzktH",
        "colab": {
         "base_uri": "https://localhost:8080/"
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
        "outputId": "f9bdc25e-4895-436b-91ab-00d3a38af883",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:19:19.732145Z",
         "start_time": "2023-11-20T22:19:18.131897Z"
        }
       },
       "execution_count": 1,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "Collecting torchinfo\r\n",
          "  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)\r\n",
          "Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)\r\n",
          "Installing collected packages: torchinfo\r\n",
          "Successfully installed torchinfo-1.8.0\r\n"
         ]
        }
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 4,
       "metadata": {
        "id": "dPt0DPgfLjEQ",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:29:17.347227Z",
         "start_time": "2023-11-20T22:29:15.452701Z"
        }
       },
       "outputs": [],
       "source": [
        "# Imports\n",
        "import copy\n",
        "\n",
        "import ipywidgets as widgets\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "import time\n",
        "import torch\n",
        "import torchvision\n",
        "#import torchvision.datasets as datasets\n",
        "import torch.nn.functional as F\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import tqdm\n",
        "\n",
        "import random\n",
        "import keras.datasets.imdb\n",
        "\n",
        "from torch.autograd import Variable\n",
        "from tqdm.auto import tqdm as tqdmauto"
       ]
      },
      {
       "cell_type": "code",
       "source": [
        "def set_seed(seed=None, seed_torch=True):\n",
        "  \"\"\"\n",
        "  Handles variability by controlling sources of randomness\n",
        "  through set seed values\n",
        "\n",
        "  Args:\n",
        "    seed: Integer\n",
        "      Set the seed value to given integer.\n",
        "      If no seed, set seed value to random integer in the range 2^32\n",
        "    seed_torch: Bool\n",
        "      Seeds the random number generator for all devices to\n",
        "      offer some guarantees on reproducibility\n",
        "\n",
        "  Returns:\n",
        "    Nothing\n",
        "  \"\"\"\n",
        "  if seed is None:\n",
        "    seed = np.random.choice(2 ** 32)\n",
        "  random.seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  if seed_torch:\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "  print(f'Random seed {seed} has been set.')\n",
        "SEED = 2021\n",
        "set_seed(seed=SEED)\n",
        "DEVICE = \"mps\"\n",
        "\n",
        "def zero_grad(params):\n",
        "  \"\"\"\n",
        "  Clear gradients as they accumulate on successive backward calls\n",
        "\n",
        "  Args:\n",
        "    params: an iterator over tensors\n",
        "      i.e., updating the Weights and biases\n",
        "\n",
        "  Returns:\n",
        "    Nothing\n",
        "  \"\"\"\n",
        "  for par in params:\n",
        "    if not(par.grad is None):\n",
        "      par.grad.data.zero_()\n",
        "\n",
        "\n",
        "def print_params(model):\n",
        "  \"\"\"\n",
        "  Lists the name and current value of the model's\n",
        "  named parameters\n",
        "\n",
        "  Args:\n",
        "    model: an nn.Module inherited model\n",
        "      Represents the ML/DL model\n",
        "\n",
        "  Returns:\n",
        "    Nothing\n",
        "  \"\"\"\n",
        "  for name, param in model.named_parameters():\n",
        "    if param.requires_grad:\n",
        "      print(name, param.data)\n",
        "\n",
        "def sample_minibatch(input_data, target_data, num_points=100):\n",
        "  \"\"\"\n",
        "  Sample a minibatch of size num_point from the provided input-target data\n",
        "\n",
        "  Args:\n",
        "    input_data: Tensor\n",
        "      Multi-dimensional tensor containing the input data\n",
        "    target_data: Tensor\n",
        "      1D tensor containing the class labels\n",
        "    num_points: Integer\n",
        "      Number of elements to be included in minibatch with default=100\n",
        "\n",
        "  Returns:\n",
        "    batch_inputs: Tensor\n",
        "      Minibatch inputs\n",
        "    batch_targets: Tensor\n",
        "      Minibatch targets\n",
        "  \"\"\"\n",
        "  # Sample a collection of IID indices from the existing data\n",
        "  batch_indices = np.random.choice(len(input_data), num_points)\n",
        "  # Use batch_indices to extract entries from the input and target data tensors\n",
        "  batch_inputs = input_data[batch_indices, :]\n",
        "  batch_targets = target_data[batch_indices]\n",
        "\n",
        "  return batch_inputs, batch_targets\n",
        "\n",
        "\n",
        "def gradient_update(loss, params, lr=1e-3):\n",
        "  \"\"\"\n",
        "  Perform a gradient descent update on a given loss over a collection of parameters\n",
        "\n",
        "  Args:\n",
        "    loss: Tensor\n",
        "      A scalar tensor containing the loss through which the gradient will be computed\n",
        "    params: List of iterables\n",
        "      Collection of parameters with respect to which we compute gradients\n",
        "    lr: Float\n",
        "      Scalar specifying the learning rate or step-size for the update\n",
        "\n",
        "  Returns:\n",
        "    Nothing\n",
        "  \"\"\"\n",
        "  # Clear up gradients as Pytorch automatically accumulates gradients from\n",
        "  # successive backward calls\n",
        "  zero_grad(params)\n",
        "\n",
        "  # Compute gradients on given objective\n",
        "  loss.backward()\n",
        "\n",
        "  with torch.no_grad():\n",
        "    for par in params:\n",
        "      # Here we work with the 'data' attribute of the parameter rather than the\n",
        "      # parameter itself.\n",
        "      # Hence - use the learning rate and the parameter's .grad.data attribute to perform an update\n",
        "      par.data -= lr * par.grad.data"
       ],
       "metadata": {
        "id": "U6niQp1RNHxp",
        "colab": {
         "base_uri": "https://localhost:8080/"
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
        "outputId": "f07ba931-b6c5-4157-df21-91a429ff70a1",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:30:27.238079Z",
         "start_time": "2023-11-20T22:30:27.232238Z"
        }
       },
       "execution_count": 7,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "Random seed 2021 has been set.\n"
         ]
        }
       ]
      },
      {
       "cell_type": "code",
       "source": [
        "(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=10000, maxlen=250,)\n",
        "def vectorize_sequences(sequences, dimension=10000):\n",
        "    # all zero matrix of shape (len(sequences), dimension)\n",
        "    result = np.zeros((len(sequences), dimension))\n",
        "    for i,sequence in enumerate(sequences):\n",
        "        result[i, sequence] = 1\n",
        "    return result\n",
        "\n",
        "\n",
        "x_train = vectorize_sequences(x_train)\n",
        "x_test = vectorize_sequences(x_test)\n",
        "#x_train = np.expand_dims(x_train, -1)\n",
        "#x_test = np.expand_dims(x_test, -1)\n",
        "\n",
        "x_train = Variable(torch.from_numpy(x_train)).float().to(DEVICE)\n",
        "y_train = Variable(torch.from_numpy(y_train)).long().to(DEVICE)\n",
        "x_test  = Variable(torch.from_numpy(x_test)).float().to(DEVICE)\n",
        "y_test  = Variable(torch.from_numpy(y_test)).long().to(DEVICE)\n",
        "\n",
        "print(\"x_train shape:\", x_train.shape)\n",
        "print(\"y_train shape:\", y_train.shape)\n",
        "print(x_train.shape[0], \"train samples\")\n",
        "print(x_test.shape[0], \"test samples\")"
       ],
       "metadata": {
        "id": "tNDNF10dyqUm",
        "colab": {
         "base_uri": "https://localhost:8080/"
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
        "outputId": "f654597c-411a-45ac-f694-aba79e025aa8",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:30:31.684372Z",
         "start_time": "2023-11-20T22:30:29.176089Z"
        }
       },
       "execution_count": 8,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "x_train shape: torch.Size([17121, 10000])\n",
          "y_train shape: torch.Size([17121])\n",
          "17121 train samples\n",
          "17588 test samples\n"
         ]
        }
       ]
      },
      {
       "cell_type": "markdown",
       "source": [
        "# 1.  Softmax Implementieren\n",
        "Implementieren Sie die Softmax Funktion mit Numpy und stellen Sie zunächst sicher, dass diese die selben Ergebnisse liefert wie die Pytorch-Funktion im Beispiel. Vergleichen Sie dann Ihre Implementierungen mit anderen Gruppen und diskutieren Sie auch über Performance und numerische Stabilität. Erstellen Sie ein kleines Benchmark, um Performance und numerische Stabilität zu testen."
       ],
       "metadata": {
        "id": "tKJZz5YsSSyT"
       }
      },
      {
       "cell_type": "markdown",
       "source": [],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "source": [
        "# If values get too large this function becomes numerically unstable\n",
        "def softmax(x):\n",
        "    return np.exp(x)/np.exp(x).sum()\n",
        "\n",
        "def softmax_stable(x):\n",
        "    return np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()"
       ],
       "metadata": {
        "id": "_80I03V8ogds",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:37:21.783849Z",
         "start_time": "2023-11-20T22:37:21.779265Z"
        }
       },
       "execution_count": 9,
       "outputs": []
      },
      {
       "cell_type": "code",
       "source": [
        "array_small = np.array([2, 11, 7])\n",
        "array_large = np.array([555, 999, 111])\n",
        "\n",
        "print(softmax(array_small).round(3))\n",
        "print(softmax(array_large))\n",
        "print(softmax_stable(array_large).round(2))"
       ],
       "metadata": {
        "id": "x0VacAxQu5JS",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:38:06.327953Z",
         "start_time": "2023-11-20T22:38:06.325348Z"
        }
       },
       "execution_count": 15,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "[0.    0.982 0.018]\n",
          "[ 0. nan  0.]\n",
          "[0. 1. 0.]\n"
         ]
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stderr",
         "output_type": "stream",
         "text": [
          "/var/folders/_v/gtvs0k6x72bf23jbq0yd1wh40000gn/T/ipykernel_78955/2076069751.py:3: RuntimeWarning: overflow encountered in exp\n",
          "  return np.exp(x)/np.exp(x).sum()\n",
          "/var/folders/_v/gtvs0k6x72bf23jbq0yd1wh40000gn/T/ipykernel_78955/2076069751.py:3: RuntimeWarning: invalid value encountered in divide\n",
          "  return np.exp(x)/np.exp(x).sum()\n"
         ]
        }
       ]
      },
      {
       "cell_type": "markdown",
       "source": [
        "# 2.  Regularisierung Implementieren\n",
        "\n",
        "Unten finden Sie einen Pytorch-SGD Schritt mit eingebauter L2-Regularisierung und ohne. Interpretieren Sie die unterschiedlichen Ausgaben. Modifizieren Sie den ersten Codabschnitt mit einer eigenen L2-Regularisierung so, dass identische Ergebnisse erzeugt werden. Sie können dazu die noch nicht verwendete und noch falsch definierte Variable \"regtermwrong\" umdefinieren und zu einem späteren Zeitpunkt im Code darauf zurückgreifen. ACHTUNG: weight_decay*2=lambda."
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "source": [
        "#Datendefinition\n",
        "np.random.seed(123)\n",
        "np.set_printoptions(8, suppress=True)\n",
        "\n",
        "x_numpy = np.random.random((3, 4)).astype(np.double)\n",
        "w_numpy = np.random.random((4, 5)).astype(np.double)\n",
        "w_numpy[0,0] =9.9\n",
        "x_torch = torch.tensor(x_numpy, requires_grad=True)\n"
       ],
       "metadata": {
        "id": "S5XEpjWFTTzi",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:39:36.082809Z",
         "start_time": "2023-11-20T22:39:36.080252Z"
        }
       },
       "execution_count": 16,
       "outputs": []
      },
      {
       "cell_type": "code",
       "source": [
        "# mit Regularisierung\n",
        "w_torch = torch.tensor(w_numpy, requires_grad=True)\n",
        "print('Original weights', w_torch)\n",
        "\n",
        "lr = 0.1\n",
        "sgd = torch.optim.SGD([w_torch], lr=lr, weight_decay=0)\n",
        "#regtermwrong = max(p.max() for p in w_torch)\n",
        "regterm = sum(sum(p) for p in w_torch)\n",
        "y_torch = torch.matmul(x_torch, w_torch)\n",
        "loss = y_torch.sum() + 4 * regterm\n",
        "\n",
        "\n",
        "sgd.zero_grad()\n",
        "loss.backward()\n",
        "sgd.step()\n",
        "\n",
        "w_grad = w_torch.grad.data.numpy()\n",
        "print('0 weight decay', w_torch)\n"
       ],
       "metadata": {
        "id": "WiQW-Y4VkH7v",
        "ExecuteTime": {
         "end_time": "2023-11-20T22:39:38.992708Z",
         "start_time": "2023-11-20T22:39:38.699008Z"
        }
       },
       "execution_count": 17,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "Original weights tensor([[9.9000, 0.0597, 0.3980, 0.7380, 0.1825],\n",
          "        [0.1755, 0.5316, 0.5318, 0.6344, 0.8494],\n",
          "        [0.7245, 0.6110, 0.7224, 0.3230, 0.3618],\n",
          "        [0.2283, 0.2937, 0.6310, 0.0921, 0.4337]], dtype=torch.float64,\n",
          "       requires_grad=True)\n"
         ]
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "ename": "TypeError",
         "evalue": "cannot inherit non-frozen dataclass from a frozen one",
         "output_type": "error",
         "traceback": [
          "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
          "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
          "Cell \u001B[0;32mIn[17], line 6\u001B[0m\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mOriginal weights\u001B[39m\u001B[38;5;124m'\u001B[39m, w_torch)\n\u001B[1;32m      5\u001B[0m lr \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0.1\u001B[39m\n\u001B[0;32m----> 6\u001B[0m sgd \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43moptim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mSGD\u001B[49m\u001B[43m(\u001B[49m\u001B[43m[\u001B[49m\u001B[43mw_torch\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweight_decay\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      7\u001B[0m regtermwrong \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mmax\u001B[39m(p\u001B[38;5;241m.\u001B[39mmax() \u001B[38;5;28;01mfor\u001B[39;00m p \u001B[38;5;129;01min\u001B[39;00m w_torch)\n\u001B[1;32m      8\u001B[0m y_torch \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mmatmul(x_torch, w_torch)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/sgd.py:26\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, lr, momentum, dampening, weight_decay, nesterov, maximize, foreach, differentiable)\u001B[0m\n\u001B[1;32m     21\u001B[0m defaults \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m(lr\u001B[38;5;241m=\u001B[39mlr, momentum\u001B[38;5;241m=\u001B[39mmomentum, dampening\u001B[38;5;241m=\u001B[39mdampening,\n\u001B[1;32m     22\u001B[0m                 weight_decay\u001B[38;5;241m=\u001B[39mweight_decay, nesterov\u001B[38;5;241m=\u001B[39mnesterov,\n\u001B[1;32m     23\u001B[0m                 maximize\u001B[38;5;241m=\u001B[39mmaximize, foreach\u001B[38;5;241m=\u001B[39mforeach,\n\u001B[1;32m     24\u001B[0m                 differentiable\u001B[38;5;241m=\u001B[39mdifferentiable)\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m nesterov \u001B[38;5;129;01mand\u001B[39;00m (momentum \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m dampening \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m):\n\u001B[0;32m---> 26\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNesterov momentum requires a momentum and zero dampening\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39m\u001B[38;5;21m__init__\u001B[39m(params, defaults)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/optimizer.py:266\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, defaults)\u001B[0m\n\u001B[1;32m    262\u001B[0m \u001B[38;5;129m@staticmethod\u001B[39m\n\u001B[1;32m    263\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mprofile_hook_step\u001B[39m(func):\n\u001B[1;32m    265\u001B[0m     \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m--> 266\u001B[0m     \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrapper\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m    267\u001B[0m         \u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m_ \u001B[38;5;241m=\u001B[39m args\n\u001B[1;32m    268\u001B[0m         profile_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mOptimizer.step#\u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.step\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_compile.py:22\u001B[0m, in \u001B[0;36minner\u001B[0;34m(*args, **kwargs)\u001B[0m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m allowed_functions, convert_frame, eval_frame, resume_execution\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m list_backends, register_backend\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconvert_frame\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m replay\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:29\u001B[0m\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mguards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CheckFunctionManager, GuardedCode\n\u001B[1;32m     28\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mhooks\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Hooks\n\u001B[0;32m---> 29\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01moutput_graph\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m OutputGraph\n\u001B[1;32m     30\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mreplay_record\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ExecutionRecord\n\u001B[1;32m     31\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_convert\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InstructionTranslator\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:23\u001B[0m\n\u001B[1;32m     14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_guards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m     15\u001B[0m     Checkpointable,\n\u001B[1;32m     16\u001B[0m     Guard,\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m     19\u001B[0m     TracingContext,\n\u001B[1;32m     20\u001B[0m )\n\u001B[1;32m     21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfx\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexperimental\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_shapes\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ShapeEnv\n\u001B[0;32m---> 23\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config, logging \u001B[38;5;28;01mas\u001B[39;00m torchdynamo_logging, variables\n\u001B[1;32m     24\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CompiledFn, CompilerFn\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbytecode_transformation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m create_instruction, Instruction, unique_id\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m VariableTracker\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbuiltin\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BuiltinVariable\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconstant\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ConstantVariable, EnumVariable\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/base.py:6\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m variables\n\u001B[1;32m      5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m unimplemented\n\u001B[0;32m----> 6\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msource\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AttrSource, Source\n\u001B[1;32m      7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m dict_values, identity, istype, odict_values\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mMutableLocal\u001B[39;00m:\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/source.py:49\u001B[0m\n\u001B[1;32m     39\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mis_input_source\u001B[39m(source):\n\u001B[1;32m     40\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m source\u001B[38;5;241m.\u001B[39mguard_source() \u001B[38;5;129;01min\u001B[39;00m [\n\u001B[1;32m     41\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL,\n\u001B[1;32m     42\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL,\n\u001B[1;32m     43\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL_NN_MODULE,\n\u001B[1;32m     44\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL_NN_MODULE,\n\u001B[1;32m     45\u001B[0m     ]\n\u001B[1;32m     48\u001B[0m \u001B[38;5;129;43m@dataclasses\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdataclass\u001B[49m\n\u001B[0;32m---> 49\u001B[0m \u001B[38;5;28;43;01mclass\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43;01mLocalSource\u001B[39;49;00m\u001B[43m(\u001B[49m\u001B[43mSource\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n\u001B[1;32m     50\u001B[0m \u001B[43m    \u001B[49m\u001B[43mlocal_name\u001B[49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mstr\u001B[39;49m\n\u001B[1;32m     52\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43;01mdef\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43mreconstruct\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcodegen\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1184\u001B[0m, in \u001B[0;36mdataclass\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m   1181\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m wrap\n\u001B[1;32m   1183\u001B[0m \u001B[38;5;66;03m# We're called as @dataclass without parens.\u001B[39;00m\n\u001B[0;32m-> 1184\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mwrap\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1175\u001B[0m, in \u001B[0;36mdataclass.<locals>.wrap\u001B[0;34m(cls)\u001B[0m\n\u001B[1;32m   1174\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrap\u001B[39m(\u001B[38;5;28mcls\u001B[39m):\n\u001B[0;32m-> 1175\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_process_class\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minit\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mrepr\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43meq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43morder\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43munsafe_hash\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   1176\u001B[0m \u001B[43m                          \u001B[49m\u001B[43mfrozen\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmatch_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkw_only\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mslots\u001B[49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:985\u001B[0m, in \u001B[0;36m_process_class\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m    982\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_dataclass_bases:\n\u001B[1;32m    983\u001B[0m     \u001B[38;5;66;03m# Raise an exception if any of our bases are frozen, but we're not.\u001B[39;00m\n\u001B[1;32m    984\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m frozen:\n\u001B[0;32m--> 985\u001B[0m         \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mcannot inherit non-frozen dataclass from a \u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m    986\u001B[0m                         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfrozen one\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m    988\u001B[0m     \u001B[38;5;66;03m# Raise an exception if we're frozen, but none of our bases are.\u001B[39;00m\n\u001B[1;32m    989\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m frozen:\n",
          "\u001B[0;31mTypeError\u001B[0m: cannot inherit non-frozen dataclass from a frozen one"
         ]
        }
       ]
      },
      {
       "cell_type": "code",
       "execution_count": 18,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "name": "stdout",
         "output_type": "stream",
         "text": [
          "Reset Original weights tensor([[9.9000, 0.0597, 0.3980, 0.7380, 0.1825],\n",
          "        [0.1755, 0.5316, 0.5318, 0.6344, 0.8494],\n",
          "        [0.7245, 0.6110, 0.7224, 0.3230, 0.3618],\n",
          "        [0.2283, 0.2937, 0.6310, 0.0921, 0.4337]], dtype=torch.float64,\n",
          "       requires_grad=True)\n"
         ]
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "ename": "TypeError",
         "evalue": "cannot inherit non-frozen dataclass from a frozen one",
         "output_type": "error",
         "traceback": [
          "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
          "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
          "Cell \u001B[0;32mIn[18], line 8\u001B[0m\n\u001B[1;32m      4\u001B[0m w_torch \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mtensor(w_numpy, requires_grad\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[1;32m      6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mReset Original weights\u001B[39m\u001B[38;5;124m'\u001B[39m, w_torch)\n\u001B[0;32m----> 8\u001B[0m sgd \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43moptim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mSGD\u001B[49m\u001B[43m(\u001B[49m\u001B[43m[\u001B[49m\u001B[43mw_torch\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mlr\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweight_decay\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m     10\u001B[0m y_torch \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mmatmul(x_torch, w_torch)\n\u001B[1;32m     11\u001B[0m loss \u001B[38;5;241m=\u001B[39m y_torch\u001B[38;5;241m.\u001B[39msum()\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/sgd.py:26\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, lr, momentum, dampening, weight_decay, nesterov, maximize, foreach, differentiable)\u001B[0m\n\u001B[1;32m     21\u001B[0m defaults \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m(lr\u001B[38;5;241m=\u001B[39mlr, momentum\u001B[38;5;241m=\u001B[39mmomentum, dampening\u001B[38;5;241m=\u001B[39mdampening,\n\u001B[1;32m     22\u001B[0m                 weight_decay\u001B[38;5;241m=\u001B[39mweight_decay, nesterov\u001B[38;5;241m=\u001B[39mnesterov,\n\u001B[1;32m     23\u001B[0m                 maximize\u001B[38;5;241m=\u001B[39mmaximize, foreach\u001B[38;5;241m=\u001B[39mforeach,\n\u001B[1;32m     24\u001B[0m                 differentiable\u001B[38;5;241m=\u001B[39mdifferentiable)\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m nesterov \u001B[38;5;129;01mand\u001B[39;00m (momentum \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m dampening \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m):\n\u001B[0;32m---> 26\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNesterov momentum requires a momentum and zero dampening\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39m\u001B[38;5;21m__init__\u001B[39m(params, defaults)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/optimizer.py:266\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, defaults)\u001B[0m\n\u001B[1;32m    262\u001B[0m \u001B[38;5;129m@staticmethod\u001B[39m\n\u001B[1;32m    263\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mprofile_hook_step\u001B[39m(func):\n\u001B[1;32m    265\u001B[0m     \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m--> 266\u001B[0m     \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrapper\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m    267\u001B[0m         \u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m_ \u001B[38;5;241m=\u001B[39m args\n\u001B[1;32m    268\u001B[0m         profile_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mOptimizer.step#\u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.step\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_compile.py:22\u001B[0m, in \u001B[0;36minner\u001B[0;34m(*args, **kwargs)\u001B[0m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m allowed_functions, convert_frame, eval_frame, resume_execution\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m list_backends, register_backend\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconvert_frame\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m replay\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:29\u001B[0m\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mguards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CheckFunctionManager, GuardedCode\n\u001B[1;32m     28\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mhooks\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Hooks\n\u001B[0;32m---> 29\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01moutput_graph\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m OutputGraph\n\u001B[1;32m     30\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mreplay_record\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ExecutionRecord\n\u001B[1;32m     31\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_convert\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InstructionTranslator\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:23\u001B[0m\n\u001B[1;32m     14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_guards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m     15\u001B[0m     Checkpointable,\n\u001B[1;32m     16\u001B[0m     Guard,\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m     19\u001B[0m     TracingContext,\n\u001B[1;32m     20\u001B[0m )\n\u001B[1;32m     21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfx\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexperimental\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_shapes\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ShapeEnv\n\u001B[0;32m---> 23\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config, logging \u001B[38;5;28;01mas\u001B[39;00m torchdynamo_logging, variables\n\u001B[1;32m     24\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CompiledFn, CompilerFn\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbytecode_transformation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m create_instruction, Instruction, unique_id\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m VariableTracker\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbuiltin\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BuiltinVariable\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconstant\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ConstantVariable, EnumVariable\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/base.py:6\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m variables\n\u001B[1;32m      5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m unimplemented\n\u001B[0;32m----> 6\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msource\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AttrSource, Source\n\u001B[1;32m      7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m dict_values, identity, istype, odict_values\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mMutableLocal\u001B[39;00m:\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/source.py:49\u001B[0m\n\u001B[1;32m     39\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mis_input_source\u001B[39m(source):\n\u001B[1;32m     40\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m source\u001B[38;5;241m.\u001B[39mguard_source() \u001B[38;5;129;01min\u001B[39;00m [\n\u001B[1;32m     41\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL,\n\u001B[1;32m     42\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL,\n\u001B[1;32m     43\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL_NN_MODULE,\n\u001B[1;32m     44\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL_NN_MODULE,\n\u001B[1;32m     45\u001B[0m     ]\n\u001B[1;32m     48\u001B[0m \u001B[38;5;129;43m@dataclasses\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdataclass\u001B[49m\n\u001B[0;32m---> 49\u001B[0m \u001B[38;5;28;43;01mclass\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43;01mLocalSource\u001B[39;49;00m\u001B[43m(\u001B[49m\u001B[43mSource\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n\u001B[1;32m     50\u001B[0m \u001B[43m    \u001B[49m\u001B[43mlocal_name\u001B[49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mstr\u001B[39;49m\n\u001B[1;32m     52\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43;01mdef\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43mreconstruct\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcodegen\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1184\u001B[0m, in \u001B[0;36mdataclass\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m   1181\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m wrap\n\u001B[1;32m   1183\u001B[0m \u001B[38;5;66;03m# We're called as @dataclass without parens.\u001B[39;00m\n\u001B[0;32m-> 1184\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mwrap\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1175\u001B[0m, in \u001B[0;36mdataclass.<locals>.wrap\u001B[0;34m(cls)\u001B[0m\n\u001B[1;32m   1174\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrap\u001B[39m(\u001B[38;5;28mcls\u001B[39m):\n\u001B[0;32m-> 1175\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_process_class\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minit\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mrepr\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43meq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43morder\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43munsafe_hash\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   1176\u001B[0m \u001B[43m                          \u001B[49m\u001B[43mfrozen\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmatch_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkw_only\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mslots\u001B[49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:985\u001B[0m, in \u001B[0;36m_process_class\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m    982\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_dataclass_bases:\n\u001B[1;32m    983\u001B[0m     \u001B[38;5;66;03m# Raise an exception if any of our bases are frozen, but we're not.\u001B[39;00m\n\u001B[1;32m    984\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m frozen:\n\u001B[0;32m--> 985\u001B[0m         \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mcannot inherit non-frozen dataclass from a \u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m    986\u001B[0m                         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfrozen one\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m    988\u001B[0m     \u001B[38;5;66;03m# Raise an exception if we're frozen, but none of our bases are.\u001B[39;00m\n\u001B[1;32m    989\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m frozen:\n",
          "\u001B[0;31mTypeError\u001B[0m: cannot inherit non-frozen dataclass from a frozen one"
         ]
        }
       ],
       "source": [
        "#mit Regularisierung\n",
        "\n",
        "w_torch = torch.tensor(w_numpy, requires_grad=True)\n",
        "\n",
        "print('Reset Original weights', w_torch)\n",
        "\n",
        "sgd = torch.optim.SGD([w_torch], lr=lr, weight_decay=2)\n",
        "\n",
        "y_torch = torch.matmul(x_torch, w_torch)\n",
        "loss = y_torch.sum()\n",
        "\n",
        "sgd.zero_grad()\n",
        "loss.backward()\n",
        "sgd.step()\n",
        "\n",
        "w_grad = w_torch.grad.data.numpy()\n",
        "print('1 weight decay', w_torch)"
       ],
       "metadata": {
        "collapsed": false,
        "ExecuteTime": {
         "end_time": "2023-11-20T22:39:58.623547Z",
         "start_time": "2023-11-20T22:39:58.475069Z"
        }
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "# 3.&nbsp; Einfaches MLP in Pytorch\n",
        "Machen Sie sich ein wenig mit dem `IMDB` Datensatz und den für Sie erstellten Datenstrukturen in x/y_train/test vertraut."
       ],
       "metadata": {
        "id": "mLeOKvUxMunF"
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "## 3.1 Modell erstellen und  Angaben zur Modellgröße verstehen\n",
        "Definieren Sie ein Pytorch Multilayer Perzeptron mit der Größe des IMDB-Dictionaries für one-hot-encodierte Wörte als Eingabe (Sigmoid Aktivierung), 50 Neuronen im Hidden Layer und 2 Ausgabeneuronen. Layer 1 und 2 Ihres Netzes verwendet die Sigmoid-Aktivierungsfunktion, Layer 3 die Softmax-Aktivierungsfunktion.\n",
        "\n",
        "Generieren Sie Modell-Summary mit torchinfo und erklären Sie, was die ausgegebenen Werte bedeuten und wie diese zustande kommen."
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "execution_count": 24,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "data": {
          "text/plain": "=================================================================\nLayer (type:depth-idx)                   Param #\n=================================================================\nModel                                    --\n├─Linear: 1-1                            500,050\n├─Linear: 1-2                            102\n├─Sigmoid: 1-3                           --\n├─Softmax: 1-4                           --\n=================================================================\nTotal params: 500,152\nTrainable params: 500,152\nNon-trainable params: 0\n================================================================="
         },
         "execution_count": 24,
         "metadata": {},
         "output_type": "execute_result"
        }
       ],
       "source": [
        "from torchinfo import summary\n",
        "\n",
        "class Model(nn.Module):\n",
        "\n",
        "  def __init__(self,\n",
        "               n_input: int = 10000,\n",
        "               n_hidden: int = 50,\n",
        "               n_out: int = 2):\n",
        "      super(Model, self).__init__()\n",
        "      self.hidden = nn.Linear(n_input, n_hidden)\n",
        "      self.out = nn.Linear(n_hidden, n_out)\n",
        "      self.sigmoid = nn.Sigmoid()\n",
        "      self.softmax = nn.Softmax()\n",
        "\n",
        "\n",
        "  def forward(self, x):\n",
        "      x = self.hidden(x)\n",
        "      x = self.sigmoid(x)\n",
        "      x = self.out(x)\n",
        "      x = self.softmax(x)\n",
        "      return x\n",
        "\n",
        "model = Model()\n",
        "summary(m)"
       ],
       "metadata": {
        "collapsed": false,
        "ExecuteTime": {
         "end_time": "2023-11-20T23:09:52.169718Z",
         "start_time": "2023-11-20T23:09:52.161860Z"
        }
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "## 3.2 Modell trainieren und Performancekurven interpretieren\n",
        "Nutzen Sie den untenstehenedn Code um Ihr Modell zu trainieren. Interpretieren und diskutieren Sie die entstehenden Performancekurven. Falls Sie einen unerwarteten Anstieg Ihres Losses beobachten, recherchieren Sie wie Sie diese mit dem Einbau einer einzelnen Verbesserung innerhalb des gegebenen SGD Lernverfahrens beheben können. ACHTUNG: Wenn Sie Ihr Modell nicht oben neu initialisieren, optimieren Sie weiter auf den schon veränderten Parametern."
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "execution_count": 25,
       "outputs": [
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
         "ename": "TypeError",
         "evalue": "cannot inherit non-frozen dataclass from a frozen one",
         "output_type": "error",
         "traceback": [
          "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
          "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
          "Cell \u001B[0;32mIn[25], line 3\u001B[0m\n\u001B[1;32m      1\u001B[0m EPOCHS  \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m291\u001B[39m \u001B[38;5;66;03m#@param {type:\"slider\", min:2, max:1000, step:1}\u001B[39;00m\n\u001B[1;32m      2\u001B[0m RATE \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0.902\u001B[39m \u001B[38;5;66;03m#@param {type:\"slider\", min:0.001, max:2, step:0.001}\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m optimizer \u001B[38;5;241m=\u001B[39m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43moptim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mSGD\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mparameters\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlr\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mRATE\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mweight_decay\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m      4\u001B[0m loss_fn   \u001B[38;5;241m=\u001B[39m nn\u001B[38;5;241m.\u001B[39mCrossEntropyLoss()\n\u001B[1;32m      5\u001B[0m loss_list     \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((EPOCHS,))\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/sgd.py:26\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, lr, momentum, dampening, weight_decay, nesterov, maximize, foreach, differentiable)\u001B[0m\n\u001B[1;32m     21\u001B[0m defaults \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m(lr\u001B[38;5;241m=\u001B[39mlr, momentum\u001B[38;5;241m=\u001B[39mmomentum, dampening\u001B[38;5;241m=\u001B[39mdampening,\n\u001B[1;32m     22\u001B[0m                 weight_decay\u001B[38;5;241m=\u001B[39mweight_decay, nesterov\u001B[38;5;241m=\u001B[39mnesterov,\n\u001B[1;32m     23\u001B[0m                 maximize\u001B[38;5;241m=\u001B[39mmaximize, foreach\u001B[38;5;241m=\u001B[39mforeach,\n\u001B[1;32m     24\u001B[0m                 differentiable\u001B[38;5;241m=\u001B[39mdifferentiable)\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m nesterov \u001B[38;5;129;01mand\u001B[39;00m (momentum \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;129;01mor\u001B[39;00m dampening \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m):\n\u001B[0;32m---> 26\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mNesterov momentum requires a momentum and zero dampening\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28msuper\u001B[39m()\u001B[38;5;241m.\u001B[39m\u001B[38;5;21m__init__\u001B[39m(params, defaults)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/optim/optimizer.py:266\u001B[0m, in \u001B[0;36m__init__\u001B[0;34m(self, params, defaults)\u001B[0m\n\u001B[1;32m    262\u001B[0m \u001B[38;5;129m@staticmethod\u001B[39m\n\u001B[1;32m    263\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mprofile_hook_step\u001B[39m(func):\n\u001B[1;32m    265\u001B[0m     \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m--> 266\u001B[0m     \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrapper\u001B[39m(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m    267\u001B[0m         \u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39m_ \u001B[38;5;241m=\u001B[39m args\n\u001B[1;32m    268\u001B[0m         profile_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mOptimizer.step#\u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m.step\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m)\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_compile.py:22\u001B[0m, in \u001B[0;36minner\u001B[0;34m(*args, **kwargs)\u001B[0m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m allowed_functions, convert_frame, eval_frame, resume_execution\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m list_backends, register_backend\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconvert_frame\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m replay\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:29\u001B[0m\n\u001B[1;32m     27\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mguards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CheckFunctionManager, GuardedCode\n\u001B[1;32m     28\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mhooks\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Hooks\n\u001B[0;32m---> 29\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01moutput_graph\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m OutputGraph\n\u001B[1;32m     30\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mreplay_record\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ExecutionRecord\n\u001B[1;32m     31\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_convert\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m InstructionTranslator\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:23\u001B[0m\n\u001B[1;32m     14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m_guards\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m (\n\u001B[1;32m     15\u001B[0m     Checkpointable,\n\u001B[1;32m     16\u001B[0m     Guard,\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m     19\u001B[0m     TracingContext,\n\u001B[1;32m     20\u001B[0m )\n\u001B[1;32m     21\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mfx\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexperimental\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msymbolic_shapes\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ShapeEnv\n\u001B[0;32m---> 23\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m config, logging \u001B[38;5;28;01mas\u001B[39;00m torchdynamo_logging, variables\n\u001B[1;32m     24\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbackends\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mregistry\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m CompiledFn, CompilerFn\n\u001B[1;32m     25\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbytecode_transformation\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m create_instruction, Instruction, unique_id\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/__init__.py:1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbase\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m VariableTracker\n\u001B[1;32m      2\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mbuiltin\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m BuiltinVariable\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mconstant\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ConstantVariable, EnumVariable\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/variables/base.py:6\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m variables\n\u001B[1;32m      5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mexc\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m unimplemented\n\u001B[0;32m----> 6\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01msource\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m AttrSource, Source\n\u001B[1;32m      7\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m dict_values, identity, istype, odict_values\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mMutableLocal\u001B[39;00m:\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/site-packages/torch/_dynamo/source.py:49\u001B[0m\n\u001B[1;32m     39\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mis_input_source\u001B[39m(source):\n\u001B[1;32m     40\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m source\u001B[38;5;241m.\u001B[39mguard_source() \u001B[38;5;129;01min\u001B[39;00m [\n\u001B[1;32m     41\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL,\n\u001B[1;32m     42\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL,\n\u001B[1;32m     43\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mLOCAL_NN_MODULE,\n\u001B[1;32m     44\u001B[0m         GuardSource\u001B[38;5;241m.\u001B[39mGLOBAL_NN_MODULE,\n\u001B[1;32m     45\u001B[0m     ]\n\u001B[1;32m     48\u001B[0m \u001B[38;5;129;43m@dataclasses\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdataclass\u001B[49m\n\u001B[0;32m---> 49\u001B[0m \u001B[38;5;28;43;01mclass\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43;01mLocalSource\u001B[39;49;00m\u001B[43m(\u001B[49m\u001B[43mSource\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n\u001B[1;32m     50\u001B[0m \u001B[43m    \u001B[49m\u001B[43mlocal_name\u001B[49m\u001B[43m:\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mstr\u001B[39;49m\n\u001B[1;32m     52\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43;01mdef\u001B[39;49;00m\u001B[43m \u001B[49m\u001B[38;5;21;43mreconstruct\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcodegen\u001B[49m\u001B[43m)\u001B[49m\u001B[43m:\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1184\u001B[0m, in \u001B[0;36mdataclass\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m   1181\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m wrap\n\u001B[1;32m   1183\u001B[0m \u001B[38;5;66;03m# We're called as @dataclass without parens.\u001B[39;00m\n\u001B[0;32m-> 1184\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mwrap\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:1175\u001B[0m, in \u001B[0;36mdataclass.<locals>.wrap\u001B[0;34m(cls)\u001B[0m\n\u001B[1;32m   1174\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mwrap\u001B[39m(\u001B[38;5;28mcls\u001B[39m):\n\u001B[0;32m-> 1175\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43m_process_class\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minit\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mrepr\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43meq\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43morder\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43munsafe_hash\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m   1176\u001B[0m \u001B[43m                          \u001B[49m\u001B[43mfrozen\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmatch_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkw_only\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mslots\u001B[49m\u001B[43m)\u001B[49m\n",
          "File \u001B[0;32m~/miniconda3/envs/ginn/lib/python3.10/dataclasses.py:985\u001B[0m, in \u001B[0;36m_process_class\u001B[0;34m(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)\u001B[0m\n\u001B[1;32m    982\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_dataclass_bases:\n\u001B[1;32m    983\u001B[0m     \u001B[38;5;66;03m# Raise an exception if any of our bases are frozen, but we're not.\u001B[39;00m\n\u001B[1;32m    984\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m frozen:\n\u001B[0;32m--> 985\u001B[0m         \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mcannot inherit non-frozen dataclass from a \u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m    986\u001B[0m                         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfrozen one\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m    988\u001B[0m     \u001B[38;5;66;03m# Raise an exception if we're frozen, but none of our bases are.\u001B[39;00m\n\u001B[1;32m    989\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m any_frozen_base \u001B[38;5;129;01mand\u001B[39;00m frozen:\n",
          "\u001B[0;31mTypeError\u001B[0m: cannot inherit non-frozen dataclass from a frozen one"
         ]
    
    Friedrich Heitzer's avatar
    Friedrich Heitzer committed
       ],
       "source": [
        "EPOCHS  = 291 #@param {type:\"slider\", min:2, max:1000, step:1}\n",
        "RATE = 0.902 #@param {type:\"slider\", min:0.001, max:2, step:0.001}\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=RATE, weight_decay=0)\n",
        "loss_fn   = nn.CrossEntropyLoss()\n",
        "loss_list     = np.zeros((EPOCHS,))\n",
        "accuracy_list = np.zeros((EPOCHS,))\n",
        "accuracy_list_test = np.zeros((EPOCHS,))\n",
        "\n",
        "\n",
        "for epoch in tqdm.trange(EPOCHS):\n",
        "    y_pred = model(x_train)\n",
        "    #loss = loss_fn(y_pred, y_train)\n",
        "    loss_list[epoch] = loss.item()\n",
        "    loss = loss_fn(y_pred, y_train)# + 0.01 *l2_reg(model)\n",
        "\n",
        "    # Zero gradients\n",
        "    optimizer.zero_grad()\n",
        "\n",
        "    loss.backward()\n",
        "    #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)#, args.clip)\n",
        "    optimizer.step()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        y_pred = model(x_train)\n",
        "        correct = (torch.argmax(y_pred, dim=1) == y_train).type(torch.FloatTensor)\n",
        "        accuracy_list[epoch] = correct.mean()\n",
        "        y_pred = model(x_test)\n",
        "        correct = (torch.argmax(y_pred, dim=1) == y_test).type(torch.FloatTensor)\n",
        "        accuracy_list_test[epoch] = correct.mean()\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(12, 6), sharex=True)\n",
        "\n",
        "ax1.plot(accuracy_list)\n",
        "ax1.set_ylabel(\"train accuracy\")\n",
        "ax2.plot(loss_list)\n",
        "ax2.set_ylabel(\"train loss\")\n",
        "ax3.plot(accuracy_list_test)\n",
        "ax3.set_ylabel(\"test acc\")\n",
        "ax3.set_xlabel(\"epochs\");"
       ],
       "metadata": {
        "collapsed": false,
        "ExecuteTime": {
         "end_time": "2023-11-20T23:10:04.077035Z",
         "start_time": "2023-11-20T23:10:03.912932Z"
        }
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "## 3.3.&nbsp;Momentum Implementieren\n",
        "Vervollständigen Sie Methode momentum_update. Überlegen Sie sich, wie Sie die Korrektheit mit einem Durchlauf inkl. Momentum Update auf Ihrem oben definierten Modell prüfen können"
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "outputs": [],
       "source": [
        "def momentum_update(loss, params, grad_vel, lr=1e-3, beta=0.8):\n",
        "  # Clear up gradients as Pytorch automatically accumulates gradients from\n",
        "  # successive backward calls\n",
        "  zero_grad(params)\n",
        "  # Compute gradients on given objective\n",
        "  loss.backward()\n",
        "\n",
        "  with torch.no_grad():\n",
        "    for (par, vel) in zip(params, grad_vel):\n",
        "      w_u = lr * par.grad + beta * grad_vel\n",
        "      grad_vel = w_u\n",
        "      par -= w_u"
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "## 3.4 Experimente zum Lernverhalten mit Momentum und Batch Size\n",
        "Im folgenden können Sie für ein festgelegtes Zeitbudget schauen, wie sich der Loss Ihres neuronalen Netzes innerhalb dieser Zeit entwickelt.\n",
        "Experimentieren Sie zunächst mit den Voreinstellungen mit und ohne Momentum, probieren Sie dann eigene Einstellungen aus. Diskutieren Sie das visualisierte Lernverhalten insbesondere bzgl. unterschiedlicher Batch Sizes."
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "code",
       "execution_count": null,
       "outputs": [],
       "source": [
        "@widgets.interact_manual\n",
        "def minibatch_experiment(batch_sizes='1, 20, 500, 17000',\n",
        "                         lrs='0.9, 0.9, 0.9, 0.9',\n",
        "                         time_budget=widgets.Dropdown(options=[\"0.05\", \"0.5\",  \"2.0\", \"5.0\", \"7.5\"],\n",
        "                                                      value=\"5.0\"),\n",
        "                         use_momentum = widgets.ToggleButton(value=True)):\n",
        "\n",
        "  \"\"\"\n",
        "  Demonstration of minibatch experiment\n",
        "\n",
        "  Args:\n",
        "    batch_sizes: String\n",
        "      Size of minibatches\n",
        "    lrs: String\n",
        "      Different learning rates\n",
        "    time_budget: widget dropdown instance\n",
        "      Different time budgets with default=2.5s\n",
        "\n",
        "  Returns:\n",
        "    Nothing\n",
        "  \"\"\"\n",
        "  batch_sizes = [int(s) for s in batch_sizes.split(',')]\n",
        "  lrs = [float(s) for s in lrs.split(',')]\n",
        "\n",
        "  LOSS_HIST = {_:[] for _ in batch_sizes}\n",
        "\n",
        "  #X, y = train_set.data, train_set.targets\n",
        "  base_model = Model(x_train.shape[1]).to(DEVICE)\n",
        "  #base_model = MLP(in_dim=784, out_dim=10, hidden_dims=[100, 100])\n",
        "\n",
        "  for id, batch_size in enumerate(tqdm.auto.tqdm(batch_sizes)):\n",
        "    start_time = time.time()\n",
        "    # Create a new copy of the model for each batch size\n",
        "    model = copy.deepcopy(base_model)\n",
        "    params = list(model.parameters())\n",
        "    lr = lrs[id]\n",
        "    # Fixed budget per choice of batch size\n",
        "    #initial_vel = [torch.randn_like(p) for p in model.parameters()]\n",
        "    aux_tensors = [torch.zeros_like(_) for _ in params]\n",
        "    while (time.time() - start_time) < float(time_budget):\n",
        "      data, labels = sample_minibatch(x_train, y_train, batch_size)\n",
        "      loss = loss_fn(model(data), labels)\n",
        "      if use_momentum:\n",
        "        momentum_update(loss, params, grad_vel=aux_tensors, lr=lr, beta=0.5)\n",
        "      else:\n",
        "        gradient_update(loss, params, lr=lr)\n",
        "      LOSS_HIST[batch_size].append([time.time() - start_time,\n",
        "                                    loss.item()])\n",
        "\n",
        "  fig, axs = plt.subplots(1, len(batch_sizes), figsize=(10, 3))\n",
        "  for ax, batch_size in zip(axs, batch_sizes):\n",
        "    plot_data = np.array(LOSS_HIST[batch_size])\n",
        "    ax.plot(plot_data[:, 0], plot_data[:, 1], label=batch_size,\n",
        "            alpha=0.8)\n",
        "    #ax.set_title('Batch size: ' + str(batch_size) + ' #: ' + str(batch_size*len(LOSS_HIST[batch_size])))\n",
        "    ax.set_title(' #: ' + str(batch_size*len(LOSS_HIST[batch_size])))\n",
        "    ax.set_xlabel('Seconds')\n",
        "    ax.set_ylabel('Loss')\n",
        "  plt.show()\n",
        "  #return(LOSS_HIST)\n"
       ],
       "metadata": {
        "collapsed": false
       }
      },
      {
       "cell_type": "markdown",
       "source": [
        "![image.png]()"
       ],
       "metadata": {
        "collapsed": false
       }
      }
     ]
    }