{ "cells": [ { "cell_type": "markdown", "id": "1d5e868e", "metadata": { "papermill": { "duration": 0.016671, "end_time": "2024-11-27T10:41:49.427398", "exception": false, "start_time": "2024-11-27T10:41:49.410727", "status": "completed" }, "tags": [] }, "source": [ "# Data augmentation" ] }, { "cell_type": "code", "execution_count": 1, "id": "74d13f43", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:49.466555Z", "iopub.status.busy": "2024-11-27T10:41:49.466085Z", "iopub.status.idle": "2024-11-27T10:41:51.447521Z", "shell.execute_reply": "2024-11-27T10:41:51.447134Z" }, "papermill": { "duration": 1.993487, "end_time": "2024-11-27T10:41:51.454917", "exception": false, "start_time": "2024-11-27T10:41:49.461430", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "markdown", "id": "b7a9fa84", "metadata": { "papermill": { "duration": 0.004024, "end_time": "2024-11-27T10:41:51.575401", "exception": false, "start_time": "2024-11-27T10:41:51.571377", "status": "completed" }, "tags": [] }, "source": [ "MNIST is too nice to be representative of real-world datasets. Below we continue with a more realistic Kaggle dataset, [Histopathologic Cancer Detection](https://www.kaggle.com/competitions/histopathologic-cancer-detection/data). The task is to detect metastatic cancer in patches of images from digital pathology scans.\n", "Download the dataset such that the folder structure looks as follows:" ] }, { "cell_type": "markdown", "id": "25b73ee8", "metadata": { "papermill": { "duration": 0.003766, "end_time": "2024-11-27T10:41:51.609831", "exception": false, "start_time": "2024-11-27T10:41:51.606065", "status": "completed" }, "tags": [] }, "source": [ "```\n", "./data/histopathologic-cancer-detection\n", "├── test\n", "├── train\n", "└── train_labels.csv\n", "```" ] }, { "cell_type": "markdown", "id": "00777bf6", "metadata": { "papermill": { "duration": 0.001424, "end_time": "2024-11-27T10:41:51.613252", "exception": false, "start_time": "2024-11-27T10:41:51.611828", "status": "completed" }, "tags": [] }, "source": [ "Taking a look at the first few images:" ] }, { "cell_type": "code", "execution_count": 2, "id": "485b1c3b", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:51.625135Z", "iopub.status.busy": "2024-11-27T10:41:51.623091Z", "iopub.status.idle": "2024-11-27T10:41:52.013777Z", "shell.execute_reply": "2024-11-27T10:41:52.013346Z" }, "papermill": { "duration": 0.399907, "end_time": "2024-11-27T10:41:52.015099", "exception": false, "start_time": "2024-11-27T10:41:51.615192", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import cv2\n",
       "\n",
       "IMG_DATASET_DIR = DATASET_DIR / "histopathologic-cancer-detection"\n",
       "data = pd.read_csv(IMG_DATASET_DIR / "train_labels.csv")\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{cv2}\n", "\n", "\\PY{n}{IMG\\PYZus{}DATASET\\PYZus{}DIR} \\PY{o}{=} \\PY{n}{DATASET\\PYZus{}DIR} \\PY{o}{/} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{histopathologic\\PYZhy{}cancer\\PYZhy{}detection}\\PY{l+s+s2}{\\PYZdq{}}\n", "\\PY{n}{data} \\PY{o}{=} \\PY{n}{pd}\\PY{o}{.}\\PY{n}{read\\PYZus{}csv}\\PY{p}{(}\\PY{n}{IMG\\PYZus{}DATASET\\PYZus{}DIR} \\PY{o}{/} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{train\\PYZus{}labels.csv}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import cv2\n", "\n", "IMG_DATASET_DIR = DATASET_DIR / \"histopathologic-cancer-detection\"\n", "data = pd.read_csv(IMG_DATASET_DIR / \"train_labels.csv\")" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import cv2\n", "\n", "IMG_DATASET_DIR = DATASET_DIR / \"histopathologic-cancer-detection\"\n", "data = pd.read_csv(IMG_DATASET_DIR / \"train_labels.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "7805b2ed", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:52.019918Z", "iopub.status.busy": "2024-11-27T10:41:52.019738Z", "iopub.status.idle": "2024-11-27T10:41:52.577676Z", "shell.execute_reply": "2024-11-27T10:41:52.574672Z" }, "papermill": { "duration": 0.563358, "end_time": "2024-11-27T10:41:52.580735", "exception": false, "start_time": "2024-11-27T10:41:52.017377", "status": "completed" }, "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-11-27T18:41:52.428360\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(3, 5, figsize=(6, 4.5))\n", "for k in range(15):\n", " i, j = divmod(k, 5)\n", " fname = str(IMG_DATASET_DIR / \"train\" / f\"{data.id[k]}.tif\")\n", " ax[i, j].imshow(cv2.imread(fname))\n", " ax[i, j].set_title(data.label[k], size=10)\n", " ax[i, j].axis(\"off\")\n", "fig.tight_layout()" ] }, { "cell_type": "markdown", "id": "d08ae705", "metadata": { "papermill": { "duration": 0.005613, "end_time": "2024-11-27T10:41:52.599400", "exception": false, "start_time": "2024-11-27T10:41:52.593787", "status": "completed" }, "tags": [] }, "source": [ "A positive label indicates that the center 32 × 32 region of a patch contains at least one pixel of tumor tissue. Tumor tissue in the outer region of the patch does not influence the label.\n", "This outer region is provided to enable fully-convolutional models that do not use zero-padding, to ensure consistent behavior when applied to a whole-slide image." ] }, { "cell_type": "markdown", "id": "96ebba59", "metadata": { "papermill": { "duration": 0.007525, "end_time": "2024-11-27T10:41:52.612364", "exception": false, "start_time": "2024-11-27T10:41:52.604839", "status": "completed" }, "tags": [] }, "source": [ "## Stochastic transforms\n", "\n", "Data augmentation incorporates transformed or perturbed versions of the original images into the dataset. More precisely, each data point\n", "$(\\boldsymbol{\\mathsf{x}}, y)$ in a mini-batch is replaced by $(T(\\boldsymbol{\\mathsf{x}}), y)$ during training\n", "where $T$ is a stochastic label preserving transformation. At inference, an input $\\boldsymbol{\\mathsf{x}}$ is replaced by $\\mathbb{E}[T(\\boldsymbol{\\mathsf{x}})].$" ] }, { "cell_type": "code", "execution_count": 4, "id": "4b0fc95a", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:52.626763Z", "iopub.status.busy": "2024-11-27T10:41:52.626344Z", "iopub.status.idle": "2024-11-27T10:41:52.632660Z", "shell.execute_reply": "2024-11-27T10:41:52.631982Z" }, "papermill": { "duration": 0.015748, "end_time": "2024-11-27T10:41:52.633963", "exception": false, "start_time": "2024-11-27T10:41:52.618215", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
transform_train = transforms.Compose([\n",
       "    transforms.ToTensor(),\n",
       "    transforms.RandomHorizontalFlip(),\n",
       "    transforms.RandomVerticalFlip(),\n",
       "    transforms.RandomRotation(20),\n",
       "    transforms.CenterCrop([49, 49]),\n",
       "])\n",
       "\n",
       "transform_infer = transforms.Compose([\n",
       "    transforms.ToTensor(),\n",
       "    transforms.CenterCrop([49, 49]),\n",
       "])\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n}{transform\\PYZus{}train} \\PY{o}{=} \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{Compose}\\PY{p}{(}\\PY{p}{[}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{ToTensor}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{RandomHorizontalFlip}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{RandomVerticalFlip}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{RandomRotation}\\PY{p}{(}\\PY{l+m+mi}{20}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{CenterCrop}\\PY{p}{(}\\PY{p}{[}\\PY{l+m+mi}{49}\\PY{p}{,} \\PY{l+m+mi}{49}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n", "\\PY{p}{]}\\PY{p}{)}\n", "\n", "\\PY{n}{transform\\PYZus{}infer} \\PY{o}{=} \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{Compose}\\PY{p}{(}\\PY{p}{[}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{ToTensor}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{n}{transforms}\\PY{o}{.}\\PY{n}{CenterCrop}\\PY{p}{(}\\PY{p}{[}\\PY{l+m+mi}{49}\\PY{p}{,} \\PY{l+m+mi}{49}\\PY{p}{]}\\PY{p}{)}\\PY{p}{,}\n", "\\PY{p}{]}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "transform_train = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomVerticalFlip(),\n", " transforms.RandomRotation(20),\n", " transforms.CenterCrop([49, 49]),\n", "])\n", "\n", "transform_infer = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.CenterCrop([49, 49]),\n", "])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "transform_train = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.RandomVerticalFlip(),\n", " transforms.RandomRotation(20),\n", " transforms.CenterCrop([49, 49]),\n", "])\n", "\n", "transform_infer = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.CenterCrop([49, 49]),\n", "])" ] }, { "cell_type": "markdown", "id": "e4a46fe6", "metadata": { "papermill": { "duration": 0.004062, "end_time": "2024-11-27T10:41:52.643826", "exception": false, "start_time": "2024-11-27T10:41:52.639764", "status": "completed" }, "tags": [] }, "source": [ "Since only the central pixels affect the labels by design, we use center crop. Furthermore, we know tissue samples in the slides can be flipped horizontally and vertically, as well as rotated (set to $\\pm 20^{\\circ}$ above) without affecting the actual presence of tumor tissue." ] }, { "cell_type": "code", "execution_count": 5, "id": "0b95d89c", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:52.653497Z", "iopub.status.busy": "2024-11-27T10:41:52.653297Z", "iopub.status.idle": "2024-11-27T10:41:53.502889Z", "shell.execute_reply": "2024-11-27T10:41:53.502390Z" }, "papermill": { "duration": 0.856296, "end_time": "2024-11-27T10:41:53.504790", "exception": false, "start_time": "2024-11-27T10:41:52.648494", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
from torch.utils.data import DataLoader, Dataset, Subset\n",
       "\n",
       "class HistopathologicDataset(Dataset):\n",
       "    def __init__(self, data, train=True, transform=None):\n",
       "        split = "train" if train else "test"\n",
       "        self.fnames = [str(IMG_DATASET_DIR / split / f"{fn}.tif") for fn in data.id]\n",
       "        self.labels = data.label.tolist()\n",
       "        self.transform = transform\n",
       "    \n",
       "    def __len__(self):\n",
       "        return len(self.fnames)\n",
       "    \n",
       "    def __getitem__(self, index):\n",
       "        img = cv2.imread(self.fnames[index])\n",
       "        if self.transform:\n",
       "            img = self.transform(img)\n",
       "        \n",
       "        return img, self.labels[index]\n",
       "\n",
       "\n",
       "data = data.sample(frac=1.0)\n",
       "split = int(0.80 * len(data))\n",
       "ds_train = HistopathologicDataset(data[:split], train=True, transform=transform_train)\n",
       "ds_valid = HistopathologicDataset(data[split:], train=True, transform=transform_infer)\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{from} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{utils}\\PY{n+nn}{.}\\PY{n+nn}{data} \\PY{k+kn}{import} \\PY{n}{DataLoader}\\PY{p}{,} \\PY{n}{Dataset}\\PY{p}{,} \\PY{n}{Subset}\n", "\n", "\\PY{k}{class} \\PY{n+nc}{HistopathologicDataset}\\PY{p}{(}\\PY{n}{Dataset}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{data}\\PY{p}{,} \\PY{n}{train}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{,} \\PY{n}{transform}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{split} \\PY{o}{=} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{train}\\PY{l+s+s2}{\\PYZdq{}} \\PY{k}{if} \\PY{n}{train} \\PY{k}{else} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{test}\\PY{l+s+s2}{\\PYZdq{}}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{fnames} \\PY{o}{=} \\PY{p}{[}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{IMG\\PYZus{}DATASET\\PYZus{}DIR} \\PY{o}{/} \\PY{n}{split} \\PY{o}{/} \\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+si}{\\PYZob{}}\\PY{n}{fn}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{.tif}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{k}{for} \\PY{n}{fn} \\PY{o+ow}{in} \\PY{n}{data}\\PY{o}{.}\\PY{n}{id}\\PY{p}{]}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{labels} \\PY{o}{=} \\PY{n}{data}\\PY{o}{.}\\PY{n}{label}\\PY{o}{.}\\PY{n}{tolist}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{transform} \\PY{o}{=} \\PY{n}{transform}\n", " \n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}len\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{fnames}\\PY{p}{)}\n", " \n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}getitem\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{index}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{img} \\PY{o}{=} \\PY{n}{cv2}\\PY{o}{.}\\PY{n}{imread}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{fnames}\\PY{p}{[}\\PY{n}{index}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{transform}\\PY{p}{:}\n", " \\PY{n}{img} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{transform}\\PY{p}{(}\\PY{n}{img}\\PY{p}{)}\n", " \n", " \\PY{k}{return} \\PY{n}{img}\\PY{p}{,} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{labels}\\PY{p}{[}\\PY{n}{index}\\PY{p}{]}\n", "\n", "\n", "\\PY{n}{data} \\PY{o}{=} \\PY{n}{data}\\PY{o}{.}\\PY{n}{sample}\\PY{p}{(}\\PY{n}{frac}\\PY{o}{=}\\PY{l+m+mf}{1.0}\\PY{p}{)}\n", "\\PY{n}{split} \\PY{o}{=} \\PY{n+nb}{int}\\PY{p}{(}\\PY{l+m+mf}{0.80} \\PY{o}{*} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{data}\\PY{p}{)}\\PY{p}{)}\n", "\\PY{n}{ds\\PYZus{}train} \\PY{o}{=} \\PY{n}{HistopathologicDataset}\\PY{p}{(}\\PY{n}{data}\\PY{p}{[}\\PY{p}{:}\\PY{n}{split}\\PY{p}{]}\\PY{p}{,} \\PY{n}{train}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{,} \\PY{n}{transform}\\PY{o}{=}\\PY{n}{transform\\PYZus{}train}\\PY{p}{)}\n", "\\PY{n}{ds\\PYZus{}valid} \\PY{o}{=} \\PY{n}{HistopathologicDataset}\\PY{p}{(}\\PY{n}{data}\\PY{p}{[}\\PY{n}{split}\\PY{p}{:}\\PY{p}{]}\\PY{p}{,} \\PY{n}{train}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{,} \\PY{n}{transform}\\PY{o}{=}\\PY{n}{transform\\PYZus{}infer}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "from torch.utils.data import DataLoader, Dataset, Subset\n", "\n", "class HistopathologicDataset(Dataset):\n", " def __init__(self, data, train=True, transform=None):\n", " split = \"train\" if train else \"test\"\n", " self.fnames = [str(IMG_DATASET_DIR / split / f\"{fn}.tif\") for fn in data.id]\n", " self.labels = data.label.tolist()\n", " self.transform = transform\n", " \n", " def __len__(self):\n", " return len(self.fnames)\n", " \n", " def __getitem__(self, index):\n", " img = cv2.imread(self.fnames[index])\n", " if self.transform:\n", " img = self.transform(img)\n", " \n", " return img, self.labels[index]\n", "\n", "\n", "data = data.sample(frac=1.0)\n", "split = int(0.80 * len(data))\n", "ds_train = HistopathologicDataset(data[:split], train=True, transform=transform_train)\n", "ds_valid = HistopathologicDataset(data[split:], train=True, transform=transform_infer)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "from torch.utils.data import DataLoader, Dataset, Subset\n", "\n", "class HistopathologicDataset(Dataset):\n", " def __init__(self, data, train=True, transform=None):\n", " split = \"train\" if train else \"test\"\n", " self.fnames = [str(IMG_DATASET_DIR / split / f\"{fn}.tif\") for fn in data.id]\n", " self.labels = data.label.tolist()\n", " self.transform = transform\n", " \n", " def __len__(self):\n", " return len(self.fnames)\n", " \n", " def __getitem__(self, index):\n", " img = cv2.imread(self.fnames[index])\n", " if self.transform:\n", " img = self.transform(img)\n", " \n", " return img, self.labels[index]\n", "\n", "\n", "data = data.sample(frac=1.0)\n", "split = int(0.80 * len(data))\n", "ds_train = HistopathologicDataset(data[:split], train=True, transform=transform_train)\n", "ds_valid = HistopathologicDataset(data[split:], train=True, transform=transform_infer)" ] }, { "cell_type": "markdown", "id": "f0f29ffe", "metadata": { "papermill": { "duration": 0.005598, "end_time": "2024-11-27T10:41:53.514887", "exception": false, "start_time": "2024-11-27T10:41:53.509289", "status": "completed" }, "tags": [] }, "source": [ "Some imbalance (not too severe):" ] }, { "cell_type": "code", "execution_count": 6, "id": "4ec00bfb", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:53.526626Z", "iopub.status.busy": "2024-11-27T10:41:53.526131Z", "iopub.status.idle": "2024-11-27T10:41:53.541586Z", "shell.execute_reply": "2024-11-27T10:41:53.541275Z" }, "papermill": { "duration": 0.023177, "end_time": "2024-11-27T10:41:53.542604", "exception": false, "start_time": "2024-11-27T10:41:53.519427", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(0.4050562436086808, 0.4049312578116123)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# percentage of positive class\n", "data[:split].label.mean(), data[split:].label.mean()" ] }, { "cell_type": "markdown", "id": "64e88d3b", "metadata": { "papermill": { "duration": 0.00456, "end_time": "2024-11-27T10:41:53.551452", "exception": false, "start_time": "2024-11-27T10:41:53.546892", "status": "completed" }, "tags": [] }, "source": [ "Simulating images across epochs:" ] }, { "cell_type": "code", "execution_count": 7, "id": "07f84c52", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:53.561169Z", "iopub.status.busy": "2024-11-27T10:41:53.560998Z", "iopub.status.idle": "2024-11-27T10:41:53.564683Z", "shell.execute_reply": "2024-11-27T10:41:53.564420Z" }, "papermill": { "duration": 0.009819, "end_time": "2024-11-27T10:41:53.565591", "exception": false, "start_time": "2024-11-27T10:41:53.555772", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "simul_train = DataLoader(Subset(ds_train, torch.arange(3)), batch_size=3, shuffle=True)\n", "simul_valid = DataLoader(Subset(ds_valid, torch.arange(1)), batch_size=1, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 8, "id": "0666a796", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:53.574186Z", "iopub.status.busy": "2024-11-27T10:41:53.574053Z", "iopub.status.idle": "2024-11-27T10:41:53.804198Z", "shell.execute_reply": "2024-11-27T10:41:53.802914Z" }, "papermill": { "duration": 0.235869, "end_time": "2024-11-27T10:41:53.805480", "exception": false, "start_time": "2024-11-27T10:41:53.569611", "status": "completed" }, "tags": [ "hide-input" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-11-27T18:41:53.747516\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(3, 4)\n", "for e in range(3):\n", " img_train, tgt_train = next(iter(simul_train))\n", " for i in range(3):\n", " if i == 0:\n", " ax[e, i].set_ylabel(f\"Epoch: {e}\")\n", " \n", " img, tgt = img_train[i], tgt_train[i]\n", " ax[e, i].imshow(img.permute(1, 2, 0).detach())\n", " ax[e, i].set_xlabel(tgt.item())\n", " ax[e, i].set_xticks([])\n", " ax[e, i].set_yticks([])\n", " ax[0, i].set_title(f\"instance: {i}\")\n", "\n", " img_valid, tgt_valid = next(iter(simul_valid))\n", " ax[e, 3].set_xlabel(tgt_valid[0].item())\n", " ax[e, 3].imshow(img_valid[0].permute(1, 2, 0).detach())\n", " ax[e, 3].set_xticks([])\n", " ax[e, 3].set_yticks([])\n", "\n", "ax[0, 3].set_title(\"valid\")\n", "fig.tight_layout()" ] }, { "cell_type": "markdown", "id": "a128334a", "metadata": { "papermill": { "duration": 0.005, "end_time": "2024-11-27T10:41:53.816178", "exception": false, "start_time": "2024-11-27T10:41:53.811178", "status": "completed" }, "tags": [] }, "source": [ "**Figure.** Training instances are stochastically transformed at each epoch. Meanwhile, test instances have fixed transformations, i.e. the expectation of the random transformations.\n", "Note that labels are not affected (both at the recognition and implementation level)." ] }, { "cell_type": "code", "execution_count": null, "id": "b603f534", "metadata": { "papermill": { "duration": 0.009863, "end_time": "2024-11-27T10:41:53.833131", "exception": false, "start_time": "2024-11-27T10:41:53.823268", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 5.909805, "end_time": "2024-11-27T10:41:54.465047", "environment_variables": {}, "exception": null, "input_path": "03c-data-augmentation.ipynb", "output_path": "03c-data-augmentation.ipynb", "parameters": {}, "start_time": "2024-11-27T10:41:48.555242", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }