{ "cells": [ { "cell_type": "markdown", "id": "a5a94bf9", "metadata": { "papermill": { "duration": 0.005199, "end_time": "2024-11-27T10:39:55.818412", "exception": false, "start_time": "2024-11-27T10:39:55.813213", "status": "completed" }, "tags": [] }, "source": [ "# Convolution layer" ] }, { "cell_type": "markdown", "id": "4a871151", "metadata": { "papermill": { "duration": 0.017831, "end_time": "2024-11-27T10:39:55.840507", "exception": false, "start_time": "2024-11-27T10:39:55.822676", "status": "completed" }, "tags": [] }, "source": [ "Digital images have [multiple channels](https://en.wikipedia.org/wiki/Channel_(digital_image)). The **convolution layer** extends the convolution operation to handle feature maps with multiple **channels**. The output feature map similarly has channels adding a further semantic dimension to the downstream representation. For an RGB image, a convolution layer learns three 2-dimensional kernels $\\boldsymbol{\\mathsf{K}}_{lc}$ for each output channel, each of which can be thought of as a **feature extractor**. Features across input channels are blended by the kernel:\n", "\n", "$$\n", "\\begin{aligned}\n", "{\\bar{\\boldsymbol{\\mathsf X}}}_{lij}\n", "&= {\\boldsymbol{\\mathsf u}}_{l} + \\sum_{c=0}^{{c}_\\text{in}-1} ({\\boldsymbol{\\mathsf X}}_{[c,\\,:,\\, :]} \\circledast {\\boldsymbol{\\mathsf K}}_{[l,\\,{c},\\, :,\\,:]})_{ij} \\\\\n", "&= {\\boldsymbol{\\mathsf u}}_{l} + \\sum_{c=0}^{{c}_\\text{in}-1}\\sum_{x = 0}^{{k}-1} \\sum_{y=0}^{{k}-1} {\\boldsymbol{\\mathsf X}}_{c,\\, i + x,\\, j + y} \\, {\\boldsymbol{\\mathsf K}}_{lcxy} \\\\\n", "\\end{aligned}\n", "$$\n", "\n", "for $l = 0, \\ldots, {c}_\\text{out}-1$. The input and output tensors $\\boldsymbol{\\mathsf{X}}$ and $\\bar{\\boldsymbol{\\mathsf{X}}}$ have the same dimensionality and semantic structure which makes sense since we want to stack convolutional layers as modules, and the kernel $\\boldsymbol{\\mathsf{K}}$ has shape $({c}_\\text{out}, {c}_\\text{in}, {k}, {k}).$ The resulting feature maps inherit the spatial ordering in its inputs along the spatial dimensions. The entire operation is linear and each convolution operation is independent for each output channel. \n", "\n", "**Remark.** This form is called two-dimensional convolution since the kernel scans two dimensions. Meanwhile, one-dimensional convolutions can be used to process sequential data. In principle, we can add as many dimensions as required." ] }, { "cell_type": "markdown", "id": "31bf34a3", "metadata": { "papermill": { "duration": 0.001955, "end_time": "2024-11-27T10:39:55.843790", "exception": false, "start_time": "2024-11-27T10:39:55.841835", "status": "completed" }, "tags": [] }, "source": [ "## Input and output channels" ] }, { "cell_type": "code", "execution_count": 1, "id": "0f5c9c7a", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:39:55.847048Z", "iopub.status.busy": "2024-11-27T10:39:55.846878Z", "iopub.status.idle": "2024-11-27T10:39:58.248255Z", "shell.execute_reply": "2024-11-27T10:39:58.247887Z" }, "papermill": { "duration": 2.41057, "end_time": "2024-11-27T10:39:58.255509", "exception": false, "start_time": "2024-11-27T10:39:55.844939", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import random\n",
       "import warnings\n",
       "from pathlib import Path\n",
       "\n",
       "import numpy as np\n",
       "import pandas as pd\n",
       "import matplotlib.pyplot as plt\n",
       "import matplotlib\n",
       "from matplotlib_inline import backend_inline\n",
       "\n",
       "import torch\n",
       "import torch.nn as nn\n",
       "import torch.nn.functional as F\n",
       "\n",
       "DATASET_DIR = Path("./data/").resolve()\n",
       "DATASET_DIR.mkdir(exist_ok=True)\n",
       "warnings.simplefilter(action="ignore")\n",
       "backend_inline.set_matplotlib_formats("svg")\n",
       "matplotlib.rcParams["image.interpolation"] = "nearest"\n",
       "\n",
       "RANDOM_SEED = 0\n",
       "random.seed(RANDOM_SEED)\n",
       "np.random.seed(RANDOM_SEED)\n",
       "torch.manual_seed(RANDOM_SEED)\n",
       "DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{random}\n", "\\PY{k+kn}{import} \\PY{n+nn}{warnings}\n", "\\PY{k+kn}{from} \\PY{n+nn}{pathlib} \\PY{k+kn}{import} \\PY{n}{Path}\n", "\n", "\\PY{k+kn}{import} \\PY{n+nn}{numpy} \\PY{k}{as} \\PY{n+nn}{np}\n", "\\PY{k+kn}{import} \\PY{n+nn}{pandas} \\PY{k}{as} \\PY{n+nn}{pd}\n", "\\PY{k+kn}{import} \\PY{n+nn}{matplotlib}\\PY{n+nn}{.}\\PY{n+nn}{pyplot} \\PY{k}{as} \\PY{n+nn}{plt}\n", "\\PY{k+kn}{import} \\PY{n+nn}{matplotlib}\n", "\\PY{k+kn}{from} \\PY{n+nn}{matplotlib\\PYZus{}inline} \\PY{k+kn}{import} \\PY{n}{backend\\PYZus{}inline}\n", "\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn} \\PY{k}{as} \\PY{n+nn}{nn}\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional} \\PY{k}{as} \\PY{n+nn}{F}\n", "\n", "\\PY{n}{DATASET\\PYZus{}DIR} \\PY{o}{=} \\PY{n}{Path}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{./data/}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{o}{.}\\PY{n}{resolve}\\PY{p}{(}\\PY{p}{)}\n", "\\PY{n}{DATASET\\PYZus{}DIR}\\PY{o}{.}\\PY{n}{mkdir}\\PY{p}{(}\\PY{n}{exist\\PYZus{}ok}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{)}\n", "\\PY{n}{warnings}\\PY{o}{.}\\PY{n}{simplefilter}\\PY{p}{(}\\PY{n}{action}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{ignore}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{backend\\PYZus{}inline}\\PY{o}{.}\\PY{n}{set\\PYZus{}matplotlib\\PYZus{}formats}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{svg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{matplotlib}\\PY{o}{.}\\PY{n}{rcParams}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{image.interpolation}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]} \\PY{o}{=} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{nearest}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", "\\PY{n}{RANDOM\\PYZus{}SEED} \\PY{o}{=} \\PY{l+m+mi}{0}\n", "\\PY{n}{random}\\PY{o}{.}\\PY{n}{seed}\\PY{p}{(}\\PY{n}{RANDOM\\PYZus{}SEED}\\PY{p}{)}\n", "\\PY{n}{np}\\PY{o}{.}\\PY{n}{random}\\PY{o}{.}\\PY{n}{seed}\\PY{p}{(}\\PY{n}{RANDOM\\PYZus{}SEED}\\PY{p}{)}\n", "\\PY{n}{torch}\\PY{o}{.}\\PY{n}{manual\\PYZus{}seed}\\PY{p}{(}\\PY{n}{RANDOM\\PYZus{}SEED}\\PY{p}{)}\n", "\\PY{n}{DEVICE} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{k}{if} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{cuda}\\PY{o}{.}\\PY{n}{is\\PYZus{}available}\\PY{p}{(}\\PY{p}{)} \\PY{k}{else} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{mps}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{k}{if} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{backends}\\PY{o}{.}\\PY{n}{mps}\\PY{o}{.}\\PY{n}{is\\PYZus{}available}\\PY{p}{(}\\PY{p}{)} \\PY{k}{else} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cpu}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import random\n", "import warnings\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from matplotlib_inline import backend_inline\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "DATASET_DIR = Path(\"./data/\").resolve()\n", "DATASET_DIR.mkdir(exist_ok=True)\n", "warnings.simplefilter(action=\"ignore\")\n", "backend_inline.set_matplotlib_formats(\"svg\")\n", "matplotlib.rcParams[\"image.interpolation\"] = \"nearest\"\n", "\n", "RANDOM_SEED = 0\n", "random.seed(RANDOM_SEED)\n", "np.random.seed(RANDOM_SEED)\n", "torch.manual_seed(RANDOM_SEED)\n", "DEVICE = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import random\n", "import warnings\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from matplotlib_inline import backend_inline\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "DATASET_DIR = Path(\"./data/\").resolve()\n", "DATASET_DIR.mkdir(exist_ok=True)\n", "warnings.simplefilter(action=\"ignore\")\n", "backend_inline.set_matplotlib_formats(\"svg\")\n", "matplotlib.rcParams[\"image.interpolation\"] = \"nearest\"\n", "\n", "RANDOM_SEED = 0\n", "random.seed(RANDOM_SEED)\n", "np.random.seed(RANDOM_SEED)\n", "torch.manual_seed(RANDOM_SEED)\n", "DEVICE = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "id": "457e8bdb", "metadata": { "papermill": { "duration": 0.001557, "end_time": "2024-11-27T10:39:58.259810", "exception": false, "start_time": "2024-11-27T10:39:58.258253", "status": "completed" }, "tags": [] }, "source": [ "Getting a sample image from [our repo](https://github.com/particle1331/ok-transformer):" ] }, { "cell_type": "code", "execution_count": 2, "id": "20e3d9c1", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:39:58.267237Z", "iopub.status.busy": "2024-11-27T10:39:58.266319Z", "iopub.status.idle": "2024-11-27T10:39:59.107917Z", "shell.execute_reply": "2024-11-27T10:39:59.105231Z" }, "papermill": { "duration": 0.849932, "end_time": "2024-11-27T10:39:59.111807", "exception": false, "start_time": "2024-11-27T10:39:58.261875", "status": "completed" }, "tags": [ "hide-output" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " % Total % Received % Xferd Average Speed Time Time Time Current\r\n", " Dload Upload Total Spent Left Speed\r\n", "\r", " 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0", "\r", " 87 881k 87 768k 0 0 1175k 0 --:--:-- --:--:-- --:--:-- 1174k\r", "100 881k 100 881k 0 0 1331k 0 --:--:-- --:--:-- --:--:-- 1331k\r\n" ] } ], "source": [ "!curl \"https://raw.githubusercontent.com/particle1331/ok-transformer/master/docs/img/shorty.png\" --output ./data/shorty.png" ] }, { "cell_type": "markdown", "id": "2d0dc6bf", "metadata": { "papermill": { "duration": 0.033951, "end_time": "2024-11-27T10:39:59.150899", "exception": false, "start_time": "2024-11-27T10:39:59.116948", "status": "completed" }, "tags": [] }, "source": [ "Reproducing the convolution operation over input and output channels:" ] }, { "cell_type": "code", "execution_count": 3, "id": "b19cf4f1", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:39:59.177730Z", "iopub.status.busy": "2024-11-27T10:39:59.177471Z", "iopub.status.idle": "2024-11-27T10:40:00.017640Z", "shell.execute_reply": "2024-11-27T10:40:00.017110Z" }, "papermill": { "duration": 0.859441, "end_time": "2024-11-27T10:40:00.021964", "exception": false, "start_time": "2024-11-27T10:39:59.162523", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from torchvision.io import read_image\n", "import torchvision.transforms.functional as fn\n", "\n", "def convolve(X, K):\n", " \"\"\"Perform 2D convolution over input.\"\"\"\n", " h, w = K.shape\n", " H0, W0 = X.shape\n", " H1 = H0 - h + 1\n", " W1 = W0 - w + 1\n", "\n", " S = np.zeros(shape=(H1, W1))\n", " for i in range(H1):\n", " for j in range(W1):\n", " S[i, j] = (X[i:i+h, j:j+w] * K).sum()\n", " \n", " return torch.tensor(S)" ] }, { "cell_type": "markdown", "id": "c5526695", "metadata": { "papermill": { "duration": 0.001318, "end_time": "2024-11-27T10:40:00.026292", "exception": false, "start_time": "2024-11-27T10:40:00.024974", "status": "completed" }, "tags": [] }, "source": [ "Decomposing the feature maps:" ] }, { "cell_type": "code", "execution_count": 4, "id": "2dbed10b", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:00.030173Z", "iopub.status.busy": "2024-11-27T10:40:00.029834Z", "iopub.status.idle": "2024-11-27T10:40:02.730580Z", "shell.execute_reply": "2024-11-27T10:40:02.729978Z" }, "papermill": { "duration": 2.710516, "end_time": "2024-11-27T10:40:02.738188", "exception": false, "start_time": "2024-11-27T10:40:00.027672", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-11-27T18:40:02.295914\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@torch.no_grad()\n", "def conv_components(X, K, u):\n", " cmaps = [\"Reds\", \"Greens\", \"Blues\"]\n", " cmaps_out = [\"spring\", \"summer\", \"autumn\", \"winter\"]\n", " c_in = X.shape[1]\n", " c_out = K.shape[0]\n", " \n", " fig, ax = plt.subplots(c_in + 1, c_out + 1, figsize=(10, 10))\n", "\n", " # Input image\n", " ax[0, 0].imshow(X[0].permute(1, 2, 0))\n", " for c in range(c_in):\n", " ax[c+1, 0].set_title(f\"X(c={c})\", size=10)\n", " ax[c+1, 0].imshow(X[0, c, :, :], cmap=cmaps[c])\n", "\n", " # Iterate over kernel filters\n", " out_components = {}\n", " for k in range(c_out):\n", " for c in range(c_in):\n", " out_components[(c, k)] = convolve(X[0, c, :, :], K[k, c, :, :])\n", " ax[c+1, k+1].imshow(out_components[(c, k)].numpy()) \n", " ax[c+1, k+1].set_title(f\"X(c={c}) \u229b K(c={c}, k={k})\", size=10)\n", "\n", " # Sum convolutions over input channels, then add bias\n", " out_maps = []\n", " for k in range(c_out):\n", " out_maps.append(sum([out_components[(c, k)] for c in range(c_in)]) + u[k])\n", " ax[0, k+1].imshow(out_maps[k].numpy(), cmaps_out[k])\n", " ax[0, k+1].set_title(r\"$\\bar{\\mathrm{X}}$\" + f\"(k={k})\", size=10)\n", "\n", " fig.tight_layout()\n", " return out_maps\n", "\n", "\n", "cat = DATASET_DIR / \"shorty.png\"\n", "X = read_image(str(cat)).unsqueeze(0)[:, :3, :, :]\n", "X = fn.resize(X, size=(128, 128)) / 255.\n", "conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=5)\n", "K, u = conv.weight, conv.bias\n", "components = conv_components(X, K, u);" ] }, { "cell_type": "markdown", "id": "3f1394db", "metadata": { "papermill": { "duration": 0.010359, "end_time": "2024-11-27T10:40:02.756593", "exception": false, "start_time": "2024-11-27T10:40:02.746234", "status": "completed" }, "tags": [] }, "source": [ "**Figure.** Each kernel in entries `i,j > 0` combines column-wise with the inputs to compute `X(c=i) \u229b K(c=i, k=j)`. The sum of these terms over `c` form the output map `X\u0305(k=j)` above. \n", "This looks like computation in a fully-connected layer, but with convolutions between matrices instead of products between scalars. CNNs perform combinatorial mixing of hierarchical spatial features with depth.\n", "\n", "Checking if the output of `conv_components` is consistent with `Conv2d` in PyTorch:" ] }, { "cell_type": "code", "execution_count": 5, "id": "c9c15c84", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:02.769458Z", "iopub.status.busy": "2024-11-27T10:40:02.769054Z", "iopub.status.idle": "2024-11-27T10:40:02.840584Z", "shell.execute_reply": "2024-11-27T10:40:02.840083Z" }, "papermill": { "duration": 0.078991, "end_time": "2024-11-27T10:40:02.841779", "exception": false, "start_time": "2024-11-27T10:40:02.762788", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: torch.Size([1, 3, 128, 128])\n", "Output shape: torch.Size([1, 4, 124, 124])\n", "Kernel shape: torch.Size([4, 3, 5, 5])\n", "Bias shape: torch.Size([4])\n", "\n", "MAE (w/ pytorch) = 2.424037095549667e-08\n" ] } ], "source": [ "S = torch.stack(components).unsqueeze(0)\n", "P = conv(X)\n", "cmaps_out = [\"spring\", \"summer\", \"autumn\", \"winter\"]\n", "\n", "print(\"Input shape: \", X.shape) # (B, c0, H0, W0)\n", "print(\"Output shape:\", S.shape) # (B, c1, H1, W1)\n", "print(\"Kernel shape:\", K.shape) # (c1, c0, h, w)\n", "print(\"Bias shape: \", u.shape) # (c1,)\n", "print()\n", "print(\"MAE (w/ pytorch) =\", (S - P).abs().mean().item())" ] }, { "cell_type": "markdown", "id": "cb74c0e4", "metadata": { "papermill": { "duration": 0.00989, "end_time": "2024-11-27T10:40:02.858319", "exception": false, "start_time": "2024-11-27T10:40:02.848429", "status": "completed" }, "tags": [] }, "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 8.697032, "end_time": "2024-11-27T10:40:03.688077", "environment_variables": {}, "exception": null, "input_path": "03aa-conv-layer.ipynb", "output_path": "03aa-conv-layer.ipynb", "parameters": {}, "start_time": "2024-11-27T10:39:54.991045", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }