{ "cells": [ { "cell_type": "markdown", "id": "8402960d", "metadata": { "papermill": { "duration": 0.005615, "end_time": "2024-11-27T10:40:32.191101", "exception": false, "start_time": "2024-11-27T10:40:32.185486", "status": "completed" }, "tags": [] }, "source": [ "(dl/03-cnn/03bb-trainer)=\n", "# Trainer engine" ] }, { "cell_type": "markdown", "id": "77bcf0d7", "metadata": { "papermill": { "duration": 0.003115, "end_time": "2024-11-27T10:40:32.198379", "exception": false, "start_time": "2024-11-27T10:40:32.195264", "status": "completed" }, "tags": [] }, "source": [ "To separate concerns during model training, we define a **trainer engine**. For example, this defines an `eval_context` to automatically set the model to eval mode at entry, and back to the default train mode at exit. This is useful for layers such as BN and Dropout which have different behaviors at train and test times. LR schedulers and callbacks are also implemented. Currently, these are called at the end of each training step (it is easy to extend this class to implement epoch end callbacks)." ] }, { "cell_type": "code", "execution_count": 1, "id": "adfa8b33", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:32.207240Z", "iopub.status.busy": "2024-11-27T10:40:32.206844Z", "iopub.status.idle": "2024-11-27T10:40:34.803454Z", "shell.execute_reply": "2024-11-27T10:40:34.803026Z" }, "papermill": { "duration": 2.604281, "end_time": "2024-11-27T10:40:34.805836", "exception": false, "start_time": "2024-11-27T10:40:32.201555", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "42787bd2", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:34.809985Z", "iopub.status.busy": "2024-11-27T10:40:34.809777Z", "iopub.status.idle": "2024-11-27T10:40:34.898587Z", "shell.execute_reply": "2024-11-27T10:40:34.898029Z" }, "papermill": { "duration": 0.092665, "end_time": "2024-11-27T10:40:34.900531", "exception": false, "start_time": "2024-11-27T10:40:34.807866", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
from tqdm.notebook import tqdm\n",
"from contextlib import contextmanager\n",
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"@contextmanager\n",
"def eval_context(model):\n",
" """Temporarily set to eval mode inside context."""\n",
" is_train = model.training\n",
" model.eval()\n",
" try:\n",
" yield\n",
" finally:\n",
" model.train(is_train)\n",
"\n",
"\n",
"class Trainer:\n",
" def __init__(self,\n",
" model, optim, loss_fn, scheduler=None, callbacks=[],\n",
" device=DEVICE, verbose=True\n",
" ):\n",
" self.model = model.to(device)\n",
" self.optim = optim\n",
" self.device = device\n",
" self.loss_fn = loss_fn\n",
" self.train_log = {"loss": [], "accs": [], "loss_avg": [], "accs_avg": []}\n",
" self.valid_log = {"loss": [], "accs": []}\n",
" self.verbose = verbose\n",
" self.scheduler = scheduler\n",
" self.callbacks = callbacks\n",
" \n",
" def __call__(self, x):\n",
" return self.model(x.to(self.device))\n",
"\n",
" def forward(self, batch):\n",
" x, y = batch\n",
" x = x.to(self.device)\n",
" y = y.to(self.device)\n",
" return self.model(x), y\n",
"\n",
" def train_step(self, batch):\n",
" preds, y = self.forward(batch)\n",
" accs = (preds.argmax(dim=1) == y).float().mean()\n",
" loss = self.loss_fn(preds, y)\n",
" loss.backward()\n",
" self.optim.step()\n",
" self.optim.zero_grad()\n",
" return {"loss": loss, "accs": accs}\n",
"\n",
" @torch.inference_mode()\n",
" def valid_step(self, batch):\n",
" preds, y = self.forward(batch)\n",
" accs = (preds.argmax(dim=1) == y).float().sum()\n",
" loss = self.loss_fn(preds, y, reduction="sum")\n",
" return {"loss": loss, "accs": accs}\n",
" \n",
" def run(self, epochs, train_loader, valid_loader):\n",
" for e in tqdm(range(epochs)):\n",
" for batch in train_loader:\n",
" # optim and lr step\n",
" output = self.train_step(batch)\n",
" if self.scheduler:\n",
" self.scheduler.step()\n",
"\n",
" # step callbacks\n",
" for callback in self.callbacks:\n",
" callback()\n",
"\n",
" # logs @ train step\n",
" steps_per_epoch = len(train_loader)\n",
" w = int(0.05 * steps_per_epoch)\n",
" self.train_log["loss"].append(output["loss"].item())\n",
" self.train_log["accs"].append(output["accs"].item())\n",
" self.train_log["loss_avg"].append(np.mean(self.train_log["loss"][-w:]))\n",
" self.train_log["accs_avg"].append(np.mean(self.train_log["accs"][-w:]))\n",
"\n",
" # logs @ epoch\n",
" output = self.evaluate(valid_loader)\n",
" self.valid_log["loss"].append(output["loss"])\n",
" self.valid_log["accs"].append(output["accs"])\n",
" if self.verbose:\n",
" print(f"[Epoch: {e+1:>0{int(len(str(epochs)))}d}/{epochs}] loss: {self.train_log['loss_avg'][-1]:.4f} acc: {self.train_log['accs_avg'][-1]:.4f} val_loss: {self.valid_log['loss'][-1]:.4f} val_acc: {self.valid_log['accs'][-1]:.4f}")\n",
"\n",
" def evaluate(self, data_loader):\n",
" with eval_context(self.model):\n",
" valid_loss = 0.0\n",
" valid_accs = 0.0\n",
" for batch in data_loader:\n",
" output = self.valid_step(batch)\n",
" valid_loss += output["loss"].item()\n",
" valid_accs += output["accs"].item()\n",
"\n",
" return {\n",
" "loss": valid_loss / len(data_loader.dataset),\n",
" "accs": valid_accs / len(data_loader.dataset)\n",
" }\n",
"\n",
" @torch.inference_mode()\n",
" def predict(self, x: torch.Tensor):\n",
" with eval_context(self.model):\n",
" return self(x)\n",
"