{ "cells": [ { "cell_type": "markdown", "id": "8402960d", "metadata": { "papermill": { "duration": 0.005615, "end_time": "2024-11-27T10:40:32.191101", "exception": false, "start_time": "2024-11-27T10:40:32.185486", "status": "completed" }, "tags": [] }, "source": [ "(dl/03-cnn/03bb-trainer)=\n", "# Trainer engine" ] }, { "cell_type": "markdown", "id": "77bcf0d7", "metadata": { "papermill": { "duration": 0.003115, "end_time": "2024-11-27T10:40:32.198379", "exception": false, "start_time": "2024-11-27T10:40:32.195264", "status": "completed" }, "tags": [] }, "source": [ "To separate concerns during model training, we define a **trainer engine**. For example, this defines an `eval_context` to automatically set the model to eval mode at entry, and back to the default train mode at exit. This is useful for layers such as BN and Dropout which have different behaviors at train and test times. LR schedulers and callbacks are also implemented. Currently, these are called at the end of each training step (it is easy to extend this class to implement epoch end callbacks)." ] }, { "cell_type": "code", "execution_count": 1, "id": "adfa8b33", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:32.207240Z", "iopub.status.busy": "2024-11-27T10:40:32.206844Z", "iopub.status.idle": "2024-11-27T10:40:34.803454Z", "shell.execute_reply": "2024-11-27T10:40:34.803026Z" }, "papermill": { "duration": 2.604281, "end_time": "2024-11-27T10:40:34.805836", "exception": false, "start_time": "2024-11-27T10:40:32.201555", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "42787bd2", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:34.809985Z", "iopub.status.busy": "2024-11-27T10:40:34.809777Z", "iopub.status.idle": "2024-11-27T10:40:34.898587Z", "shell.execute_reply": "2024-11-27T10:40:34.898029Z" }, "papermill": { "duration": 0.092665, "end_time": "2024-11-27T10:40:34.900531", "exception": false, "start_time": "2024-11-27T10:40:34.807866", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
from tqdm.notebook import tqdm\n",
       "from contextlib import contextmanager\n",
       "from torch.utils.data import DataLoader\n",
       "\n",
       "\n",
       "@contextmanager\n",
       "def eval_context(model):\n",
       "    """Temporarily set to eval mode inside context."""\n",
       "    is_train = model.training\n",
       "    model.eval()\n",
       "    try:\n",
       "        yield\n",
       "    finally:\n",
       "        model.train(is_train)\n",
       "\n",
       "\n",
       "class Trainer:\n",
       "    def __init__(self,\n",
       "        model, optim, loss_fn, scheduler=None, callbacks=[],\n",
       "        device=DEVICE, verbose=True\n",
       "    ):\n",
       "        self.model = model.to(device)\n",
       "        self.optim = optim\n",
       "        self.device = device\n",
       "        self.loss_fn = loss_fn\n",
       "        self.train_log = {"loss": [], "accs": [], "loss_avg": [], "accs_avg": []}\n",
       "        self.valid_log = {"loss": [], "accs": []}\n",
       "        self.verbose = verbose\n",
       "        self.scheduler = scheduler\n",
       "        self.callbacks = callbacks\n",
       "    \n",
       "    def __call__(self, x):\n",
       "        return self.model(x.to(self.device))\n",
       "\n",
       "    def forward(self, batch):\n",
       "        x, y = batch\n",
       "        x = x.to(self.device)\n",
       "        y = y.to(self.device)\n",
       "        return self.model(x), y\n",
       "\n",
       "    def train_step(self, batch):\n",
       "        preds, y = self.forward(batch)\n",
       "        accs = (preds.argmax(dim=1) == y).float().mean()\n",
       "        loss = self.loss_fn(preds, y)\n",
       "        loss.backward()\n",
       "        self.optim.step()\n",
       "        self.optim.zero_grad()\n",
       "        return {"loss": loss, "accs": accs}\n",
       "\n",
       "    @torch.inference_mode()\n",
       "    def valid_step(self, batch):\n",
       "        preds, y = self.forward(batch)\n",
       "        accs = (preds.argmax(dim=1) == y).float().sum()\n",
       "        loss = self.loss_fn(preds, y, reduction="sum")\n",
       "        return {"loss": loss, "accs": accs}\n",
       "    \n",
       "    def run(self, epochs, train_loader, valid_loader):\n",
       "        for e in tqdm(range(epochs)):\n",
       "            for batch in train_loader:\n",
       "                # optim and lr step\n",
       "                output = self.train_step(batch)\n",
       "                if self.scheduler:\n",
       "                    self.scheduler.step()\n",
       "\n",
       "                # step callbacks\n",
       "                for callback in self.callbacks:\n",
       "                    callback()\n",
       "\n",
       "                # logs @ train step\n",
       "                steps_per_epoch = len(train_loader)\n",
       "                w = int(0.05 * steps_per_epoch)\n",
       "                self.train_log["loss"].append(output["loss"].item())\n",
       "                self.train_log["accs"].append(output["accs"].item())\n",
       "                self.train_log["loss_avg"].append(np.mean(self.train_log["loss"][-w:]))\n",
       "                self.train_log["accs_avg"].append(np.mean(self.train_log["accs"][-w:]))\n",
       "\n",
       "            # logs @ epoch\n",
       "            output = self.evaluate(valid_loader)\n",
       "            self.valid_log["loss"].append(output["loss"])\n",
       "            self.valid_log["accs"].append(output["accs"])\n",
       "            if self.verbose:\n",
       "                print(f"[Epoch: {e+1:>0{int(len(str(epochs)))}d}/{epochs}]    loss: {self.train_log['loss_avg'][-1]:.4f}  acc: {self.train_log['accs_avg'][-1]:.4f}    val_loss: {self.valid_log['loss'][-1]:.4f}  val_acc: {self.valid_log['accs'][-1]:.4f}")\n",
       "\n",
       "    def evaluate(self, data_loader):\n",
       "        with eval_context(self.model):\n",
       "            valid_loss = 0.0\n",
       "            valid_accs = 0.0\n",
       "            for batch in data_loader:\n",
       "                output = self.valid_step(batch)\n",
       "                valid_loss += output["loss"].item()\n",
       "                valid_accs += output["accs"].item()\n",
       "\n",
       "        return {\n",
       "            "loss": valid_loss / len(data_loader.dataset),\n",
       "            "accs": valid_accs / len(data_loader.dataset)\n",
       "        }\n",
       "\n",
       "    @torch.inference_mode()\n",
       "    def predict(self, x: torch.Tensor):\n",
       "        with eval_context(self.model):\n",
       "            return self(x)\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{from} \\PY{n+nn}{tqdm}\\PY{n+nn}{.}\\PY{n+nn}{notebook} \\PY{k+kn}{import} \\PY{n}{tqdm}\n", "\\PY{k+kn}{from} \\PY{n+nn}{contextlib} \\PY{k+kn}{import} \\PY{n}{contextmanager}\n", "\\PY{k+kn}{from} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{utils}\\PY{n+nn}{.}\\PY{n+nn}{data} \\PY{k+kn}{import} \\PY{n}{DataLoader}\n", "\n", "\n", "\\PY{n+nd}{@contextmanager}\n", "\\PY{k}{def} \\PY{n+nf}{eval\\PYZus{}context}\\PY{p}{(}\\PY{n}{model}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Temporarily set to eval mode inside context.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n}{is\\PYZus{}train} \\PY{o}{=} \\PY{n}{model}\\PY{o}{.}\\PY{n}{training}\n", " \\PY{n}{model}\\PY{o}{.}\\PY{n}{eval}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{try}\\PY{p}{:}\n", " \\PY{k}{yield}\n", " \\PY{k}{finally}\\PY{p}{:}\n", " \\PY{n}{model}\\PY{o}{.}\\PY{n}{train}\\PY{p}{(}\\PY{n}{is\\PYZus{}train}\\PY{p}{)}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Trainer}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,}\n", " \\PY{n}{model}\\PY{p}{,} \\PY{n}{optim}\\PY{p}{,} \\PY{n}{loss\\PYZus{}fn}\\PY{p}{,} \\PY{n}{scheduler}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{,} \\PY{n}{callbacks}\\PY{o}{=}\\PY{p}{[}\\PY{p}{]}\\PY{p}{,}\n", " \\PY{n}{device}\\PY{o}{=}\\PY{n}{DEVICE}\\PY{p}{,} \\PY{n}{verbose}\\PY{o}{=}\\PY{k+kc}{True}\n", " \\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{model} \\PY{o}{=} \\PY{n}{model}\\PY{o}{.}\\PY{n}{to}\\PY{p}{(}\\PY{n}{device}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{optim} \\PY{o}{=} \\PY{n}{optim}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{device} \\PY{o}{=} \\PY{n}{device}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{loss\\PYZus{}fn} \\PY{o}{=} \\PY{n}{loss\\PYZus{}fn}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log} \\PY{o}{=} \\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss\\PYZus{}avg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs\\PYZus{}avg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{\\PYZcb{}}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}log} \\PY{o}{=} \\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{p}{[}\\PY{p}{]}\\PY{p}{\\PYZcb{}}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{verbose} \\PY{o}{=} \\PY{n}{verbose}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{scheduler} \\PY{o}{=} \\PY{n}{scheduler}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{callbacks} \\PY{o}{=} \\PY{n}{callbacks}\n", " \n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}call\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{model}\\PY{p}{(}\\PY{n}{x}\\PY{o}{.}\\PY{n}{to}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{forward}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{batch}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{x}\\PY{p}{,} \\PY{n}{y} \\PY{o}{=} \\PY{n}{batch}\n", " \\PY{n}{x} \\PY{o}{=} \\PY{n}{x}\\PY{o}{.}\\PY{n}{to}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\n", " \\PY{n}{y} \\PY{o}{=} \\PY{n}{y}\\PY{o}{.}\\PY{n}{to}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{device}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{model}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)}\\PY{p}{,} \\PY{n}{y}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{train\\PYZus{}step}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{batch}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{preds}\\PY{p}{,} \\PY{n}{y} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{forward}\\PY{p}{(}\\PY{n}{batch}\\PY{p}{)}\n", " \\PY{n}{accs} \\PY{o}{=} \\PY{p}{(}\\PY{n}{preds}\\PY{o}{.}\\PY{n}{argmax}\\PY{p}{(}\\PY{n}{dim}\\PY{o}{=}\\PY{l+m+mi}{1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{y}\\PY{p}{)}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{mean}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{loss} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{loss\\PYZus{}fn}\\PY{p}{(}\\PY{n}{preds}\\PY{p}{,} \\PY{n}{y}\\PY{p}{)}\n", " \\PY{n}{loss}\\PY{o}{.}\\PY{n}{backward}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{optim}\\PY{o}{.}\\PY{n}{step}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{optim}\\PY{o}{.}\\PY{n}{zero\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{loss}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{accs}\\PY{p}{\\PYZcb{}}\n", "\n", " \\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{inference\\PYZus{}mode}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{def} \\PY{n+nf}{valid\\PYZus{}step}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{batch}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{preds}\\PY{p}{,} \\PY{n}{y} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{forward}\\PY{p}{(}\\PY{n}{batch}\\PY{p}{)}\n", " \\PY{n}{accs} \\PY{o}{=} \\PY{p}{(}\\PY{n}{preds}\\PY{o}{.}\\PY{n}{argmax}\\PY{p}{(}\\PY{n}{dim}\\PY{o}{=}\\PY{l+m+mi}{1}\\PY{p}{)} \\PY{o}{==} \\PY{n}{y}\\PY{p}{)}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{sum}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{loss} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{loss\\PYZus{}fn}\\PY{p}{(}\\PY{n}{preds}\\PY{p}{,} \\PY{n}{y}\\PY{p}{,} \\PY{n}{reduction}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{sum}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{loss}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{accs}\\PY{p}{\\PYZcb{}}\n", " \n", " \\PY{k}{def} \\PY{n+nf}{run}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{epochs}\\PY{p}{,} \\PY{n}{train\\PYZus{}loader}\\PY{p}{,} \\PY{n}{valid\\PYZus{}loader}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{e} \\PY{o+ow}{in} \\PY{n}{tqdm}\\PY{p}{(}\\PY{n+nb}{range}\\PY{p}{(}\\PY{n}{epochs}\\PY{p}{)}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{batch} \\PY{o+ow}{in} \\PY{n}{train\\PYZus{}loader}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} optim and lr step}\n", " \\PY{n}{output} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}step}\\PY{p}{(}\\PY{n}{batch}\\PY{p}{)}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{scheduler}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{scheduler}\\PY{o}{.}\\PY{n}{step}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{c+c1}{\\PYZsh{} step callbacks}\n", " \\PY{k}{for} \\PY{n}{callback} \\PY{o+ow}{in} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{callbacks}\\PY{p}{:}\n", " \\PY{n}{callback}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{c+c1}{\\PYZsh{} logs @ train step}\n", " \\PY{n}{steps\\PYZus{}per\\PYZus{}epoch} \\PY{o}{=} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{train\\PYZus{}loader}\\PY{p}{)}\n", " \\PY{n}{w} \\PY{o}{=} \\PY{n+nb}{int}\\PY{p}{(}\\PY{l+m+mf}{0.05} \\PY{o}{*} \\PY{n}{steps\\PYZus{}per\\PYZus{}epoch}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss\\PYZus{}avg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{np}\\PY{o}{.}\\PY{n}{mean}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{n}{w}\\PY{p}{:}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs\\PYZus{}avg}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{np}\\PY{o}{.}\\PY{n}{mean}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{n}{w}\\PY{p}{:}\\PY{p}{]}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{c+c1}{\\PYZsh{} logs @ epoch}\n", " \\PY{n}{output} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{evaluate}\\PY{p}{(}\\PY{n}{valid\\PYZus{}loader}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{append}\\PY{p}{(}\\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{k}{if} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{verbose}\\PY{p}{:}\n", " \\PY{n+nb}{print}\\PY{p}{(}\\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{[Epoch: }\\PY{l+s+si}{\\PYZob{}}\\PY{n}{e}\\PY{o}{+}\\PY{l+m+mi}{1}\\PY{l+s+si}{:}\\PY{l+s+s2}{\\PYZgt{}0}\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb}{int}\\PY{p}{(}\\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb}{str}\\PY{p}{(}\\PY{n}{epochs}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{d}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{/}\\PY{l+s+si}{\\PYZob{}}\\PY{n}{epochs}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{] loss: }\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{loss\\PYZus{}avg}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{l+s+si}{:}\\PY{l+s+s2}{.4f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{ acc: }\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{train\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{accs\\PYZus{}avg}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{l+s+si}{:}\\PY{l+s+s2}{.4f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{ val\\PYZus{}loss: }\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{loss}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{l+s+si}{:}\\PY{l+s+s2}{.4f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{ val\\PYZus{}acc: }\\PY{l+s+si}{\\PYZob{}}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}log}\\PY{p}{[}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{accs}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{]}\\PY{p}{[}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{l+s+si}{:}\\PY{l+s+s2}{.4f}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{evaluate}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{data\\PYZus{}loader}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{eval\\PYZus{}context}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{model}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{valid\\PYZus{}loss} \\PY{o}{=} \\PY{l+m+mf}{0.0}\n", " \\PY{n}{valid\\PYZus{}accs} \\PY{o}{=} \\PY{l+m+mf}{0.0}\n", " \\PY{k}{for} \\PY{n}{batch} \\PY{o+ow}{in} \\PY{n}{data\\PYZus{}loader}\\PY{p}{:}\n", " \\PY{n}{output} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{valid\\PYZus{}step}\\PY{p}{(}\\PY{n}{batch}\\PY{p}{)}\n", " \\PY{n}{valid\\PYZus{}loss} \\PY{o}{+}\\PY{o}{=} \\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{valid\\PYZus{}accs} \\PY{o}{+}\\PY{o}{=} \\PY{n}{output}\\PY{p}{[}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{k}{return} \\PY{p}{\\PYZob{}}\n", " \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{loss}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{valid\\PYZus{}loss} \\PY{o}{/} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{data\\PYZus{}loader}\\PY{o}{.}\\PY{n}{dataset}\\PY{p}{)}\\PY{p}{,}\n", " \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{accs}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{n}{valid\\PYZus{}accs} \\PY{o}{/} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{data\\PYZus{}loader}\\PY{o}{.}\\PY{n}{dataset}\\PY{p}{)}\n", " \\PY{p}{\\PYZcb{}}\n", "\n", " \\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{inference\\PYZus{}mode}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{def} \\PY{n+nf}{predict}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{:} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{Tensor}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{eval\\PYZus{}context}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{model}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "from tqdm.notebook import tqdm\n", "from contextlib import contextmanager\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "@contextmanager\n", "def eval_context(model):\n", " \"\"\"Temporarily set to eval mode inside context.\"\"\"\n", " is_train = model.training\n", " model.eval()\n", " try:\n", " yield\n", " finally:\n", " model.train(is_train)\n", "\n", "\n", "class Trainer:\n", " def __init__(self,\n", " model, optim, loss_fn, scheduler=None, callbacks=[],\n", " device=DEVICE, verbose=True\n", " ):\n", " self.model = model.to(device)\n", " self.optim = optim\n", " self.device = device\n", " self.loss_fn = loss_fn\n", " self.train_log = {\"loss\": [], \"accs\": [], \"loss_avg\": [], \"accs_avg\": []}\n", " self.valid_log = {\"loss\": [], \"accs\": []}\n", " self.verbose = verbose\n", " self.scheduler = scheduler\n", " self.callbacks = callbacks\n", " \n", " def __call__(self, x):\n", " return self.model(x.to(self.device))\n", "\n", " def forward(self, batch):\n", " x, y = batch\n", " x = x.to(self.device)\n", " y = y.to(self.device)\n", " return self.model(x), y\n", "\n", " def train_step(self, batch):\n", " preds, y = self.forward(batch)\n", " accs = (preds.argmax(dim=1) == y).float().mean()\n", " loss = self.loss_fn(preds, y)\n", " loss.backward()\n", " self.optim.step()\n", " self.optim.zero_grad()\n", " return {\"loss\": loss, \"accs\": accs}\n", "\n", " @torch.inference_mode()\n", " def valid_step(self, batch):\n", " preds, y = self.forward(batch)\n", " accs = (preds.argmax(dim=1) == y).float().sum()\n", " loss = self.loss_fn(preds, y, reduction=\"sum\")\n", " return {\"loss\": loss, \"accs\": accs}\n", " \n", " def run(self, epochs, train_loader, valid_loader):\n", " for e in tqdm(range(epochs)):\n", " for batch in train_loader:\n", " # optim and lr step\n", " output = self.train_step(batch)\n", " if self.scheduler:\n", " self.scheduler.step()\n", "\n", " # step callbacks\n", " for callback in self.callbacks:\n", " callback()\n", "\n", " # logs @ train step\n", " steps_per_epoch = len(train_loader)\n", " w = int(0.05 * steps_per_epoch)\n", " self.train_log[\"loss\"].append(output[\"loss\"].item())\n", " self.train_log[\"accs\"].append(output[\"accs\"].item())\n", " self.train_log[\"loss_avg\"].append(np.mean(self.train_log[\"loss\"][-w:]))\n", " self.train_log[\"accs_avg\"].append(np.mean(self.train_log[\"accs\"][-w:]))\n", "\n", " # logs @ epoch\n", " output = self.evaluate(valid_loader)\n", " self.valid_log[\"loss\"].append(output[\"loss\"])\n", " self.valid_log[\"accs\"].append(output[\"accs\"])\n", " if self.verbose:\n", " print(f\"[Epoch: {e+1:>0{int(len(str(epochs)))}d}/{epochs}] loss: {self.train_log['loss_avg'][-1]:.4f} acc: {self.train_log['accs_avg'][-1]:.4f} val_loss: {self.valid_log['loss'][-1]:.4f} val_acc: {self.valid_log['accs'][-1]:.4f}\")\n", "\n", " def evaluate(self, data_loader):\n", " with eval_context(self.model):\n", " valid_loss = 0.0\n", " valid_accs = 0.0\n", " for batch in data_loader:\n", " output = self.valid_step(batch)\n", " valid_loss += output[\"loss\"].item()\n", " valid_accs += output[\"accs\"].item()\n", "\n", " return {\n", " \"loss\": valid_loss / len(data_loader.dataset),\n", " \"accs\": valid_accs / len(data_loader.dataset)\n", " }\n", "\n", " @torch.inference_mode()\n", " def predict(self, x: torch.Tensor):\n", " with eval_context(self.model):\n", " return self(x)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "from tqdm.notebook import tqdm\n", "from contextlib import contextmanager\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "@contextmanager\n", "def eval_context(model):\n", " \"\"\"Temporarily set to eval mode inside context.\"\"\"\n", " is_train = model.training\n", " model.eval()\n", " try:\n", " yield\n", " finally:\n", " model.train(is_train)\n", "\n", "\n", "class Trainer:\n", " def __init__(self,\n", " model, optim, loss_fn, scheduler=None, callbacks=[],\n", " device=DEVICE, verbose=True\n", " ):\n", " self.model = model.to(device)\n", " self.optim = optim\n", " self.device = device\n", " self.loss_fn = loss_fn\n", " self.train_log = {\"loss\": [], \"accs\": [], \"loss_avg\": [], \"accs_avg\": []}\n", " self.valid_log = {\"loss\": [], \"accs\": []}\n", " self.verbose = verbose\n", " self.scheduler = scheduler\n", " self.callbacks = callbacks\n", " \n", " def __call__(self, x):\n", " return self.model(x.to(self.device))\n", "\n", " def forward(self, batch):\n", " x, y = batch\n", " x = x.to(self.device)\n", " y = y.to(self.device)\n", " return self.model(x), y\n", "\n", " def train_step(self, batch):\n", " preds, y = self.forward(batch)\n", " accs = (preds.argmax(dim=1) == y).float().mean()\n", " loss = self.loss_fn(preds, y)\n", " loss.backward()\n", " self.optim.step()\n", " self.optim.zero_grad()\n", " return {\"loss\": loss, \"accs\": accs}\n", "\n", " @torch.inference_mode()\n", " def valid_step(self, batch):\n", " preds, y = self.forward(batch)\n", " accs = (preds.argmax(dim=1) == y).float().sum()\n", " loss = self.loss_fn(preds, y, reduction=\"sum\")\n", " return {\"loss\": loss, \"accs\": accs}\n", " \n", " def run(self, epochs, train_loader, valid_loader):\n", " for e in tqdm(range(epochs)):\n", " for batch in train_loader:\n", " # optim and lr step\n", " output = self.train_step(batch)\n", " if self.scheduler:\n", " self.scheduler.step()\n", "\n", " # step callbacks\n", " for callback in self.callbacks:\n", " callback()\n", "\n", " # logs @ train step\n", " steps_per_epoch = len(train_loader)\n", " w = int(0.05 * steps_per_epoch)\n", " self.train_log[\"loss\"].append(output[\"loss\"].item())\n", " self.train_log[\"accs\"].append(output[\"accs\"].item())\n", " self.train_log[\"loss_avg\"].append(np.mean(self.train_log[\"loss\"][-w:]))\n", " self.train_log[\"accs_avg\"].append(np.mean(self.train_log[\"accs\"][-w:]))\n", "\n", " # logs @ epoch\n", " output = self.evaluate(valid_loader)\n", " self.valid_log[\"loss\"].append(output[\"loss\"])\n", " self.valid_log[\"accs\"].append(output[\"accs\"])\n", " if self.verbose:\n", " print(f\"[Epoch: {e+1:>0{int(len(str(epochs)))}d}/{epochs}] loss: {self.train_log['loss_avg'][-1]:.4f} acc: {self.train_log['accs_avg'][-1]:.4f} val_loss: {self.valid_log['loss'][-1]:.4f} val_acc: {self.valid_log['accs'][-1]:.4f}\")\n", "\n", " def evaluate(self, data_loader):\n", " with eval_context(self.model):\n", " valid_loss = 0.0\n", " valid_accs = 0.0\n", " for batch in data_loader:\n", " output = self.valid_step(batch)\n", " valid_loss += output[\"loss\"].item()\n", " valid_accs += output[\"accs\"].item()\n", "\n", " return {\n", " \"loss\": valid_loss / len(data_loader.dataset),\n", " \"accs\": valid_accs / len(data_loader.dataset)\n", " }\n", "\n", " @torch.inference_mode()\n", " def predict(self, x: torch.Tensor):\n", " with eval_context(self.model):\n", " return self(x)" ] }, { "cell_type": "markdown", "id": "b4e8ab7f", "metadata": { "papermill": { "duration": 0.003343, "end_time": "2024-11-27T10:40:34.908349", "exception": false, "start_time": "2024-11-27T10:40:34.905006", "status": "completed" }, "tags": [] }, "source": [ "The `predict` method is suited for inference over *one* transformed mini-batch. A model call over a large input tensor may cause memory error. The model does not generate a computational graph to conserve memory and calls the model with layers in eval mode. For large batches, one should use `batch_predict` which is the same but takes in a data loader with transforms." ] }, { "cell_type": "code", "execution_count": 3, "id": "d3190b4a", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:34.919851Z", "iopub.status.busy": "2024-11-27T10:40:34.919574Z", "iopub.status.idle": "2024-11-27T10:40:35.189593Z", "shell.execute_reply": "2024-11-27T10:40:35.188920Z" }, "papermill": { "duration": 0.279978, "end_time": "2024-11-27T10:40:35.192325", "exception": false, "start_time": "2024-11-27T10:40:34.912347", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "__call__ 0.000\n", "predict 0.500\n" ] } ], "source": [ "model = nn.Sequential(nn.Linear(3, 10), nn.Dropout(1.0))\n", "trainer = Trainer(model, optim=None, scheduler=None, loss_fn=None)\n", "\n", "# inference mode using eval_context\n", "x = torch.ones(size=(1, 3), requires_grad=True)\n", "print(f\"__call__ {(trainer(x) > 0).float().mean():.3f}\")\n", "print(f\"predict {(trainer.predict(x) > 0).float().mean():.3f}\")" ] }, { "cell_type": "markdown", "id": "d3fcdbe6", "metadata": { "papermill": { "duration": 0.001613, "end_time": "2024-11-27T10:40:35.195583", "exception": false, "start_time": "2024-11-27T10:40:35.193970", "status": "completed" }, "tags": [] }, "source": [ "Checking computational graph generation:" ] }, { "cell_type": "code", "execution_count": 4, "id": "e6a75784", "metadata": { "execution": { "iopub.execute_input": "2024-11-27T10:40:35.199028Z", "iopub.status.busy": "2024-11-27T10:40:35.198828Z", "iopub.status.idle": "2024-11-27T10:40:35.202238Z", "shell.execute_reply": "2024-11-27T10:40:35.201959Z" }, "papermill": { "duration": 0.006207, "end_time": "2024-11-27T10:40:35.203157", "exception": false, "start_time": "2024-11-27T10:40:35.196950", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "__call__ True\n", "predict False\n" ] } ], "source": [ "y = trainer(x)\n", "z = trainer.predict(x)\n", "print(\"__call__ \", y.requires_grad)\n", "print(\"predict \", z.requires_grad)" ] }, { "cell_type": "code", "execution_count": null, "id": "7b70d00e", "metadata": { "papermill": { "duration": 0.001154, "end_time": "2024-11-27T10:40:35.205670", "exception": false, "start_time": "2024-11-27T10:40:35.204516", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 4.995913, "end_time": "2024-11-27T10:40:35.930013", "environment_variables": {}, "exception": null, "input_path": "03bb-trainer.ipynb", "output_path": "03bb-trainer.ipynb", "parameters": {}, "start_time": "2024-11-27T10:40:30.934100", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }