{ "cells": [ { "cell_type": "markdown", "id": "1d0665ce", "metadata": { "papermill": { "duration": 0.004409, "end_time": "2024-09-05T18:30:34.421709", "exception": false, "start_time": "2024-09-05T18:30:34.417300", "status": "completed" }, "tags": [] }, "source": [ "# Neural network module" ] }, { "cell_type": "markdown", "id": "e51bfef8", "metadata": { "papermill": { "duration": 0.01941, "end_time": "2024-09-05T18:30:34.446073", "exception": false, "start_time": "2024-09-05T18:30:34.426663", "status": "completed" }, "tags": [] }, "source": [ "Here we construct the neural network module. The `Module` class defines an abstract class that maintains a list of the parameters used in forward pass implemented in `__call__`. The decorator `@final` is to prevent any inheriting class from overriding the methods as doing so would result in a warning (or an error with a type checker)." ] }, { "cell_type": "code", "execution_count": 1, "id": "ae2cc403", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:34.461626Z", "iopub.status.busy": "2024-09-05T18:30:34.461229Z", "iopub.status.idle": "2024-09-05T18:30:34.492764Z", "shell.execute_reply": "2024-09-05T18:30:34.492267Z" }, "papermill": { "duration": 0.045357, "end_time": "2024-09-05T18:30:34.495126", "exception": false, "start_time": "2024-09-05T18:30:34.449769", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "4a432dbd", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:34.504787Z", "iopub.status.busy": "2024-09-05T18:30:34.504091Z", "iopub.status.idle": "2024-09-05T18:30:34.604889Z", "shell.execute_reply": "2024-09-05T18:30:34.604458Z" }, "papermill": { "duration": 0.108303, "end_time": "2024-09-05T18:30:34.606828", "exception": false, "start_time": "2024-09-05T18:30:34.498525", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
from abc import ABC, abstractmethod\n",
       "\n",
       "class Module(ABC):\n",
       "    def __init__(self):\n",
       "        self._parameters = []\n",
       "\n",
       "    @final\n",
       "    def parameters(self) -> list:\n",
       "        return self._parameters\n",
       "\n",
       "    @abstractmethod\n",
       "    def __call__(self, x: list):\n",
       "        pass\n",
       "\n",
       "    @final\n",
       "    def zero_grad(self):\n",
       "        for p in self.parameters():\n",
       "            p.grad = 0\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{from} \\PY{n+nn}{abc} \\PY{k+kn}{import} \\PY{n}{ABC}\\PY{p}{,} \\PY{n}{abstractmethod}\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{(}\\PY{n}{ABC}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parameters} \\PY{o}{=} \\PY{p}{[}\\PY{p}{]}\n", "\n", " \\PY{n+nd}{@final}\n", " \\PY{k}{def} \\PY{n+nf}{parameters}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{list}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parameters}\n", "\n", " \\PY{n+nd}{@abstractmethod}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}call\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{:} \\PY{n+nb}{list}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{pass}\n", "\n", " \\PY{n+nd}{@final}\n", " \\PY{k}{def} \\PY{n+nf}{zero\\PYZus{}grad}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{parameters}\\PY{p}{(}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{p}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{=} \\PY{l+m+mi}{0}\n", "\\end{Verbatim}\n" ], "text/plain": [ "from abc import ABC, abstractmethod\n", "\n", "class Module(ABC):\n", " def __init__(self):\n", " self._parameters = []\n", "\n", " @final\n", " def parameters(self) -> list:\n", " return self._parameters\n", "\n", " @abstractmethod\n", " def __call__(self, x: list):\n", " pass\n", "\n", " @final\n", " def zero_grad(self):\n", " for p in self.parameters():\n", " p.grad = 0" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "from abc import ABC, abstractmethod\n", "\n", "class Module(ABC):\n", " def __init__(self):\n", " self._parameters = []\n", "\n", " @final\n", " def parameters(self) -> list:\n", " return self._parameters\n", "\n", " @abstractmethod\n", " def __call__(self, x: list):\n", " pass\n", "\n", " @final\n", " def zero_grad(self):\n", " for p in self.parameters():\n", " p.grad = 0" ] }, { "cell_type": "markdown", "id": "55502473", "metadata": { "papermill": { "duration": 0.001561, "end_time": "2024-09-05T18:30:34.610418", "exception": false, "start_time": "2024-09-05T18:30:34.608857", "status": "completed" }, "tags": [] }, "source": [ "The `_parameters` attribute is defined so that the parameter list is not constructed at each call of the `parameters()` method. Implementing layers from neurons:" ] }, { "cell_type": "code", "execution_count": 3, "id": "09d0d8b1", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:34.614300Z", "iopub.status.busy": "2024-09-05T18:30:34.614086Z", "iopub.status.idle": "2024-09-05T18:30:34.632626Z", "shell.execute_reply": "2024-09-05T18:30:34.632250Z" }, "papermill": { "duration": 0.022001, "end_time": "2024-09-05T18:30:34.633861", "exception": false, "start_time": "2024-09-05T18:30:34.611860", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
class Neuron(Module):\n",
       "    def __init__(self, n_in, activation=None):\n",
       "        self.n_in = n_in\n",
       "        self.act = activation\n",
       "\n",
       "        self.w = [Node(random.random()) for _ in range(n_in)]\n",
       "        self.b = Node(0.0)\n",
       "        self._parameters = self.w + [self.b]\n",
       "\n",
       "    def __call__(self, x: list):\n",
       "        assert len(x) == self.n_in\n",
       "        out = sum((x[j] * self.w[j] for j in range(self.n_in)), start=self.b)\n",
       "        if self.act is not None:\n",
       "            if self.act == "tanh":\n",
       "                out = out.tanh()\n",
       "            elif self.act == "relu":\n",
       "                out = out.relu()\n",
       "            else:\n",
       "                raise NotImplementedError("Activation not supported.")\n",
       "        return out\n",
       "\n",
       "    def __repr__(self):\n",
       "        return f"{self.act if self.act is not None else 'linear'}({len(self.w)})"\n",
       "\n",
       "\n",
       "class Layer(Module):\n",
       "    def __init__(self, n_in, n_out, *args):\n",
       "        self.neurons = [Neuron(n_in, *args) for _ in range(n_out)]\n",
       "        self._parameters = [p for n in self.neurons for p in n.parameters()]\n",
       "\n",
       "    def __call__(self, x: list):\n",
       "        out = [n(x) for n in self.neurons]\n",
       "        return out[0] if len(out) == 1 else out\n",
       "\n",
       "    def __repr__(self):\n",
       "        return f"Layer[{', '.join(str(n) for n in self.neurons)}]"\n",
       "\n",
       "\n",
       "class MLP(Module):\n",
       "    def __init__(self, n_in, n_outs, activation=None):\n",
       "        sizes = [n_in] + n_outs\n",
       "        self.layers = []\n",
       "        for i in range(len(n_outs)):\n",
       "            act = activation if i < len(n_outs) - 1 else None\n",
       "            layer = Layer(sizes[i], sizes[i + 1], act)\n",
       "            self.layers.append(layer)\n",
       "\n",
       "        self._parameters = [p for layer in self.layers for p in layer.parameters()]\n",
       "\n",
       "    def __call__(self, x):\n",
       "        for layer in self.layers:\n",
       "            x = layer(x)\n",
       "        return x\n",
       "\n",
       "    def __repr__(self):\n",
       "        return f"MLP[{', '.join(str(layer) for layer in self.layers)}]"\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k}{class} \\PY{n+nc}{Neuron}\\PY{p}{(}\\PY{n}{Module}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{n\\PYZus{}in}\\PY{p}{,} \\PY{n}{activation}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n\\PYZus{}in} \\PY{o}{=} \\PY{n}{n\\PYZus{}in}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act} \\PY{o}{=} \\PY{n}{activation}\n", "\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{w} \\PY{o}{=} \\PY{p}{[}\\PY{n}{Node}\\PY{p}{(}\\PY{n}{random}\\PY{o}{.}\\PY{n}{random}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)} \\PY{k}{for} \\PY{n}{\\PYZus{}} \\PY{o+ow}{in} \\PY{n+nb}{range}\\PY{p}{(}\\PY{n}{n\\PYZus{}in}\\PY{p}{)}\\PY{p}{]}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{b} \\PY{o}{=} \\PY{n}{Node}\\PY{p}{(}\\PY{l+m+mf}{0.0}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parameters} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{w} \\PY{o}{+} \\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{b}\\PY{p}{]}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}call\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{:} \\PY{n+nb}{list}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{assert} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)} \\PY{o}{==} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n\\PYZus{}in}\n", " \\PY{n}{out} \\PY{o}{=} \\PY{n+nb}{sum}\\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{[}\\PY{n}{j}\\PY{p}{]} \\PY{o}{*} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{w}\\PY{p}{[}\\PY{n}{j}\\PY{p}{]} \\PY{k}{for} \\PY{n}{j} \\PY{o+ow}{in} \\PY{n+nb}{range}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{n\\PYZus{}in}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{start}\\PY{o}{=}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{b}\\PY{p}{)}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act} \\PY{o+ow}{is} \\PY{o+ow}{not} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act} \\PY{o}{==} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tanh}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:}\n", " \\PY{n}{out} \\PY{o}{=} \\PY{n}{out}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{elif} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act} \\PY{o}{==} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{relu}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:}\n", " \\PY{n}{out} \\PY{o}{=} \\PY{n}{out}\\PY{o}{.}\\PY{n}{relu}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{else}\\PY{p}{:}\n", " \\PY{k}{raise} \\PY{n+ne}{NotImplementedError}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{Activation not supported.}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{out}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act}\\PY{+w}{ }\\PY{k}{if}\\PY{+w}{ }\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{act}\\PY{+w}{ }\\PY{o+ow}{is}\\PY{+w}{ }\\PY{o+ow}{not}\\PY{+w}{ }\\PY{k+kc}{None}\\PY{+w}{ }\\PY{k}{else}\\PY{+w}{ }\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{linear}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{(}\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{w}\\PY{p}{)}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{)}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Layer}\\PY{p}{(}\\PY{n}{Module}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{n\\PYZus{}in}\\PY{p}{,} \\PY{n}{n\\PYZus{}out}\\PY{p}{,} \\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{neurons} \\PY{o}{=} \\PY{p}{[}\\PY{n}{Neuron}\\PY{p}{(}\\PY{n}{n\\PYZus{}in}\\PY{p}{,} \\PY{o}{*}\\PY{n}{args}\\PY{p}{)} \\PY{k}{for} \\PY{n}{\\PYZus{}} \\PY{o+ow}{in} \\PY{n+nb}{range}\\PY{p}{(}\\PY{n}{n\\PYZus{}out}\\PY{p}{)}\\PY{p}{]}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parameters} \\PY{o}{=} \\PY{p}{[}\\PY{n}{p} \\PY{k}{for} \\PY{n}{n} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{neurons} \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n}{n}\\PY{o}{.}\\PY{n}{parameters}\\PY{p}{(}\\PY{p}{)}\\PY{p}{]}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}call\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{:} \\PY{n+nb}{list}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{out} \\PY{o}{=} \\PY{p}{[}\\PY{n}{n}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)} \\PY{k}{for} \\PY{n}{n} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{neurons}\\PY{p}{]}\n", " \\PY{k}{return} \\PY{n}{out}\\PY{p}{[}\\PY{l+m+mi}{0}\\PY{p}{]} \\PY{k}{if} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{out}\\PY{p}{)} \\PY{o}{==} \\PY{l+m+mi}{1} \\PY{k}{else} \\PY{n}{out}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{Layer[}\\PY{l+s+si}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{, }\\PY{l+s+s1}{\\PYZsq{}}\\PY{o}{.}\\PY{n}{join}\\PY{p}{(}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{n}\\PY{p}{)}\\PY{+w}{ }\\PY{k}{for}\\PY{+w}{ }\\PY{n}{n}\\PY{+w}{ }\\PY{o+ow}{in}\\PY{+w}{ }\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{neurons}\\PY{p}{)}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{]}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{MLP}\\PY{p}{(}\\PY{n}{Module}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{n\\PYZus{}in}\\PY{p}{,} \\PY{n}{n\\PYZus{}outs}\\PY{p}{,} \\PY{n}{activation}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{sizes} \\PY{o}{=} \\PY{p}{[}\\PY{n}{n\\PYZus{}in}\\PY{p}{]} \\PY{o}{+} \\PY{n}{n\\PYZus{}outs}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{layers} \\PY{o}{=} \\PY{p}{[}\\PY{p}{]}\n", " \\PY{k}{for} \\PY{n}{i} \\PY{o+ow}{in} \\PY{n+nb}{range}\\PY{p}{(}\\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{n\\PYZus{}outs}\\PY{p}{)}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{act} \\PY{o}{=} \\PY{n}{activation} \\PY{k}{if} \\PY{n}{i} \\PY{o}{\\PYZlt{}} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{n\\PYZus{}outs}\\PY{p}{)} \\PY{o}{\\PYZhy{}} \\PY{l+m+mi}{1} \\PY{k}{else} \\PY{k+kc}{None}\n", " \\PY{n}{layer} \\PY{o}{=} \\PY{n}{Layer}\\PY{p}{(}\\PY{n}{sizes}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{,} \\PY{n}{sizes}\\PY{p}{[}\\PY{n}{i} \\PY{o}{+} \\PY{l+m+mi}{1}\\PY{p}{]}\\PY{p}{,} \\PY{n}{act}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{layers}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{layer}\\PY{p}{)}\n", "\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}parameters} \\PY{o}{=} \\PY{p}{[}\\PY{n}{p} \\PY{k}{for} \\PY{n}{layer} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{layers} \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n}{layer}\\PY{o}{.}\\PY{n}{parameters}\\PY{p}{(}\\PY{p}{)}\\PY{p}{]}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}call\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{layer} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{layers}\\PY{p}{:}\n", " \\PY{n}{x} \\PY{o}{=} \\PY{n}{layer}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{x}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}repr\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{MLP[}\\PY{l+s+si}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{, }\\PY{l+s+s1}{\\PYZsq{}}\\PY{o}{.}\\PY{n}{join}\\PY{p}{(}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{layer}\\PY{p}{)}\\PY{+w}{ }\\PY{k}{for}\\PY{+w}{ }\\PY{n}{layer}\\PY{+w}{ }\\PY{o+ow}{in}\\PY{+w}{ }\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{layers}\\PY{p}{)}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{]}\\PY{l+s+s2}{\\PYZdq{}}\n", "\\end{Verbatim}\n" ], "text/plain": [ "\n", "class Neuron(Module):\n", " def __init__(self, n_in, activation=None):\n", " self.n_in = n_in\n", " self.act = activation\n", "\n", " self.w = [Node(random.random()) for _ in range(n_in)]\n", " self.b = Node(0.0)\n", " self._parameters = self.w + [self.b]\n", "\n", " def __call__(self, x: list):\n", " assert len(x) == self.n_in\n", " out = sum((x[j] * self.w[j] for j in range(self.n_in)), start=self.b)\n", " if self.act is not None:\n", " if self.act == \"tanh\":\n", " out = out.tanh()\n", " elif self.act == \"relu\":\n", " out = out.relu()\n", " else:\n", " raise NotImplementedError(\"Activation not supported.\")\n", " return out\n", "\n", " def __repr__(self):\n", " return f\"{self.act if self.act is not None else 'linear'}({len(self.w)})\"\n", "\n", "\n", "class Layer(Module):\n", " def __init__(self, n_in, n_out, *args):\n", " self.neurons = [Neuron(n_in, *args) for _ in range(n_out)]\n", " self._parameters = [p for n in self.neurons for p in n.parameters()]\n", "\n", " def __call__(self, x: list):\n", " out = [n(x) for n in self.neurons]\n", " return out[0] if len(out) == 1 else out\n", "\n", " def __repr__(self):\n", " return f\"Layer[{', '.join(str(n) for n in self.neurons)}]\"\n", "\n", "\n", "class MLP(Module):\n", " def __init__(self, n_in, n_outs, activation=None):\n", " sizes = [n_in] + n_outs\n", " self.layers = []\n", " for i in range(len(n_outs)):\n", " act = activation if i < len(n_outs) - 1 else None\n", " layer = Layer(sizes[i], sizes[i + 1], act)\n", " self.layers.append(layer)\n", "\n", " self._parameters = [p for layer in self.layers for p in layer.parameters()]\n", "\n", " def __call__(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return x\n", "\n", " def __repr__(self):\n", " return f\"MLP[{', '.join(str(layer) for layer in self.layers)}]\"" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "\n", "class Neuron(Module):\n", " def __init__(self, n_in, activation=None):\n", " self.n_in = n_in\n", " self.act = activation\n", "\n", " self.w = [Node(random.random()) for _ in range(n_in)]\n", " self.b = Node(0.0)\n", " self._parameters = self.w + [self.b]\n", "\n", " def __call__(self, x: list):\n", " assert len(x) == self.n_in\n", " out = sum((x[j] * self.w[j] for j in range(self.n_in)), start=self.b)\n", " if self.act is not None:\n", " if self.act == \"tanh\":\n", " out = out.tanh()\n", " elif self.act == \"relu\":\n", " out = out.relu()\n", " else:\n", " raise NotImplementedError(\"Activation not supported.\")\n", " return out\n", "\n", " def __repr__(self):\n", " return f\"{self.act if self.act is not None else 'linear'}({len(self.w)})\"\n", "\n", "\n", "class Layer(Module):\n", " def __init__(self, n_in, n_out, *args):\n", " self.neurons = [Neuron(n_in, *args) for _ in range(n_out)]\n", " self._parameters = [p for n in self.neurons for p in n.parameters()]\n", "\n", " def __call__(self, x: list):\n", " out = [n(x) for n in self.neurons]\n", " return out[0] if len(out) == 1 else out\n", "\n", " def __repr__(self):\n", " return f\"Layer[{', '.join(str(n) for n in self.neurons)}]\"\n", "\n", "\n", "class MLP(Module):\n", " def __init__(self, n_in, n_outs, activation=None):\n", " sizes = [n_in] + n_outs\n", " self.layers = []\n", " for i in range(len(n_outs)):\n", " act = activation if i < len(n_outs) - 1 else None\n", " layer = Layer(sizes[i], sizes[i + 1], act)\n", " self.layers.append(layer)\n", "\n", " self._parameters = [p for layer in self.layers for p in layer.parameters()]\n", "\n", " def __call__(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return x\n", "\n", " def __repr__(self):\n", " return f\"MLP[{', '.join(str(layer) for layer in self.layers)}]\"" ] }, { "cell_type": "markdown", "id": "afd9d099", "metadata": { "papermill": { "duration": 0.002642, "end_time": "2024-09-05T18:30:34.639401", "exception": false, "start_time": "2024-09-05T18:30:34.636759", "status": "completed" }, "tags": [] }, "source": [ "Testing model init and model call. Note that the final node has no activation:" ] }, { "cell_type": "code", "execution_count": 4, "id": "b8e5e034", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:34.644690Z", "iopub.status.busy": "2024-09-05T18:30:34.644477Z", "iopub.status.idle": "2024-09-05T18:30:34.647599Z", "shell.execute_reply": "2024-09-05T18:30:34.647190Z" }, "papermill": { "duration": 0.007022, "end_time": "2024-09-05T18:30:34.648825", "exception": false, "start_time": "2024-09-05T18:30:34.641803", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP[Layer[relu(1), relu(1)], Layer[relu(2), relu(2), relu(2), relu(2)], Layer[linear(4)]]\n", "0.3551933990615996\n" ] } ], "source": [ "model = MLP(n_in=1, n_outs=[2, 4, 1], activation=\"relu\")\n", "x = Node(1.0)\n", "pred = model([x])\n", "pred.backward()\n", "\n", "print(model)\n", "print(pred.data)" ] }, { "cell_type": "code", "execution_count": 5, "id": "72723e9c", "metadata": { "execution": { "iopub.execute_input": "2024-09-05T18:30:34.653550Z", "iopub.status.busy": "2024-09-05T18:30:34.653257Z", "iopub.status.idle": "2024-09-05T18:30:34.785118Z", "shell.execute_reply": "2024-09-05T18:30:34.784615Z" }, "papermill": { "duration": 0.13583, "end_time": "2024-09-05T18:30:34.786706", "exception": false, "start_time": "2024-09-05T18:30:34.650876", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "4418517024\n", "\n", "data=0.000\n", "\n", "grad=0.5399\n", "\n", "\n", "\n", "4418511936+\n", "\n", "+\n", "\n", "\n", "\n", "4418517024->4418511936+\n", "\n", "\n", "\n", "\n", "\n", "4418510880\n", "\n", "data=0.488\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418509056*\n", "\n", "*\n", "\n", "\n", "\n", "4418510880->4418509056*\n", "\n", "\n", "\n", "\n", "\n", "4418510880relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418510880relu->4418510880\n", "\n", "\n", "\n", "\n", "\n", "4418508864\n", "\n", "data=0.015\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418508576+\n", "\n", "+\n", "\n", "\n", "\n", "4418508864->4418508576+\n", "\n", "\n", "\n", "\n", "\n", "4418508864*\n", "\n", "*\n", "\n", "\n", "\n", "4418508864*->4418508864\n", "\n", "\n", "\n", "\n", "\n", "4418517072\n", "\n", "data=0.639\n", "\n", "grad=0.5399\n", "\n", "\n", "\n", "4418512032*\n", "\n", "*\n", "\n", "\n", "\n", "4418517072->4418512032*\n", "\n", "\n", "\n", "\n", "\n", "4418510928\n", "\n", "data=0.471\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418510832+\n", "\n", "+\n", "\n", "\n", "\n", "4418510928->4418510832+\n", "\n", "\n", "\n", "\n", "\n", "4418510928*\n", "\n", "*\n", "\n", "\n", "\n", "4418510928*->4418510928\n", "\n", "\n", "\n", "\n", "\n", "4418508960\n", "\n", "data=0.286\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418508960->4418508576+\n", "\n", "\n", "\n", "\n", "\n", "4418508960+\n", "\n", "+\n", "\n", "\n", "\n", "4418508960+->4418508960\n", "\n", "\n", "\n", "\n", "\n", "4418513088\n", "\n", "data=0.030\n", "\n", "grad=0.0050\n", "\n", "\n", "\n", "4418509680*\n", "\n", "*\n", "\n", "\n", "\n", "4418513088->4418509680*\n", "\n", "\n", "\n", "\n", "\n", "4418509056\n", "\n", "data=0.247\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418509056->4418508960+\n", "\n", "\n", "\n", "\n", "\n", "4418509056*->4418509056\n", "\n", "\n", "\n", "\n", "\n", "4418513184\n", "\n", "data=0.000\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418509776+\n", "\n", "+\n", "\n", "\n", "\n", "4418513184->4418509776+\n", "\n", "\n", "\n", "\n", "\n", "4418511168\n", "\n", "data=0.181\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418511408relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418511168->4418511408relu\n", "\n", "\n", "\n", "\n", "\n", "4418511168+\n", "\n", "+\n", "\n", "\n", "\n", "4418511168+->4418511168\n", "\n", "\n", "\n", "\n", "\n", "4418517312\n", "\n", "data=0.275\n", "\n", "grad=0.1398\n", "\n", "\n", "\n", "4418511456*\n", "\n", "*\n", "\n", "\n", "\n", "4418517312->4418511456*\n", "\n", "\n", "\n", "\n", "\n", "4418513232\n", "\n", "data=0.422\n", "\n", "grad=0.1271\n", "\n", "\n", "\n", "4418509872*\n", "\n", "*\n", "\n", "\n", "\n", "4418513232->4418509872*\n", "\n", "\n", "\n", "\n", "\n", "4418517360\n", "\n", "data=1.000\n", "\n", "grad=0.3552\n", "\n", "\n", "\n", "4418511744*\n", "\n", "*\n", "\n", "\n", "\n", "4418517360->4418511744*\n", "\n", "\n", "\n", "\n", "\n", "4418517360->4418512032*\n", "\n", "\n", "\n", "\n", "\n", "4418516976\n", "\n", "data=0.000\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418516976->4418510832+\n", "\n", "\n", "\n", "\n", "\n", "4418511264\n", "\n", "data=0.006\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418511264->4418511168+\n", "\n", "\n", "\n", "\n", "\n", "4418511264*\n", "\n", "*\n", "\n", "\n", "\n", "4418511264*->4418511264\n", "\n", "\n", "\n", "\n", "\n", "4418509248\n", "\n", "data=0.040\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418509248->4418508960+\n", "\n", "\n", "\n", "\n", "\n", "4418509248+\n", "\n", "+\n", "\n", "\n", "\n", "4418509248+->4418509248\n", "\n", "\n", "\n", "\n", "\n", "4418515392\n", "\n", "data=0.736\n", "\n", "grad=0.3231\n", "\n", "\n", "\n", "4418515392->4418510928*\n", "\n", "\n", "\n", "\n", "\n", "4418511360\n", "\n", "data=0.176\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418511360->4418511168+\n", "\n", "\n", "\n", "\n", "\n", "4418511360+\n", "\n", "+\n", "\n", "\n", "\n", "4418511360+->4418511360\n", "\n", "\n", "\n", "\n", "\n", "4418509344\n", "\n", "data=0.040\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418509344->4418509248+\n", "\n", "\n", "\n", "\n", "\n", "4418509344*\n", "\n", "*\n", "\n", "\n", "\n", "4418509344*->4418509344\n", "\n", "\n", "\n", "\n", "\n", "4418511408\n", "\n", "data=0.181\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418511408->4418509344*\n", "\n", "\n", "\n", "\n", "\n", "4418511408relu->4418511408\n", "\n", "\n", "\n", "\n", "\n", "4418513472\n", "\n", "data=0.087\n", "\n", "grad=0.0007\n", "\n", "\n", "\n", "4418509920*\n", "\n", "*\n", "\n", "\n", "\n", "4418513472->4418509920*\n", "\n", "\n", "\n", "\n", "\n", "4418511456\n", "\n", "data=0.176\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418511456->4418511360+\n", "\n", "\n", "\n", "\n", "\n", "4418511456*->4418511456\n", "\n", "\n", "\n", "\n", "\n", "4418507424\n", "\n", "data=0.223\n", "\n", "grad=0.0055\n", "\n", "\n", "\n", "4418507424->4418511264*\n", "\n", "\n", "\n", "\n", "\n", "4418513568\n", "\n", "data=0.000\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510304+\n", "\n", "+\n", "\n", "\n", "\n", "4418513568->4418510304+\n", "\n", "\n", "\n", "\n", "\n", "4418513616\n", "\n", "data=0.892\n", "\n", "grad=0.0170\n", "\n", "\n", "\n", "4418510496*\n", "\n", "*\n", "\n", "\n", "\n", "4418513616->4418510496*\n", "\n", "\n", "\n", "\n", "\n", "4418509584\n", "\n", "data=0.271\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418509824relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418509584->4418509824relu\n", "\n", "\n", "\n", "\n", "\n", "4418509584+\n", "\n", "+\n", "\n", "\n", "\n", "4418509584+->4418509584\n", "\n", "\n", "\n", "\n", "\n", "4418511648\n", "\n", "data=0.025\n", "\n", "grad=0.3990\n", "\n", "\n", "\n", "4418511696relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418511648->4418511696relu\n", "\n", "\n", "\n", "\n", "\n", "4418511648+\n", "\n", "+\n", "\n", "\n", "\n", "4418511648+->4418511648\n", "\n", "\n", "\n", "\n", "\n", "4418511696\n", "\n", "data=0.025\n", "\n", "grad=0.3990\n", "\n", "\n", "\n", "4418511696->4418511264*\n", "\n", "\n", "\n", "\n", "\n", "4418511696->4418509680*\n", "\n", "\n", "\n", "\n", "\n", "4418511696->4418509920*\n", "\n", "\n", "\n", "\n", "\n", "4418510736*\n", "\n", "*\n", "\n", "\n", "\n", "4418511696->4418510736*\n", "\n", "\n", "\n", "\n", "\n", "4418511696relu->4418511696\n", "\n", "\n", "\n", "\n", "\n", "4418509680\n", "\n", "data=0.001\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418509680->4418509584+\n", "\n", "\n", "\n", "\n", "\n", "4418509680*->4418509680\n", "\n", "\n", "\n", "\n", "\n", "4418511744\n", "\n", "data=0.025\n", "\n", "grad=0.3990\n", "\n", "\n", "\n", "4418511744->4418511648+\n", "\n", "\n", "\n", "\n", "\n", "4418511744*->4418511744\n", "\n", "\n", "\n", "\n", "\n", "4418513808\n", "\n", "data=0.000\n", "\n", "grad=0.2186\n", "\n", "\n", "\n", "4418513808->4418511360+\n", "\n", "\n", "\n", "\n", "\n", "4418509776\n", "\n", "data=0.270\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418509776->4418509584+\n", "\n", "\n", "\n", "\n", "\n", "4418509776+->4418509776\n", "\n", "\n", "\n", "\n", "\n", "4418509824\n", "\n", "data=0.271\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418508768*\n", "\n", "*\n", "\n", "\n", "\n", "4418509824->4418508768*\n", "\n", "\n", "\n", "\n", "\n", "4418509824relu->4418509824\n", "\n", "\n", "\n", "\n", "\n", "4418509872\n", "\n", "data=0.270\n", "\n", "grad=0.1988\n", "\n", "\n", "\n", "4418509872->4418509776+\n", "\n", "\n", "\n", "\n", "\n", "4418509872*->4418509872\n", "\n", "\n", "\n", "\n", "\n", "4418511936\n", "\n", "data=0.639\n", "\n", "grad=0.5399\n", "\n", "\n", "\n", "4418511984relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418511936->4418511984relu\n", "\n", "\n", "\n", "\n", "\n", "4418511936+->4418511936\n", "\n", "\n", "\n", "\n", "\n", "4418509920\n", "\n", "data=0.002\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510112+\n", "\n", "+\n", "\n", "\n", "\n", "4418509920->4418510112+\n", "\n", "\n", "\n", "\n", "\n", "4418509920*->4418509920\n", "\n", "\n", "\n", "\n", "\n", "4418511984\n", "\n", "data=0.639\n", "\n", "grad=0.5399\n", "\n", "\n", "\n", "4418511984->4418510928*\n", "\n", "\n", "\n", "\n", "\n", "4418511984->4418511456*\n", "\n", "\n", "\n", "\n", "\n", "4418511984->4418509872*\n", "\n", "\n", "\n", "\n", "\n", "4418511984->4418510496*\n", "\n", "\n", "\n", "\n", "\n", "4418511984relu->4418511984\n", "\n", "\n", "\n", "\n", "\n", "4418512032\n", "\n", "data=0.639\n", "\n", "grad=0.5399\n", "\n", "\n", "\n", "4418512032->4418511936+\n", "\n", "\n", "\n", "\n", "\n", "4418512032*->4418512032\n", "\n", "\n", "\n", "\n", "\n", "4418516880\n", "\n", "data=0.677\n", "\n", "grad=0.0126\n", "\n", "\n", "\n", "4418516880->4418510736*\n", "\n", "\n", "\n", "\n", "\n", "4418510112\n", "\n", "data=0.573\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510352relu\n", "\n", "relu\n", "\n", "\n", "\n", "4418510112->4418510352relu\n", "\n", "\n", "\n", "\n", "\n", "4418510112+->4418510112\n", "\n", "\n", "\n", "\n", "\n", "4418514240\n", "\n", "data=0.000\n", "\n", "grad=0.3990\n", "\n", "\n", "\n", "4418514240->4418511648+\n", "\n", "\n", "\n", "\n", "\n", "4418514336\n", "\n", "data=0.025\n", "\n", "grad=0.3990\n", "\n", "\n", "\n", "4418514336->4418511744*\n", "\n", "\n", "\n", "\n", "\n", "4418510304\n", "\n", "data=0.570\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510304->4418510112+\n", "\n", "\n", "\n", "\n", "\n", "4418510304+->4418510304\n", "\n", "\n", "\n", "\n", "\n", "4418510352\n", "\n", "data=0.573\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510352->4418508864*\n", "\n", "\n", "\n", "\n", "\n", "4418510352relu->4418510352\n", "\n", "\n", "\n", "\n", "\n", "4418512416\n", "\n", "data=0.199\n", "\n", "grad=0.2705\n", "\n", "\n", "\n", "4418512416->4418508768*\n", "\n", "\n", "\n", "\n", "\n", "4418512512\n", "\n", "data=0.027\n", "\n", "grad=0.5727\n", "\n", "\n", "\n", "4418512512->4418508864*\n", "\n", "\n", "\n", "\n", "\n", "4418510496\n", "\n", "data=0.570\n", "\n", "grad=0.0265\n", "\n", "\n", "\n", "4418510496->4418510304+\n", "\n", "\n", "\n", "\n", "\n", "4418510496*->4418510496\n", "\n", "\n", "\n", "\n", "\n", "4418512608\n", "\n", "data=0.505\n", "\n", "grad=0.4878\n", "\n", "\n", "\n", "4418512608->4418509056*\n", "\n", "\n", "\n", "\n", "\n", "4418508576\n", "\n", "data=0.301\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418508672+\n", "\n", "+\n", "\n", "\n", "\n", "4418508576->4418508672+\n", "\n", "\n", "\n", "\n", "\n", "4418508576+->4418508576\n", "\n", "\n", "\n", "\n", "\n", "4418510640\n", "\n", "data=0.488\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418510640->4418510880relu\n", "\n", "\n", "\n", "\n", "\n", "4418510640+\n", "\n", "+\n", "\n", "\n", "\n", "4418510640+->4418510640\n", "\n", "\n", "\n", "\n", "\n", "4418512704\n", "\n", "data=0.000\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418512704->4418509248+\n", "\n", "\n", "\n", "\n", "\n", "4418512752\n", "\n", "data=0.219\n", "\n", "grad=0.1814\n", "\n", "\n", "\n", "4418512752->4418509344*\n", "\n", "\n", "\n", "\n", "\n", "4418508672\n", "\n", "data=0.355\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418508672+->4418508672\n", "\n", "\n", "\n", "\n", "\n", "4418510736\n", "\n", "data=0.017\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418510736->4418510640+\n", "\n", "\n", "\n", "\n", "\n", "4418510736*->4418510736\n", "\n", "\n", "\n", "\n", "\n", "4418508768\n", "\n", "data=0.054\n", "\n", "grad=1.0000\n", "\n", "\n", "\n", "4418508768->4418508672+\n", "\n", "\n", "\n", "\n", "\n", "4418508768*->4418508768\n", "\n", "\n", "\n", "\n", "\n", "4418510832\n", "\n", "data=0.471\n", "\n", "grad=0.5054\n", "\n", "\n", "\n", "4418510832->4418510640+\n", "\n", "\n", "\n", "\n", "\n", "4418510832+->4418510832\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "4400d266", "metadata": { "papermill": { "duration": 0.003607, "end_time": "2024-09-05T18:30:34.794489", "exception": false, "start_time": "2024-09-05T18:30:34.790882", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 1.381813, "end_time": "2024-09-05T18:30:34.914611", "environment_variables": {}, "exception": null, "input_path": "00b-neural-net-module.ipynb", "output_path": "00b-neural-net-module.ipynb", "parameters": {}, "start_time": "2024-09-05T18:30:33.532798", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }