{ "cells": [ { "cell_type": "markdown", "id": "d58644a3", "metadata": { "papermill": { "duration": 0.007135, "end_time": "2024-09-05T18:30:30.865440", "exception": false, "start_time": "2024-09-05T18:30:30.858305", "status": "completed" }, "tags": [] }, "source": [ "# Defined operations" ] }, { "cell_type": "markdown", "id": "e54b0381", "metadata": { "papermill": { "duration": 0.004333, "end_time": "2024-09-05T18:30:30.875104", "exception": false, "start_time": "2024-09-05T18:30:30.870771", "status": "completed" }, "tags": [] }, "source": [ "Recall that all operations must be defined with specific local gradient computation for BP to work. In this section, we will implement a minimal **autograd engine** for creating computational graphs. This starts with the base `Node` class which has a `data` attribute for storing output and a `grad` attribute for storing the global gradient. Furthermore, the base class defines a `backward` method to solve for `grad` as described above." ] }, { "cell_type": "code", "execution_count": 1, "id": "967a8e6d", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:30.884439Z", "iopub.status.busy": "2024-09-05T18:30:30.884096Z", "iopub.status.idle": "2024-09-05T18:30:30.938215Z", "shell.execute_reply": "2024-09-05T18:30:30.937799Z" }, "papermill": { "duration": 0.060396, "end_time": "2024-09-05T18:30:30.939730", "exception": false, "start_time": "2024-09-05T18:30:30.879334", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import math\n",
       "import random\n",
       "random.seed(42)\n",
       "\n",
       "from typing import final\n",
       "from collections import OrderedDict\n",
       "\n",
       "\n",
       "class Node:\n",
       "    def __init__(self, data, parents=()):\n",
       "        self.data = data\n",
       "        self.grad = 0               # ∂loss / ∂self\n",
       "        self._parents = parents     # parent -> self\n",
       "\n",
       "    @final\n",
       "    def sorted_nodes(self):\n",
       "        """Return topologically sorted nodes with self as root."""\n",
       "        topo = OrderedDict()\n",
       "\n",
       "        def dfs(node):\n",
       "            if node not in topo:\n",
       "                for parent in node._parents:\n",
       "                    dfs(parent)\n",
       "\n",
       "                topo[node] = None\n",
       "\n",
       "        dfs(self)\n",
       "        return reversed(topo)\n",
       "\n",
       "\n",
       "    @final\n",
       "    def backward(self):\n",
       "        """Send global grads backward to parent nodes."""\n",
       "        self.grad = 1.0\n",
       "        for node in self.sorted_nodes():\n",
       "            for parent in node._parents:\n",
       "                parent.grad += node.grad * node._local_grad(parent)\n",
       "\n",
       "\n",
       "    def _local_grad(self, parent) -> float:\n",
       "        """Calculate local grads ∂self / ∂parent."""\n",
       "        raise NotImplementedError("Base node has no parents.")\n",
       "\n",
       "\n",
       "    def __add__(self, node):\n",
       "        return BinaryOpNode(self, node, op="+")\n",
       "\n",
       "    def __mul__(self, node):\n",
       "        return BinaryOpNode(self, node, op="*")\n",
       "\n",
       "    def __pow__(self, n):\n",
       "        assert isinstance(n, (int, float)) and n != 1\n",
       "        return PowOp(self, n)\n",
       "\n",
       "    def relu(self):\n",
       "        return ReLUNode(self)\n",
       "\n",
       "    def tanh(self):\n",
       "        return TanhNode(self)\n",
       "\n",
       "    def __neg__(self):\n",
       "        return self * Node(-1)\n",
       "\n",
       "    def __sub__(self, node):\n",
       "        return self + (-node)\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{math}\n", "\\PY{k+kn}{import} \\PY{n+nn}{random}\n", "\\PY{n}{random}\\PY{o}{.}\\PY{n}{seed}\\PY{p}{(}\\PY{l+m+mi}{42}\\PY{p}{)}\n", "\n", "\\PY{k+kn}{from} \\PY{n+nn}{typing} \\PY{k+kn}{import} \\PY{n}{final}\n", "\\PY{k+kn}{from} \\PY{n+nn}{collections} \\PY{k+kn}{import} \\PY{n}{OrderedDict}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Node}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{data}\\PY{p}{,} \\PY{n}{parents}\\PY{o}{=}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{data} \\PY{o}{=} \\PY{n}{data}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{=} \\PY{l+m+mi}{0} \\PY{c+c1}{\\PYZsh{} ∂loss / ∂self}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parents} \\PY{o}{=} \\PY{n}{parents} \\PY{c+c1}{\\PYZsh{} parent \\PYZhy{}\\PYZgt{} self}\n", "\n", " \\PY{n+nd}{@final}\n", " \\PY{k}{def} \\PY{n+nf}{sorted\\PYZus{}nodes}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Return topologically sorted nodes with self as root.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n}{topo} \\PY{o}{=} \\PY{n}{OrderedDict}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{dfs}\\PY{p}{(}\\PY{n}{node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n}{node} \\PY{o+ow}{not} \\PY{o+ow}{in} \\PY{n}{topo}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{parent} \\PY{o+ow}{in} \\PY{n}{node}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{p}{:}\n", " \\PY{n}{dfs}\\PY{p}{(}\\PY{n}{parent}\\PY{p}{)}\n", "\n", " \\PY{n}{topo}\\PY{p}{[}\\PY{n}{node}\\PY{p}{]} \\PY{o}{=} \\PY{k+kc}{None}\n", "\n", " \\PY{n}{dfs}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n+nb}{reversed}\\PY{p}{(}\\PY{n}{topo}\\PY{p}{)}\n", "\n", "\n", " \\PY{n+nd}{@final}\n", " \\PY{k}{def} \\PY{n+nf}{backward}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Send global grads backward to parent nodes.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{=} \\PY{l+m+mf}{1.0}\n", " \\PY{k}{for} \\PY{n}{node} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{sorted\\PYZus{}nodes}\\PY{p}{(}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{parent} \\PY{o+ow}{in} \\PY{n}{node}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{p}{:}\n", " \\PY{n}{parent}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{+}\\PY{o}{=} \\PY{n}{node}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{*} \\PY{n}{node}\\PY{o}{.}\\PY{n}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n}{parent}\\PY{p}{)}\n", "\n", "\n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{parent}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{float}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Calculate local grads ∂self / ∂parent.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{k}{raise} \\PY{n+ne}{NotImplementedError}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{Base node has no parents.}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}add\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n}{BinaryOpNode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{node}\\PY{p}{,} \\PY{n}{op}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{+}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}mul\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n}{BinaryOpNode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{node}\\PY{p}{,} \\PY{n}{op}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{*}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}pow\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{n}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{assert} \\PY{n+nb}{isinstance}\\PY{p}{(}\\PY{n}{n}\\PY{p}{,} \\PY{p}{(}\\PY{n+nb}{int}\\PY{p}{,} \\PY{n+nb}{float}\\PY{p}{)}\\PY{p}{)} \\PY{o+ow}{and} \\PY{n}{n} \\PY{o}{!=} \\PY{l+m+mi}{1}\n", " \\PY{k}{return} \\PY{n}{PowOp}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{n}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{relu}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n}{ReLUNode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{tanh}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n}{TanhNode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}neg\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self} \\PY{o}{*} \\PY{n}{Node}\\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}sub\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self} \\PY{o}{+} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{n}{node}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import math\n", "import random\n", "random.seed(42)\n", "\n", "from typing import final\n", "from collections import OrderedDict\n", "\n", "\n", "class Node:\n", " def __init__(self, data, parents=()):\n", " self.data = data\n", " self.grad = 0 # ∂loss / ∂self\n", " self._parents = parents # parent -> self\n", "\n", " @final\n", " def sorted_nodes(self):\n", " \"\"\"Return topologically sorted nodes with self as root.\"\"\"\n", " topo = OrderedDict()\n", "\n", " def dfs(node):\n", " if node not in topo:\n", " for parent in node._parents:\n", " dfs(parent)\n", "\n", " topo[node] = None\n", "\n", " dfs(self)\n", " return reversed(topo)\n", "\n", "\n", " @final\n", " def backward(self):\n", " \"\"\"Send global grads backward to parent nodes.\"\"\"\n", " self.grad = 1.0\n", " for node in self.sorted_nodes():\n", " for parent in node._parents:\n", " parent.grad += node.grad * node._local_grad(parent)\n", "\n", "\n", " def _local_grad(self, parent) -> float:\n", " \"\"\"Calculate local grads ∂self / ∂parent.\"\"\"\n", " raise NotImplementedError(\"Base node has no parents.\")\n", "\n", "\n", " def __add__(self, node):\n", " return BinaryOpNode(self, node, op=\"+\")\n", "\n", " def __mul__(self, node):\n", " return BinaryOpNode(self, node, op=\"*\")\n", "\n", " def __pow__(self, n):\n", " assert isinstance(n, (int, float)) and n != 1\n", " return PowOp(self, n)\n", "\n", " def relu(self):\n", " return ReLUNode(self)\n", "\n", " def tanh(self):\n", " return TanhNode(self)\n", "\n", " def __neg__(self):\n", " return self * Node(-1)\n", "\n", " def __sub__(self, node):\n", " return self + (-node)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import math\n", "import random\n", "random.seed(42)\n", "\n", "from typing import final\n", "from collections import OrderedDict\n", "\n", "\n", "class Node:\n", " def __init__(self, data, parents=()):\n", " self.data = data\n", " self.grad = 0 # ∂loss / ∂self\n", " self._parents = parents # parent -> self\n", "\n", " @final\n", " def sorted_nodes(self):\n", " \"\"\"Return topologically sorted nodes with self as root.\"\"\"\n", " topo = OrderedDict()\n", "\n", " def dfs(node):\n", " if node not in topo:\n", " for parent in node._parents:\n", " dfs(parent)\n", "\n", " topo[node] = None\n", "\n", " dfs(self)\n", " return reversed(topo)\n", "\n", "\n", " @final\n", " def backward(self):\n", " \"\"\"Send global grads backward to parent nodes.\"\"\"\n", " self.grad = 1.0\n", " for node in self.sorted_nodes():\n", " for parent in node._parents:\n", " parent.grad += node.grad * node._local_grad(parent)\n", "\n", "\n", " def _local_grad(self, parent) -> float:\n", " \"\"\"Calculate local grads ∂self / ∂parent.\"\"\"\n", " raise NotImplementedError(\"Base node has no parents.\")\n", "\n", "\n", " def __add__(self, node):\n", " return BinaryOpNode(self, node, op=\"+\")\n", "\n", " def __mul__(self, node):\n", " return BinaryOpNode(self, node, op=\"*\")\n", "\n", " def __pow__(self, n):\n", " assert isinstance(n, (int, float)) and n != 1\n", " return PowOp(self, n)\n", "\n", " def relu(self):\n", " return ReLUNode(self)\n", "\n", " def tanh(self):\n", " return TanhNode(self)\n", "\n", " def __neg__(self):\n", " return self * Node(-1)\n", "\n", " def __sub__(self, node):\n", " return self + (-node)" ] }, { "cell_type": "markdown", "id": "3eb3b46f", "metadata": { "papermill": { "duration": 0.003248, "end_time": "2024-09-05T18:30:30.946263", "exception": false, "start_time": "2024-09-05T18:30:30.943015", "status": "completed" }, "tags": [] }, "source": [ "Next, we define the **supported operations**. Observe that only a handful are needed to implement a fully-connected neural net:" ] }, { "cell_type": "code", "execution_count": 2, "id": "a9ee7e91", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:30.952894Z", "iopub.status.busy": "2024-09-05T18:30:30.952688Z", "iopub.status.idle": "2024-09-05T18:30:31.067666Z", "shell.execute_reply": "2024-09-05T18:30:31.058235Z" }, "papermill": { "duration": 0.193747, "end_time": "2024-09-05T18:30:31.142917", "exception": false, "start_time": "2024-09-05T18:30:30.949170", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
class BinaryOpNode(Node):\n",
       "    def __init__(self, x, y, op: str):\n",
       "        """Binary operation between two nodes."""\n",
       "        ops = {"+": lambda x, y: x + y, "*": lambda x, y: x * y}\n",
       "        self._op = op\n",
       "        super().__init__(ops[op](x.data, y.data), (x, y))\n",
       "\n",
       "    def _local_grad(self, parent):\n",
       "        if self._op == "+":\n",
       "            return 1.0\n",
       "\n",
       "        elif self._op == "*":\n",
       "            i = self._parents.index(parent)\n",
       "            coparent = self._parents[1 - i]\n",
       "            return coparent.data\n",
       "\n",
       "    def __repr__(self):\n",
       "        return self._op\n",
       "\n",
       "\n",
       "class ReLUNode(Node):\n",
       "    def __init__(self, x):\n",
       "        data = x.data * int(x.data > 0.0)\n",
       "        super().__init__(data, (x,))\n",
       "\n",
       "    def _local_grad(self, parent):\n",
       "        return float(parent.data > 0)\n",
       "\n",
       "    def __repr__(self):\n",
       "        return "relu"\n",
       "\n",
       "\n",
       "class TanhNode(Node):\n",
       "    def __init__(self, x):\n",
       "        data = math.tanh(x.data)\n",
       "        super().__init__(data, (x,))\n",
       "\n",
       "    def _local_grad(self, parent):\n",
       "        return 1 - self.data**2\n",
       "\n",
       "    def __repr__(self):\n",
       "        return "tanh"\n",
       "\n",
       "\n",
       "class PowOp(Node):\n",
       "    def __init__(self, x, n):\n",
       "        self.n = n\n",
       "        data = x.data**self.n\n",
       "        super().__init__(data, (x,))\n",
       "\n",
       "    def _local_grad(self, parent):\n",
       "        return self.n * parent.data ** (self.n - 1)\n",
       "\n",
       "    def __repr__(self):\n",
       "        return f"** {self.n}"\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k}{class} \\PY{n+nc}{BinaryOpNode}\\PY{p}{(}\\PY{n}{Node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{,} \\PY{n}{op}\\PY{p}{:} \\PY{n+nb}{str}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Binary operation between two nodes.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n}{ops} \\PY{o}{=} \\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{+}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k}{lambda} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{:} \\PY{n}{x} \\PY{o}{+} \\PY{n}{y}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{*}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k}{lambda} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{:} \\PY{n}{x} \\PY{o}{*} \\PY{n}{y}\\PY{p}{\\PYZcb{}}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}op} \\PY{o}{=} \\PY{n}{op}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n}{ops}\\PY{p}{[}\\PY{n}{op}\\PY{p}{]}\\PY{p}{(}\\PY{n}{x}\\PY{o}{.}\\PY{n}{data}\\PY{p}{,} \\PY{n}{y}\\PY{o}{.}\\PY{n}{data}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{parent}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}op} \\PY{o}{==} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{+}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+m+mf}{1.0}\n", "\n", " \\PY{k}{elif} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}op} \\PY{o}{==} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{*}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:}\n", " \\PY{n}{i} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{o}{.}\\PY{n}{index}\\PY{p}{(}\\PY{n}{parent}\\PY{p}{)}\n", " \\PY{n}{coparent} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{p}{[}\\PY{l+m+mi}{1} \\PY{o}{\\PYZhy{}} \\PY{n}{i}\\PY{p}{]}\n", " \\PY{k}{return} \\PY{n}{coparent}\\PY{o}{.}\\PY{n}{data}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}op}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{ReLUNode}\\PY{p}{(}\\PY{n}{Node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{data} \\PY{o}{=} \\PY{n}{x}\\PY{o}{.}\\PY{n}{data} \\PY{o}{*} \\PY{n+nb}{int}\\PY{p}{(}\\PY{n}{x}\\PY{o}{.}\\PY{n}{data} \\PY{o}{\\PYZgt{}} \\PY{l+m+mf}{0.0}\\PY{p}{)}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n}{data}\\PY{p}{,} \\PY{p}{(}\\PY{n}{x}\\PY{p}{,}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{parent}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb}{float}\\PY{p}{(}\\PY{n}{parent}\\PY{o}{.}\\PY{n}{data} \\PY{o}{\\PYZgt{}} \\PY{l+m+mi}{0}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{relu}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{TanhNode}\\PY{p}{(}\\PY{n}{Node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{data} \\PY{o}{=} \\PY{n}{math}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{x}\\PY{o}{.}\\PY{n}{data}\\PY{p}{)}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n}{data}\\PY{p}{,} \\PY{p}{(}\\PY{n}{x}\\PY{p}{,}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{parent}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+m+mi}{1} \\PY{o}{\\PYZhy{}} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{data}\\PY{o}{*}\\PY{o}{*}\\PY{l+m+mi}{2}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tanh}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{PowOp}\\PY{p}{(}\\PY{n}{Node}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{,} \\PY{n}{n}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n} \\PY{o}{=} \\PY{n}{n}\n", " \\PY{n}{data} \\PY{o}{=} \\PY{n}{x}\\PY{o}{.}\\PY{n}{data}\\PY{o}{*}\\PY{o}{*}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n}{data}\\PY{p}{,} \\PY{p}{(}\\PY{n}{x}\\PY{p}{,}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}local\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{parent}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n} \\PY{o}{*} \\PY{n}{parent}\\PY{o}{.}\\PY{n}{data} \\PY{o}{*}\\PY{o}{*} \\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n} \\PY{o}{\\PYZhy{}} \\PY{l+m+mi}{1}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{** }\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{\\PYZdq{}}\n", "\\end{Verbatim}\n" ], "text/plain": [ "\n", "class BinaryOpNode(Node):\n", " def __init__(self, x, y, op: str):\n", " \"\"\"Binary operation between two nodes.\"\"\"\n", " ops = {\"+\": lambda x, y: x + y, \"*\": lambda x, y: x * y}\n", " self._op = op\n", " super().__init__(ops[op](x.data, y.data), (x, y))\n", "\n", " def _local_grad(self, parent):\n", " if self._op == \"+\":\n", " return 1.0\n", "\n", " elif self._op == \"*\":\n", " i = self._parents.index(parent)\n", " coparent = self._parents[1 - i]\n", " return coparent.data\n", "\n", " def __repr__(self):\n", " return self._op\n", "\n", "\n", "class ReLUNode(Node):\n", " def __init__(self, x):\n", " data = x.data * int(x.data > 0.0)\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return float(parent.data > 0)\n", "\n", " def __repr__(self):\n", " return \"relu\"\n", "\n", "\n", "class TanhNode(Node):\n", " def __init__(self, x):\n", " data = math.tanh(x.data)\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return 1 - self.data**2\n", "\n", " def __repr__(self):\n", " return \"tanh\"\n", "\n", "\n", "class PowOp(Node):\n", " def __init__(self, x, n):\n", " self.n = n\n", " data = x.data**self.n\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return self.n * parent.data ** (self.n - 1)\n", "\n", " def __repr__(self):\n", " return f\"** {self.n}\"" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "\n", "class BinaryOpNode(Node):\n", " def __init__(self, x, y, op: str):\n", " \"\"\"Binary operation between two nodes.\"\"\"\n", " ops = {\"+\": lambda x, y: x + y, \"*\": lambda x, y: x * y}\n", " self._op = op\n", " super().__init__(ops[op](x.data, y.data), (x, y))\n", "\n", " def _local_grad(self, parent):\n", " if self._op == \"+\":\n", " return 1.0\n", "\n", " elif self._op == \"*\":\n", " i = self._parents.index(parent)\n", " coparent = self._parents[1 - i]\n", " return coparent.data\n", "\n", " def __repr__(self):\n", " return self._op\n", "\n", "\n", "class ReLUNode(Node):\n", " def __init__(self, x):\n", " data = x.data * int(x.data > 0.0)\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return float(parent.data > 0)\n", "\n", " def __repr__(self):\n", " return \"relu\"\n", "\n", "\n", "class TanhNode(Node):\n", " def __init__(self, x):\n", " data = math.tanh(x.data)\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return 1 - self.data**2\n", "\n", " def __repr__(self):\n", " return \"tanh\"\n", "\n", "\n", "class PowOp(Node):\n", " def __init__(self, x, n):\n", " self.n = n\n", " data = x.data**self.n\n", " super().__init__(data, (x,))\n", "\n", " def _local_grad(self, parent):\n", " return self.n * parent.data ** (self.n - 1)\n", "\n", " def __repr__(self):\n", " return f\"** {self.n}\"" ] }, { "cell_type": "markdown", "id": "8897ade3", "metadata": { "papermill": { "duration": 0.003549, "end_time": "2024-09-05T18:30:31.150529", "exception": false, "start_time": "2024-09-05T18:30:31.146980", "status": "completed" }, "tags": [] }, "source": [ "**Remark.** Note circular definition is okay since references are resolved at runtime." ] }, { "cell_type": "markdown", "id": "971f3877", "metadata": { "papermill": { "duration": 0.004848, "end_time": "2024-09-05T18:30:31.159065", "exception": false, "start_time": "2024-09-05T18:30:31.154217", "status": "completed" }, "tags": [] }, "source": [ "
\n", "\n", "## Graph vizualization" ] }, { "cell_type": "markdown", "id": "b38fce39", "metadata": { "papermill": { "duration": 0.028164, "end_time": "2024-09-05T18:30:31.210379", "exception": false, "start_time": "2024-09-05T18:30:31.182215", "status": "completed" }, "tags": [] }, "source": [ "The next two functions help to visualize networks. The `trace` function just walks backward into the graph to collect all nodes and edges. This is used by the `draw_graph` which first draws all nodes, then draws all edges. For compute nodes we add a small juncture node which contains the name of the operation." ] }, { "cell_type": "code", "execution_count": 3, "id": "dbbec280", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:31.223011Z", "iopub.status.busy": "2024-09-05T18:30:31.222512Z", "iopub.status.idle": "2024-09-05T18:30:31.263302Z", "shell.execute_reply": "2024-09-05T18:30:31.262616Z" }, "papermill": { "duration": 0.049435, "end_time": "2024-09-05T18:30:31.265381", "exception": false, "start_time": "2024-09-05T18:30:31.215946", "status": "completed" }, "tags": [ "remove-input", "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
from graphviz import Digraph\n",
       "\n",
       "\n",
       "def trace(root):\n",
       "    """Builds a set of all nodes and edges in a graph."""\n",
       "    # https://github.com/karpathy/micrograd/blob/master/trace_graph.ipynb\n",
       "\n",
       "    nodes = set()\n",
       "    edges = set()\n",
       "\n",
       "    def build(v):\n",
       "        if v not in nodes:\n",
       "            nodes.add(v)\n",
       "            for parent in v._parents:\n",
       "                edges.add((parent, v))\n",
       "                build(parent)\n",
       "\n",
       "    build(root)\n",
       "    return nodes, edges\n",
       "\n",
       "\n",
       "def draw_graph(root):\n",
       "    """Build diagram of computational graph."""\n",
       "\n",
       "    dot = Digraph(format="svg", graph_attr={"rankdir": "LR"})  # LR = left to right\n",
       "    nodes, edges = trace(root)\n",
       "    for n in nodes:\n",
       "        # Add node to graph\n",
       "        uid = str(id(n))\n",
       "        dot.node(name=uid, label=f"data={n.data:.3f} | grad={n.grad:.4f}", shape="record")\n",
       "\n",
       "        # Connect node to op node if operation\n",
       "        # e.g. if (5) = (2) + (3), then draw (5) as (+) -> (5).\n",
       "        if len(n._parents) > 0:\n",
       "            dot.node(name=uid + str(n), label=str(n))\n",
       "            dot.edge(uid + str(n), uid)\n",
       "\n",
       "    for child, v in edges:\n",
       "        # Connect child to the op node of v\n",
       "        dot.edge(str(id(child)), str(id(v)) + str(v))\n",
       "\n",
       "    return dot\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{from} \\PY{n+nn}{graphviz} \\PY{k+kn}{import} \\PY{n}{Digraph}\n", "\n", "\n", "\\PY{k}{def} \\PY{n+nf}{trace}\\PY{p}{(}\\PY{n}{root}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Builds a set of all nodes and edges in a graph.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{c+c1}{\\PYZsh{} https://github.com/karpathy/micrograd/blob/master/trace\\PYZus{}graph.ipynb}\n", "\n", " \\PY{n}{nodes} \\PY{o}{=} \\PY{n+nb}{set}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{edges} \\PY{o}{=} \\PY{n+nb}{set}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{build}\\PY{p}{(}\\PY{n}{v}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n}{v} \\PY{o+ow}{not} \\PY{o+ow}{in} \\PY{n}{nodes}\\PY{p}{:}\n", " \\PY{n}{nodes}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{v}\\PY{p}{)}\n", " \\PY{k}{for} \\PY{n}{parent} \\PY{o+ow}{in} \\PY{n}{v}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{p}{:}\n", " \\PY{n}{edges}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{p}{(}\\PY{n}{parent}\\PY{p}{,} \\PY{n}{v}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n}{build}\\PY{p}{(}\\PY{n}{parent}\\PY{p}{)}\n", "\n", " \\PY{n}{build}\\PY{p}{(}\\PY{n}{root}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{nodes}\\PY{p}{,} \\PY{n}{edges}\n", "\n", "\n", "\\PY{k}{def} \\PY{n+nf}{draw\\PYZus{}graph}\\PY{p}{(}\\PY{n}{root}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Build diagram of computational graph.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", "\n", " \\PY{n}{dot} \\PY{o}{=} \\PY{n}{Digraph}\\PY{p}{(}\\PY{n+nb}{format}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{svg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{graph\\PYZus{}attr}\\PY{o}{=}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{rankdir}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{LR}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{\\PYZcb{}}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} LR = left to right}\n", " \\PY{n}{nodes}\\PY{p}{,} \\PY{n}{edges} \\PY{o}{=} \\PY{n}{trace}\\PY{p}{(}\\PY{n}{root}\\PY{p}{)}\n", " \\PY{k}{for} \\PY{n}{n} \\PY{o+ow}{in} \\PY{n}{nodes}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} Add node to graph}\n", " \\PY{n}{uid} \\PY{o}{=} \\PY{n+nb}{str}\\PY{p}{(}\\PY{n+nb}{id}\\PY{p}{(}\\PY{n}{n}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n}{dot}\\PY{o}{.}\\PY{n}{node}\\PY{p}{(}\\PY{n}{name}\\PY{o}{=}\\PY{n}{uid}\\PY{p}{,} \\PY{n}{label}\\PY{o}{=}\\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{data=}\\PY{l+s+si}{\\PYZob{}}\\PY{n}{n}\\PY{o}{.}\\PY{n}{data}\\PY{l+s+si}{:}\\PY{l+s+s2}{.3f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{ | grad=}\\PY{l+s+si}{\\PYZob{}}\\PY{n}{n}\\PY{o}{.}\\PY{n}{grad}\\PY{l+s+si}{:}\\PY{l+s+s2}{.4f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{shape}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{record}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\n", " \\PY{c+c1}{\\PYZsh{} Connect node to op node if operation}\n", " \\PY{c+c1}{\\PYZsh{} e.g. if (5) = (2) + (3), then draw (5) as (+) \\PYZhy{}\\PYZgt{} (5).}\n", " \\PY{k}{if} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{n}\\PY{o}{.}\\PY{n}{\\PYZus{}parents}\\PY{p}{)} \\PY{o}{\\PYZgt{}} \\PY{l+m+mi}{0}\\PY{p}{:}\n", " \\PY{n}{dot}\\PY{o}{.}\\PY{n}{node}\\PY{p}{(}\\PY{n}{name}\\PY{o}{=}\\PY{n}{uid} \\PY{o}{+} \\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{n}\\PY{p}{)}\\PY{p}{,} \\PY{n}{label}\\PY{o}{=}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{n}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n}{dot}\\PY{o}{.}\\PY{n}{edge}\\PY{p}{(}\\PY{n}{uid} \\PY{o}{+} \\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{n}\\PY{p}{)}\\PY{p}{,} \\PY{n}{uid}\\PY{p}{)}\n", "\n", " \\PY{k}{for} \\PY{n}{child}\\PY{p}{,} \\PY{n}{v} \\PY{o+ow}{in} \\PY{n}{edges}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} Connect child to the op node of v}\n", " \\PY{n}{dot}\\PY{o}{.}\\PY{n}{edge}\\PY{p}{(}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n+nb}{id}\\PY{p}{(}\\PY{n}{child}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n+nb}{str}\\PY{p}{(}\\PY{n+nb}{id}\\PY{p}{(}\\PY{n}{v}\\PY{p}{)}\\PY{p}{)} \\PY{o}{+} \\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{v}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{return} \\PY{n}{dot}\n", "\\end{Verbatim}\n" ], "text/plain": [ "\n", "from graphviz import Digraph\n", "\n", "\n", "def trace(root):\n", " \"\"\"Builds a set of all nodes and edges in a graph.\"\"\"\n", " # https://github.com/karpathy/micrograd/blob/master/trace_graph.ipynb\n", "\n", " nodes = set()\n", " edges = set()\n", "\n", " def build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for parent in v._parents:\n", " edges.add((parent, v))\n", " build(parent)\n", "\n", " build(root)\n", " return nodes, edges\n", "\n", "\n", "def draw_graph(root):\n", " \"\"\"Build diagram of computational graph.\"\"\"\n", "\n", " dot = Digraph(format=\"svg\", graph_attr={\"rankdir\": \"LR\"}) # LR = left to right\n", " nodes, edges = trace(root)\n", " for n in nodes:\n", " # Add node to graph\n", " uid = str(id(n))\n", " dot.node(name=uid, label=f\"data={n.data:.3f} | grad={n.grad:.4f}\", shape=\"record\")\n", "\n", " # Connect node to op node if operation\n", " # e.g. if (5) = (2) + (3), then draw (5) as (+) -> (5).\n", " if len(n._parents) > 0:\n", " dot.node(name=uid + str(n), label=str(n))\n", " dot.edge(uid + str(n), uid)\n", "\n", " for child, v in edges:\n", " # Connect child to the op node of v\n", " dot.edge(str(id(child)), str(id(v)) + str(v))\n", "\n", " return dot" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "\n", "from graphviz import Digraph\n", "\n", "\n", "def trace(root):\n", " \"\"\"Builds a set of all nodes and edges in a graph.\"\"\"\n", " # https://github.com/karpathy/micrograd/blob/master/trace_graph.ipynb\n", "\n", " nodes = set()\n", " edges = set()\n", "\n", " def build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for parent in v._parents:\n", " edges.add((parent, v))\n", " build(parent)\n", "\n", " build(root)\n", " return nodes, edges\n", "\n", "\n", "def draw_graph(root):\n", " \"\"\"Build diagram of computational graph.\"\"\"\n", "\n", " dot = Digraph(format=\"svg\", graph_attr={\"rankdir\": \"LR\"}) # LR = left to right\n", " nodes, edges = trace(root)\n", " for n in nodes:\n", " # Add node to graph\n", " uid = str(id(n))\n", " dot.node(name=uid, label=f\"data={n.data:.3f} | grad={n.grad:.4f}\", shape=\"record\")\n", "\n", " # Connect node to op node if operation\n", " # e.g. if (5) = (2) + (3), then draw (5) as (+) -> (5).\n", " if len(n._parents) > 0:\n", " dot.node(name=uid + str(n), label=str(n))\n", " dot.edge(uid + str(n), uid)\n", "\n", " for child, v in edges:\n", " # Connect child to the op node of v\n", " dot.edge(str(id(child)), str(id(v)) + str(v))\n", "\n", " return dot" ] }, { "cell_type": "markdown", "id": "79342ba0", "metadata": { "papermill": { "duration": 0.006336, "end_time": "2024-09-05T18:30:31.278730", "exception": false, "start_time": "2024-09-05T18:30:31.272394", "status": "completed" }, "tags": [] }, "source": [ "Creating graph for a dense unit. Observe that `x1` has a degree of 2 since it has two children." ] }, { "cell_type": "code", "execution_count": 4, "id": "bbdb1782", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:31.293915Z", "iopub.status.busy": "2024-09-05T18:30:31.293543Z", "iopub.status.idle": "2024-09-05T18:30:31.652324Z", "shell.execute_reply": "2024-09-05T18:30:31.651500Z" }, "papermill": { "duration": 0.373569, "end_time": "2024-09-05T18:30:31.659509", "exception": false, "start_time": "2024-09-05T18:30:31.285940", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "4379506720\n", "\n", "data=4.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379411648+\n", "\n", "+\n", "\n", "\n", "\n", "4379506720->4379411648+\n", "\n", "\n", "\n", "\n", "\n", "4379506768\n", "\n", "data=-1.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379415200*\n", "\n", "*\n", "\n", "\n", "\n", "4379506768->4379415200*\n", "\n", "\n", "\n", "\n", "\n", "4379415200\n", "\n", "data=-2.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379411744+\n", "\n", "+\n", "\n", "\n", "\n", "4379415200->4379411744+\n", "\n", "\n", "\n", "\n", "\n", "4379415200*->4379415200\n", "\n", "\n", "\n", "\n", "\n", "4379411648\n", "\n", "data=6.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379412944relu\n", "\n", "relu\n", "\n", "\n", "\n", "4379411648->4379412944relu\n", "\n", "\n", "\n", "\n", "\n", "4379411648+->4379411648\n", "\n", "\n", "\n", "\n", "\n", "4379414720\n", "\n", "data=4.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379414720->4379411744+\n", "\n", "\n", "\n", "\n", "\n", "4379414720*\n", "\n", "*\n", "\n", "\n", "\n", "4379414720*->4379414720\n", "\n", "\n", "\n", "\n", "\n", "4379411744\n", "\n", "data=2.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379411744->4379411648+\n", "\n", "\n", "\n", "\n", "\n", "4379411744+->4379411744\n", "\n", "\n", "\n", "\n", "\n", "4379424560\n", "\n", "data=2.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379424560->4379415200*\n", "\n", "\n", "\n", "\n", "\n", "4379424560->4379414720*\n", "\n", "\n", "\n", "\n", "\n", "4379412944\n", "\n", "data=6.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379412944relu->4379412944\n", "\n", "\n", "\n", "\n", "\n", "4379506672\n", "\n", "data=2.000\n", "\n", "grad=0.0000\n", "\n", "\n", "\n", "4379506672->4379414720*\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w0 = Node(-1.0)\n", "w1 = Node(2.0)\n", "b = Node(4.0)\n", "x = Node(2.0)\n", "t = Node(3.0)\n", "\n", "z = w0 * x + w1 * x + b\n", "u = z.tanh()\n", "y = z.relu()\n", "draw_graph(y)" ] }, { "cell_type": "markdown", "id": "699d87d8", "metadata": { "papermill": { "duration": 0.005196, "end_time": "2024-09-05T18:30:31.671643", "exception": false, "start_time": "2024-09-05T18:30:31.666447", "status": "completed" }, "tags": [] }, "source": [ "Gradients all check out:" ] }, { "cell_type": "code", "execution_count": 5, "id": "7e3ede9e", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:31.681248Z", "iopub.status.busy": "2024-09-05T18:30:31.680772Z", "iopub.status.idle": "2024-09-05T18:30:31.816814Z", "shell.execute_reply": "2024-09-05T18:30:31.816224Z" }, "papermill": { "duration": 0.143049, "end_time": "2024-09-05T18:30:31.818431", "exception": false, "start_time": "2024-09-05T18:30:31.675382", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "4379506720\n", "\n", "data=4.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379411648+\n", "\n", "+\n", "\n", "\n", "\n", "4379506720->4379411648+\n", "\n", "\n", "\n", "\n", "\n", "4379506768\n", "\n", "data=-1.000\n", "\n", "grad=2.0000\n", "\n", "\n", "\n", "4379415200*\n", "\n", "*\n", "\n", "\n", "\n", "4379506768->4379415200*\n", "\n", "\n", "\n", "\n", "\n", "4379415200\n", "\n", "data=-2.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379411744+\n", "\n", "+\n", "\n", "\n", "\n", "4379415200->4379411744+\n", "\n", "\n", "\n", "\n", "\n", "4379415200*->4379415200\n", "\n", "\n", "\n", "\n", "\n", "4379411648\n", "\n", "data=6.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379412944relu\n", "\n", "relu\n", "\n", "\n", "\n", "4379411648->4379412944relu\n", "\n", "\n", "\n", "\n", "\n", "4379411648+->4379411648\n", "\n", "\n", "\n", "\n", "\n", "4379414720\n", "\n", "data=4.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379414720->4379411744+\n", "\n", "\n", "\n", "\n", "\n", "4379414720*\n", "\n", "*\n", "\n", "\n", "\n", "4379414720*->4379414720\n", "\n", "\n", "\n", "\n", "\n", "4379411744\n", "\n", "data=2.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379411744->4379411648+\n", "\n", "\n", "\n", "\n", "\n", "4379411744+->4379411744\n", "\n", "\n", "\n", "\n", "\n", "4379424560\n", "\n", "data=2.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379424560->4379415200*\n", "\n", "\n", "\n", "\n", "\n", "4379424560->4379414720*\n", "\n", "\n", "\n", "\n", "\n", "4379412944\n", "\n", "data=6.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379412944relu->4379412944\n", "\n", "\n", "\n", "\n", "\n", "4379506672\n", "\n", "data=2.000\n", "\n", "grad=2.0000\n", "\n", "\n", "\n", "4379506672->4379414720*\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.backward()\n", "draw_graph(y)" ] }, { "cell_type": "markdown", "id": "7ab79a91", "metadata": { "papermill": { "duration": 0.00413, "end_time": "2024-09-05T18:30:31.827553", "exception": false, "start_time": "2024-09-05T18:30:31.823423", "status": "completed" }, "tags": [] }, "source": [ "Note that `u` is not shown in the graph and `u.grad` is zero since `y` has no dependence on `u`:" ] }, { "cell_type": "code", "execution_count": 6, "id": "8f8eac44", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:31.838625Z", "iopub.status.busy": "2024-09-05T18:30:31.838394Z", "iopub.status.idle": "2024-09-05T18:30:31.842133Z", "shell.execute_reply": "2024-09-05T18:30:31.841774Z" }, "papermill": { "duration": 0.010971, "end_time": "2024-09-05T18:30:31.843340", "exception": false, "start_time": "2024-09-05T18:30:31.832369", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u.grad" ] }, { "cell_type": "markdown", "id": "4d2b65dc", "metadata": { "papermill": { "duration": 0.003324, "end_time": "2024-09-05T18:30:31.850726", "exception": false, "start_time": "2024-09-05T18:30:31.847402", "status": "completed" }, "tags": [] }, "source": [ "Moreover, gradients on shared parameters **accumulate** with multiple inputs:" ] }, { "cell_type": "code", "execution_count": 7, "id": "58aa8949", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:31.859213Z", "iopub.status.busy": "2024-09-05T18:30:31.859005Z", "iopub.status.idle": "2024-09-05T18:30:31.976780Z", "shell.execute_reply": "2024-09-05T18:30:31.975991Z" }, "papermill": { "duration": 0.123475, "end_time": "2024-09-05T18:30:31.978153", "exception": false, "start_time": "2024-09-05T18:30:31.854678", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "4379142672\n", "\n", "data=1.700\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379501920+\n", "\n", "+\n", "\n", "\n", "\n", "4379142672->4379501920+\n", "\n", "\n", "\n", "\n", "\n", "4379142672+\n", "\n", "+\n", "\n", "\n", "\n", "4379142672+->4379142672\n", "\n", "\n", "\n", "\n", "\n", "4379506720\n", "\n", "data=4.000\n", "\n", "grad=2.0000\n", "\n", "\n", "\n", "4379506720->4379501920+\n", "\n", "\n", "\n", "\n", "\n", "4379142192\n", "\n", "data=1.700\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379141856*\n", "\n", "*\n", "\n", "\n", "\n", "4379142192->4379141856*\n", "\n", "\n", "\n", "\n", "\n", "4379141472*\n", "\n", "*\n", "\n", "\n", "\n", "4379142192->4379141472*\n", "\n", "\n", "\n", "\n", "\n", "4379506768\n", "\n", "data=-1.000\n", "\n", "grad=3.7000\n", "\n", "\n", "\n", "4379506768->4379141472*\n", "\n", "\n", "\n", "\n", "\n", "4379142000\n", "\n", "data=5.700\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379142000relu\n", "\n", "relu\n", "\n", "\n", "\n", "4379142000relu->4379142000\n", "\n", "\n", "\n", "\n", "\n", "4379141856\n", "\n", "data=3.400\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379141856->4379142672+\n", "\n", "\n", "\n", "\n", "\n", "4379141856*->4379141856\n", "\n", "\n", "\n", "\n", "\n", "4379141472\n", "\n", "data=-1.700\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379141472->4379142672+\n", "\n", "\n", "\n", "\n", "\n", "4379141472*->4379141472\n", "\n", "\n", "\n", "\n", "\n", "4379501920\n", "\n", "data=5.700\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4379501920->4379142000relu\n", "\n", "\n", "\n", "\n", "\n", "4379501920+->4379501920\n", "\n", "\n", "\n", "\n", "\n", "4379506672\n", "\n", "data=2.000\n", "\n", "grad=3.7000\n", "\n", "\n", "\n", "4379506672->4379141856*\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = Node(1.7)\n", "z1 = w0 * x1 + w1 * x1 + b\n", "y1 = z1.relu()\n", "y1.backward()\n", "draw_graph(y1)" ] }, { "cell_type": "code", "execution_count": null, "id": "320a04cc", "metadata": { "papermill": { "duration": 0.00499, "end_time": "2024-09-05T18:30:31.987557", "exception": false, "start_time": "2024-09-05T18:30:31.982567", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 2.246177, "end_time": "2024-09-05T18:30:32.210832", "environment_variables": {}, "exception": null, "input_path": "00a-compute-nodes.ipynb", "output_path": "00a-compute-nodes.ipynb", "parameters": {}, "start_time": "2024-09-05T18:30:29.964655", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }