{ "cells": [ { "cell_type": "markdown", "id": "d58644a3", "metadata": { "papermill": { "duration": 0.007135, "end_time": "2024-09-05T18:30:30.865440", "exception": false, "start_time": "2024-09-05T18:30:30.858305", "status": "completed" }, "tags": [] }, "source": [ "# Defined operations" ] }, { "cell_type": "markdown", "id": "e54b0381", "metadata": { "papermill": { "duration": 0.004333, "end_time": "2024-09-05T18:30:30.875104", "exception": false, "start_time": "2024-09-05T18:30:30.870771", "status": "completed" }, "tags": [] }, "source": [ "Recall that all operations must be defined with specific local gradient computation for BP to work. In this section, we will implement a minimal **autograd engine** for creating computational graphs. This starts with the base `Node` class which has a `data` attribute for storing output and a `grad` attribute for storing the global gradient. Furthermore, the base class defines a `backward` method to solve for `grad` as described above." ] }, { "cell_type": "code", "execution_count": 1, "id": "967a8e6d", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:30.884439Z", "iopub.status.busy": "2024-09-05T18:30:30.884096Z", "iopub.status.idle": "2024-09-05T18:30:30.938215Z", "shell.execute_reply": "2024-09-05T18:30:30.937799Z" }, "papermill": { "duration": 0.060396, "end_time": "2024-09-05T18:30:30.939730", "exception": false, "start_time": "2024-09-05T18:30:30.879334", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import math\n",
"import random\n",
"random.seed(42)\n",
"\n",
"from typing import final\n",
"from collections import OrderedDict\n",
"\n",
"\n",
"class Node:\n",
" def __init__(self, data, parents=()):\n",
" self.data = data\n",
" self.grad = 0 # ∂loss / ∂self\n",
" self._parents = parents # parent -> self\n",
"\n",
" @final\n",
" def sorted_nodes(self):\n",
" """Return topologically sorted nodes with self as root."""\n",
" topo = OrderedDict()\n",
"\n",
" def dfs(node):\n",
" if node not in topo:\n",
" for parent in node._parents:\n",
" dfs(parent)\n",
"\n",
" topo[node] = None\n",
"\n",
" dfs(self)\n",
" return reversed(topo)\n",
"\n",
"\n",
" @final\n",
" def backward(self):\n",
" """Send global grads backward to parent nodes."""\n",
" self.grad = 1.0\n",
" for node in self.sorted_nodes():\n",
" for parent in node._parents:\n",
" parent.grad += node.grad * node._local_grad(parent)\n",
"\n",
"\n",
" def _local_grad(self, parent) -> float:\n",
" """Calculate local grads ∂self / ∂parent."""\n",
" raise NotImplementedError("Base node has no parents.")\n",
"\n",
"\n",
" def __add__(self, node):\n",
" return BinaryOpNode(self, node, op="+")\n",
"\n",
" def __mul__(self, node):\n",
" return BinaryOpNode(self, node, op="*")\n",
"\n",
" def __pow__(self, n):\n",
" assert isinstance(n, (int, float)) and n != 1\n",
" return PowOp(self, n)\n",
"\n",
" def relu(self):\n",
" return ReLUNode(self)\n",
"\n",
" def tanh(self):\n",
" return TanhNode(self)\n",
"\n",
" def __neg__(self):\n",
" return self * Node(-1)\n",
"\n",
" def __sub__(self, node):\n",
" return self + (-node)\n",
"
class BinaryOpNode(Node):\n",
" def __init__(self, x, y, op: str):\n",
" """Binary operation between two nodes."""\n",
" ops = {"+": lambda x, y: x + y, "*": lambda x, y: x * y}\n",
" self._op = op\n",
" super().__init__(ops[op](x.data, y.data), (x, y))\n",
"\n",
" def _local_grad(self, parent):\n",
" if self._op == "+":\n",
" return 1.0\n",
"\n",
" elif self._op == "*":\n",
" i = self._parents.index(parent)\n",
" coparent = self._parents[1 - i]\n",
" return coparent.data\n",
"\n",
" def __repr__(self):\n",
" return self._op\n",
"\n",
"\n",
"class ReLUNode(Node):\n",
" def __init__(self, x):\n",
" data = x.data * int(x.data > 0.0)\n",
" super().__init__(data, (x,))\n",
"\n",
" def _local_grad(self, parent):\n",
" return float(parent.data > 0)\n",
"\n",
" def __repr__(self):\n",
" return "relu"\n",
"\n",
"\n",
"class TanhNode(Node):\n",
" def __init__(self, x):\n",
" data = math.tanh(x.data)\n",
" super().__init__(data, (x,))\n",
"\n",
" def _local_grad(self, parent):\n",
" return 1 - self.data**2\n",
"\n",
" def __repr__(self):\n",
" return "tanh"\n",
"\n",
"\n",
"class PowOp(Node):\n",
" def __init__(self, x, n):\n",
" self.n = n\n",
" data = x.data**self.n\n",
" super().__init__(data, (x,))\n",
"\n",
" def _local_grad(self, parent):\n",
" return self.n * parent.data ** (self.n - 1)\n",
"\n",
" def __repr__(self):\n",
" return f"** {self.n}"\n",
"
from graphviz import Digraph\n",
"\n",
"\n",
"def trace(root):\n",
" """Builds a set of all nodes and edges in a graph."""\n",
" # https://github.com/karpathy/micrograd/blob/master/trace_graph.ipynb\n",
"\n",
" nodes = set()\n",
" edges = set()\n",
"\n",
" def build(v):\n",
" if v not in nodes:\n",
" nodes.add(v)\n",
" for parent in v._parents:\n",
" edges.add((parent, v))\n",
" build(parent)\n",
"\n",
" build(root)\n",
" return nodes, edges\n",
"\n",
"\n",
"def draw_graph(root):\n",
" """Build diagram of computational graph."""\n",
"\n",
" dot = Digraph(format="svg", graph_attr={"rankdir": "LR"}) # LR = left to right\n",
" nodes, edges = trace(root)\n",
" for n in nodes:\n",
" # Add node to graph\n",
" uid = str(id(n))\n",
" dot.node(name=uid, label=f"data={n.data:.3f} | grad={n.grad:.4f}", shape="record")\n",
"\n",
" # Connect node to op node if operation\n",
" # e.g. if (5) = (2) + (3), then draw (5) as (+) -> (5).\n",
" if len(n._parents) > 0:\n",
" dot.node(name=uid + str(n), label=str(n))\n",
" dot.edge(uid + str(n), uid)\n",
"\n",
" for child, v in edges:\n",
" # Connect child to the op node of v\n",
" dot.edge(str(id(child)), str(id(v)) + str(v))\n",
"\n",
" return dot\n",
"