{ "cells": [ { "cell_type": "markdown", "id": "7a4912ac", "metadata": { "papermill": { "duration": 0.009569, "end_time": "2025-01-13T14:53:19.267785", "exception": false, "start_time": "2025-01-13T14:53:19.258216", "status": "completed" }, "tags": [] }, "source": [ "# RNN model training" ] }, { "cell_type": "markdown", "id": "78c572f5", "metadata": { "papermill": { "duration": 0.018123, "end_time": "2025-01-13T14:53:19.294596", "exception": false, "start_time": "2025-01-13T14:53:19.276473", "status": "completed" }, "tags": [] }, "source": [ "Each mini-batch contains $B \\times T$ prediction instances for training. \n", "Prediction is done at each step, with the state updated at each step as well.\n", "Below are logits for $t = 1, \\ldots, 30.$ At each step, the hidden state is updated before making the next prediction. Finally, the model is evaluated at every time step with varying-length inputs. State starts at zero, so it may be necessary to warm the model up." ] }, { "cell_type": "code", "execution_count": 1, "id": "6ef4fed3", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:19.301186Z", "iopub.status.busy": "2025-01-13T14:53:19.300939Z", "iopub.status.idle": "2025-01-13T14:53:20.275472Z", "shell.execute_reply": "2025-01-13T14:53:20.275161Z" }, "papermill": { "duration": 0.979657, "end_time": "2025-01-13T14:53:20.277299", "exception": false, "start_time": "2025-01-13T14:53:19.297642", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa74fbd2", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.282804Z", "iopub.status.busy": "2025-01-13T14:53:20.282595Z", "iopub.status.idle": "2025-01-13T14:53:20.308962Z", "shell.execute_reply": "2025-01-13T14:53:20.308627Z" }, "papermill": { "duration": 0.032474, "end_time": "2025-01-13T14:53:20.312456", "exception": false, "start_time": "2025-01-13T14:53:20.279982", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
def collate_fn(batch):\n",
       "    """Transforming the data to sequence-first format."""\n",
       "    x, y = zip(*batch)\n",
       "    x = torch.stack(x, 1)      # (T, B, vocab_size)\n",
       "    y = torch.stack(y, 1)      # (T, B)\n",
       "    return x, y\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k}{def} \\PY{n+nf}{collate\\PYZus{}fn}\\PY{p}{(}\\PY{n}{batch}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Transforming the data to sequence\\PYZhy{}first format.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n}{x}\\PY{p}{,} \\PY{n}{y} \\PY{o}{=} \\PY{n+nb}{zip}\\PY{p}{(}\\PY{o}{*}\\PY{n}{batch}\\PY{p}{)}\n", " \\PY{n}{x} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{stack}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} (T, B, vocab\\PYZus{}size)}\n", " \\PY{n}{y} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{stack}\\PY{p}{(}\\PY{n}{y}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} (T, B)}\n", " \\PY{k}{return} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\n", "\\end{Verbatim}\n" ], "text/plain": [ "def collate_fn(batch):\n", " \"\"\"Transforming the data to sequence-first format.\"\"\"\n", " x, y = zip(*batch)\n", " x = torch.stack(x, 1) # (T, B, vocab_size)\n", " y = torch.stack(y, 1) # (T, B)\n", " return x, y" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "def collate_fn(batch):\n", " \"\"\"Transforming the data to sequence-first format.\"\"\"\n", " x, y = zip(*batch)\n", " x = torch.stack(x, 1) # (T, B, vocab_size)\n", " y = torch.stack(y, 1) # (T, B)\n", " return x, y" ] }, { "cell_type": "markdown", "id": "5d647b2f", "metadata": { "papermill": { "duration": 0.001993, "end_time": "2025-01-13T14:53:20.316688", "exception": false, "start_time": "2025-01-13T14:53:20.314695", "status": "completed" }, "tags": [] }, "source": [ "Training on sequences from *The Time Machine*:" ] }, { "cell_type": "code", "execution_count": 3, "id": "ffb3a1a5", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.322069Z", "iopub.status.busy": "2025-01-13T14:53:20.321894Z", "iopub.status.idle": "2025-01-13T14:53:20.401168Z", "shell.execute_reply": "2025-01-13T14:53:20.400839Z" }, "papermill": { "duration": 0.083533, "end_time": "2025-01-13T14:53:20.402583", "exception": false, "start_time": "2025-01-13T14:53:20.319050", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "preds per epoch\n", "train: 4.182e+06\n", "valid: 1.048e+06\n" ] } ], "source": [ "from torch.utils.data import random_split\n", "\n", "data, tokenizer = TimeMachine().build()\n", "T = 30\n", "BATCH_SIZE = 128\n", "VOCAB_SIZE = tokenizer.vocab_size\n", "\n", "dataset = SequenceDataset(data, seq_len=T, vocab_size=VOCAB_SIZE)\n", "train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])\n", "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)\n", "valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # also sampled\n", "\n", "print(\"preds per epoch\")\n", "print(\"train:\", f\"{len(train_loader) * BATCH_SIZE * T: .3e}\")\n", "print(\"valid:\", f\"{len(valid_loader) * BATCH_SIZE * T: .3e}\")" ] }, { "cell_type": "markdown", "id": "0e81c75a", "metadata": { "papermill": { "duration": 0.002823, "end_time": "2025-01-13T14:53:20.409188", "exception": false, "start_time": "2025-01-13T14:53:20.406365", "status": "completed" }, "tags": [] }, "source": [ "When training RNNs, it is common to use **gradient clipping**. RNNs are deep in another sense, i.e. in sequence length since we apply the state update function $f$ for each sequence element. Hence, during BP, we get matrix products of length $O(T).$ This causes gradients to explode or vanish resulting in numerical instability. A direct solution to exploding gradients is simply to clip them. Here we project them to a ball of radius $\\xi > 0.$ Thus,\n", "\n", "$$\n", "\\boldsymbol{\\mathsf{g}} \\leftarrow \\min \\left(1, \\frac{\\xi}{\\| \\boldsymbol{\\mathsf{g}} \\|} \\right) \\boldsymbol{\\mathsf{g}} = \\min \\left({\\| \\boldsymbol{\\mathsf{g}}\\|,\\, {\\xi}} \\right) \\frac{\\boldsymbol{\\mathsf{g}}}{\\| {\\boldsymbol{\\mathsf{g}}} \\|}.\n", "$$\n", "\n", "First, the gradient is still in the same direction but clipped in norm to $\\xi.$ \n", "So when $\\| \\boldsymbol{\\mathsf{g}} \\| \\leq \\xi$, the gradient is unchanged. On the other hand, when \n", "$\\| \\boldsymbol{\\mathsf{g}} \\| > \\xi$, the above ratio goes out of the $\\min$ operation, and the gradient is scaled to have norm $\\xi.$" ] }, { "cell_type": "code", "execution_count": 4, "id": "7782992f", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.440944Z", "iopub.status.busy": "2025-01-13T14:53:20.440546Z", "iopub.status.idle": "2025-01-13T14:53:20.454603Z", "shell.execute_reply": "2025-01-13T14:53:20.454190Z" }, "papermill": { "duration": 0.029689, "end_time": "2025-01-13T14:53:20.455998", "exception": false, "start_time": "2025-01-13T14:53:20.426309", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import torch.nn.functional as F\n",
       "\n",
       "def clip_grad_norm(model, max_norm: float):\n",
       "    """Calculate norm on concatenated params. Modify params in-place."""\n",
       "    params = [p for p in model.parameters() if p.requires_grad]\n",
       "    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
       "    if norm > max_norm:\n",
       "        for p in params:\n",
       "            p.grad[:] *= max_norm / norm   # [:] = shallow copy, in-place\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional} \\PY{k}{as} \\PY{n+nn}{F}\n", "\n", "\\PY{k}{def} \\PY{n+nf}{clip\\PYZus{}grad\\PYZus{}norm}\\PY{p}{(}\\PY{n}{model}\\PY{p}{,} \\PY{n}{max\\PYZus{}norm}\\PY{p}{:} \\PY{n+nb}{float}\\PY{p}{)}\\PY{p}{:}\n", "\\PY{+w}{ }\\PY{l+s+sd}{\\PYZdq{}\\PYZdq{}\\PYZdq{}Calculate norm on concatenated params. Modify params in\\PYZhy{}place.\\PYZdq{}\\PYZdq{}\\PYZdq{}}\n", " \\PY{n}{params} \\PY{o}{=} \\PY{p}{[}\\PY{n}{p} \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n}{model}\\PY{o}{.}\\PY{n}{parameters}\\PY{p}{(}\\PY{p}{)} \\PY{k}{if} \\PY{n}{p}\\PY{o}{.}\\PY{n}{requires\\PYZus{}grad}\\PY{p}{]}\n", " \\PY{n}{norm} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sqrt}\\PY{p}{(}\\PY{n+nb}{sum}\\PY{p}{(}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{sum}\\PY{p}{(}\\PY{p}{(}\\PY{n}{p}\\PY{o}{.}\\PY{n}{grad} \\PY{o}{*}\\PY{o}{*} \\PY{l+m+mi}{2}\\PY{p}{)}\\PY{p}{)} \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n}{params}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{k}{if} \\PY{n}{norm} \\PY{o}{\\PYZgt{}} \\PY{n}{max\\PYZus{}norm}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{p} \\PY{o+ow}{in} \\PY{n}{params}\\PY{p}{:}\n", " \\PY{n}{p}\\PY{o}{.}\\PY{n}{grad}\\PY{p}{[}\\PY{p}{:}\\PY{p}{]} \\PY{o}{*}\\PY{o}{=} \\PY{n}{max\\PYZus{}norm} \\PY{o}{/} \\PY{n}{norm} \\PY{c+c1}{\\PYZsh{} [:] = shallow copy, in\\PYZhy{}place}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import torch.nn.functional as F\n", "\n", "def clip_grad_norm(model, max_norm: float):\n", " \"\"\"Calculate norm on concatenated params. Modify params in-place.\"\"\"\n", " params = [p for p in model.parameters() if p.requires_grad]\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > max_norm:\n", " for p in params:\n", " p.grad[:] *= max_norm / norm # [:] = shallow copy, in-place" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import torch.nn.functional as F\n", "\n", "def clip_grad_norm(model, max_norm: float):\n", " \"\"\"Calculate norm on concatenated params. Modify params in-place.\"\"\"\n", " params = [p for p in model.parameters() if p.requires_grad]\n", " norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n", " if norm > max_norm:\n", " for p in params:\n", " p.grad[:] *= max_norm / norm # [:] = shallow copy, in-place" ] }, { "cell_type": "markdown", "id": "dc1b01ce", "metadata": { "papermill": { "duration": 0.002165, "end_time": "2025-01-13T14:53:20.460455", "exception": false, "start_time": "2025-01-13T14:53:20.458290", "status": "completed" }, "tags": [] }, "source": [ "PyTorch `F.cross_entropy` expects input `(B, C, T)` and target `(B, T)`:" ] }, { "cell_type": "code", "execution_count": 5, "id": "69613453", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.466416Z", "iopub.status.busy": "2025-01-13T14:53:20.465957Z", "iopub.status.idle": "2025-01-13T14:53:20.475270Z", "shell.execute_reply": "2025-01-13T14:53:20.474912Z" }, "papermill": { "duration": 0.013457, "end_time": "2025-01-13T14:53:20.476532", "exception": false, "start_time": "2025-01-13T14:53:20.463075", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
def train_step(model, optim, x, y, max_norm) -> float:\n",
       "    target = y.transpose(0, 1)\n",
       "    output = model(x).permute(1, 2, 0)\n",
       "    loss = F.cross_entropy(output, target)\n",
       "    loss.backward()\n",
       "    \n",
       "    clip_grad_norm(model, max_norm=max_norm)\n",
       "    optim.step()\n",
       "    optim.zero_grad()\n",
       "    return loss.item()\n",
       "\n",
       "\n",
       "@torch.no_grad()\n",
       "def valid_step(model, x, y) -> float:\n",
       "    target = y.transpose(0, 1)\n",
       "    output = model(x).permute(1, 2, 0)\n",
       "    loss = F.cross_entropy(output, target)\n",
       "    return loss.item()\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k}{def} \\PY{n+nf}{train\\PYZus{}step}\\PY{p}{(}\\PY{n}{model}\\PY{p}{,} \\PY{n}{optim}\\PY{p}{,} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{,} \\PY{n}{max\\PYZus{}norm}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{float}\\PY{p}{:}\n", " \\PY{n}{target} \\PY{o}{=} \\PY{n}{y}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{)}\n", " \\PY{n}{output} \\PY{o}{=} \\PY{n}{model}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\n", " \\PY{n}{loss} \\PY{o}{=} \\PY{n}{F}\\PY{o}{.}\\PY{n}{cross\\PYZus{}entropy}\\PY{p}{(}\\PY{n}{output}\\PY{p}{,} \\PY{n}{target}\\PY{p}{)}\n", " \\PY{n}{loss}\\PY{o}{.}\\PY{n}{backward}\\PY{p}{(}\\PY{p}{)}\n", " \n", " \\PY{n}{clip\\PYZus{}grad\\PYZus{}norm}\\PY{p}{(}\\PY{n}{model}\\PY{p}{,} \\PY{n}{max\\PYZus{}norm}\\PY{o}{=}\\PY{n}{max\\PYZus{}norm}\\PY{p}{)}\n", " \\PY{n}{optim}\\PY{o}{.}\\PY{n}{step}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{optim}\\PY{o}{.}\\PY{n}{zero\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{loss}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\n", "\n", "\n", "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", "\\PY{k}{def} \\PY{n+nf}{valid\\PYZus{}step}\\PY{p}{(}\\PY{n}{model}\\PY{p}{,} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{float}\\PY{p}{:}\n", " \\PY{n}{target} \\PY{o}{=} \\PY{n}{y}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{l+m+mi}{0}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{)}\n", " \\PY{n}{output} \\PY{o}{=} \\PY{n}{model}\\PY{p}{(}\\PY{n}{x}\\PY{p}{)}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\n", " \\PY{n}{loss} \\PY{o}{=} \\PY{n}{F}\\PY{o}{.}\\PY{n}{cross\\PYZus{}entropy}\\PY{p}{(}\\PY{n}{output}\\PY{p}{,} \\PY{n}{target}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{loss}\\PY{o}{.}\\PY{n}{item}\\PY{p}{(}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "def train_step(model, optim, x, y, max_norm) -> float:\n", " target = y.transpose(0, 1)\n", " output = model(x).permute(1, 2, 0)\n", " loss = F.cross_entropy(output, target)\n", " loss.backward()\n", " \n", " clip_grad_norm(model, max_norm=max_norm)\n", " optim.step()\n", " optim.zero_grad()\n", " return loss.item()\n", "\n", "\n", "@torch.no_grad()\n", "def valid_step(model, x, y) -> float:\n", " target = y.transpose(0, 1)\n", " output = model(x).permute(1, 2, 0)\n", " loss = F.cross_entropy(output, target)\n", " return loss.item()" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "def train_step(model, optim, x, y, max_norm) -> float:\n", " target = y.transpose(0, 1)\n", " output = model(x).permute(1, 2, 0)\n", " loss = F.cross_entropy(output, target)\n", " loss.backward()\n", " \n", " clip_grad_norm(model, max_norm=max_norm)\n", " optim.step()\n", " optim.zero_grad()\n", " return loss.item()\n", "\n", "\n", "@torch.no_grad()\n", "def valid_step(model, x, y) -> float:\n", " target = y.transpose(0, 1)\n", " output = model(x).permute(1, 2, 0)\n", " loss = F.cross_entropy(output, target)\n", " return loss.item()" ] }, { "cell_type": "markdown", "id": "f103e37c", "metadata": { "papermill": { "duration": 0.002661, "end_time": "2025-01-13T14:53:20.482493", "exception": false, "start_time": "2025-01-13T14:53:20.479832", "status": "completed" }, "tags": [] }, "source": [ "Training the model:" ] }, { "cell_type": "code", "execution_count": 6, "id": "e3e7222c", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:20.488003Z", "iopub.status.busy": "2025-01-13T14:53:20.487849Z", "iopub.status.idle": "2025-01-13T14:54:00.288564Z", "shell.execute_reply": "2025-01-13T14:54:00.287748Z" }, "papermill": { "duration": 39.807202, "end_time": "2025-01-13T14:54:00.292128", "exception": false, "start_time": "2025-01-13T14:53:20.484926", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "612f499b3f004fc2bf3792055987c5c6", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-01-13T22:54:00.844243\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib_inline import backend_inline\n", "backend_inline.set_matplotlib_formats(\"svg\")\n", "\n", "plt.figure(figsize=(10, 4))\n", "plt.plot(train_losses, label=\"train\")\n", "plt.plot(np.array(range(1, len(valid_losses) + 1)) * 5, valid_losses, label=\"valid\")\n", "plt.grid(linestyle=\"dotted\", alpha=0.6)\n", "plt.ylabel(\"loss\")\n", "plt.xlabel(\"step\")\n", "plt.legend();" ] }, { "cell_type": "markdown", "id": "6eb6921e", "metadata": { "papermill": { "duration": 0.004769, "end_time": "2025-01-13T14:54:00.900893", "exception": false, "start_time": "2025-01-13T14:54:00.896124", "status": "completed" }, "tags": [] }, "source": [ "Note we addressed problem of exploding gradients, but not vanishing gradients." ] }, { "cell_type": "code", "execution_count": 8, "id": "1cb1bfaf", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:54:00.922026Z", "iopub.status.busy": "2025-01-13T14:54:00.921562Z", "iopub.status.idle": "2025-01-13T14:54:01.067035Z", "shell.execute_reply": "2025-01-13T14:54:01.066418Z" }, "papermill": { "duration": 0.154078, "end_time": "2025-01-13T14:54:01.068387", "exception": false, "start_time": "2025-01-13T14:54:00.914309", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "!mkdir -p artifacts/\n", "PATH = \"./artifacts/rnn_lm.pkl\"\n", "torch.save(model.state_dict(), PATH)" ] }, { "cell_type": "code", "execution_count": null, "id": "f8f8c703", "metadata": { "papermill": { "duration": 0.006384, "end_time": "2025-01-13T14:54:01.081017", "exception": false, "start_time": "2025-01-13T14:54:01.074633", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 45.292755, "end_time": "2025-01-13T14:54:03.716759", "environment_variables": {}, "exception": null, "input_path": "05c-training.ipynb", "output_path": "05c-training.ipynb", "parameters": {}, "start_time": "2025-01-13T14:53:18.424004", "version": "2.6.0" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "2b3aa100843b4ed08114e3180ba5d392": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e50679b1ff1c4f70966b5b6e3fce5ea5", "max": 5.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_ff895d1825c44657bff5139ae018dd14", "tabbable": null, "tooltip": null, "value": 5.0 } }, "364c484c23364e539319df386c7c1eb0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null } }, "3719a9292cb54875b8162e7e9a63dc92": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null } }, "44f7c88953f64f03bb3fb9387283733e": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "612f499b3f004fc2bf3792055987c5c6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_81b34dfad91442b599c912041c76276f", "IPY_MODEL_2b3aa100843b4ed08114e3180ba5d392", "IPY_MODEL_acaadf0d94454cb39bf3faaa9e93e567" ], "layout": "IPY_MODEL_44f7c88953f64f03bb3fb9387283733e", "tabbable": null, "tooltip": null } }, "81b34dfad91442b599c912041c76276f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_fea9d5ed0de241c09c1d62891fbde742", "placeholder": "​", "style": "IPY_MODEL_364c484c23364e539319df386c7c1eb0", "tabbable": null, "tooltip": null, "value": "100%" } }, "acaadf0d94454cb39bf3faaa9e93e567": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_cf6b62a434474eebbbea8068206a42b4", "placeholder": "​", "style": "IPY_MODEL_3719a9292cb54875b8162e7e9a63dc92", "tabbable": null, "tooltip": null, "value": " 5/5 [00:39<00:00,  7.29s/it]" } }, "cf6b62a434474eebbbea8068206a42b4": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e50679b1ff1c4f70966b5b6e3fce5ea5": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "fea9d5ed0de241c09c1d62891fbde742": { "model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "ff895d1825c44657bff5139ae018dd14": { "model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }