{ "cells": [ { "cell_type": "markdown", "id": "1d5e868e", "metadata": { "papermill": { "duration": 0.016671, "end_time": "2024-11-27T10:41:49.427398", "exception": false, "start_time": "2024-11-27T10:41:49.410727", "status": "completed" }, "tags": [] }, "source": [ "# Data augmentation" ] }, { "cell_type": "code", "execution_count": 1, "id": "74d13f43", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:49.466555Z", "iopub.status.busy": "2024-11-27T10:41:49.466085Z", "iopub.status.idle": "2024-11-27T10:41:51.447521Z", "shell.execute_reply": "2024-11-27T10:41:51.447134Z" }, "papermill": { "duration": 1.993487, "end_time": "2024-11-27T10:41:51.454917", "exception": false, "start_time": "2024-11-27T10:41:49.461430", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "markdown", "id": "b7a9fa84", "metadata": { "papermill": { "duration": 0.004024, "end_time": "2024-11-27T10:41:51.575401", "exception": false, "start_time": "2024-11-27T10:41:51.571377", "status": "completed" }, "tags": [] }, "source": [ "MNIST is too nice to be representative of real-world datasets. Below we continue with a more realistic Kaggle dataset, [Histopathologic Cancer Detection](https://www.kaggle.com/competitions/histopathologic-cancer-detection/data). The task is to detect metastatic cancer in patches of images from digital pathology scans.\n", "Download the dataset such that the folder structure looks as follows:" ] }, { "cell_type": "markdown", "id": "25b73ee8", "metadata": { "papermill": { "duration": 0.003766, "end_time": "2024-11-27T10:41:51.609831", "exception": false, "start_time": "2024-11-27T10:41:51.606065", "status": "completed" }, "tags": [] }, "source": [ "```\n", "./data/histopathologic-cancer-detection\n", "├── test\n", "├── train\n", "└── train_labels.csv\n", "```" ] }, { "cell_type": "markdown", "id": "00777bf6", "metadata": { "papermill": { "duration": 0.001424, "end_time": "2024-11-27T10:41:51.613252", "exception": false, "start_time": "2024-11-27T10:41:51.611828", "status": "completed" }, "tags": [] }, "source": [ "Taking a look at the first few images:" ] }, { "cell_type": "code", "execution_count": 2, "id": "485b1c3b", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:41:51.625135Z", "iopub.status.busy": "2024-11-27T10:41:51.623091Z", "iopub.status.idle": "2024-11-27T10:41:52.013777Z", "shell.execute_reply": "2024-11-27T10:41:52.013346Z" }, "papermill": { "duration": 0.399907, "end_time": "2024-11-27T10:41:52.015099", "exception": false, "start_time": "2024-11-27T10:41:51.615192", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import cv2\n",
"\n",
"IMG_DATASET_DIR = DATASET_DIR / "histopathologic-cancer-detection"\n",
"data = pd.read_csv(IMG_DATASET_DIR / "train_labels.csv")\n",
"
transform_train = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.RandomVerticalFlip(),\n",
" transforms.RandomRotation(20),\n",
" transforms.CenterCrop([49, 49]),\n",
"])\n",
"\n",
"transform_infer = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.CenterCrop([49, 49]),\n",
"])\n",
"
from torch.utils.data import DataLoader, Dataset, Subset\n",
"\n",
"class HistopathologicDataset(Dataset):\n",
" def __init__(self, data, train=True, transform=None):\n",
" split = "train" if train else "test"\n",
" self.fnames = [str(IMG_DATASET_DIR / split / f"{fn}.tif") for fn in data.id]\n",
" self.labels = data.label.tolist()\n",
" self.transform = transform\n",
" \n",
" def __len__(self):\n",
" return len(self.fnames)\n",
" \n",
" def __getitem__(self, index):\n",
" img = cv2.imread(self.fnames[index])\n",
" if self.transform:\n",
" img = self.transform(img)\n",
" \n",
" return img, self.labels[index]\n",
"\n",
"\n",
"data = data.sample(frac=1.0)\n",
"split = int(0.80 * len(data))\n",
"ds_train = HistopathologicDataset(data[:split], train=True, transform=transform_train)\n",
"ds_valid = HistopathologicDataset(data[split:], train=True, transform=transform_infer)\n",
"