{ "cells": [ { "cell_type": "markdown", "id": "2c4f3e6d", "metadata": { "papermill": { "duration": 0.011205, "end_time": "2025-01-13T14:53:14.094763", "exception": false, "start_time": "2025-01-13T14:53:14.083558", "status": "completed" }, "tags": [] }, "source": [ "# RNN language model" ] }, { "cell_type": "code", "execution_count": 1, "id": "120fc610", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.110871Z", "iopub.status.busy": "2025-01-13T14:53:14.110393Z", "iopub.status.idle": "2025-01-13T14:53:14.796397Z", "shell.execute_reply": "2025-01-13T14:53:14.796084Z" }, "papermill": { "duration": 0.695097, "end_time": "2025-01-13T14:53:14.797931", "exception": false, "start_time": "2025-01-13T14:53:14.102834", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "markdown", "id": "1dd5f01c", "metadata": { "papermill": { "duration": 0.00127, "end_time": "2025-01-13T14:53:14.801008", "exception": false, "start_time": "2025-01-13T14:53:14.799738", "status": "completed" }, "tags": [] }, "source": [ "Our goal in this section is to train a character-level RNN language model to predict the next token at *each* step with varying-length context. Hence, during training, our model predicts on each time-step ({numref}`04-char-rnn`). The language model below is simply an RNN cell with an attached **logits layer** applied at each step." ] }, { "cell_type": "markdown", "id": "a05a2576", "metadata": { "papermill": { "duration": 0.001177, "end_time": "2025-01-13T14:53:14.803410", "exception": false, "start_time": "2025-01-13T14:53:14.802233", "status": "completed" }, "tags": [] }, "source": [ "\n", "```{figure} ../../../img/nn/04-char-rnn.svg\n", "---\n", "width: 550px\n", "name: 04-char-rnn\n", "align: center\n", "---\n", "Character-level RNN language model for predicting the next character at each step. [Source](https://www.d2l.ai/chapter_recurrent-neural-networks/rnn.html)\n", "```" ] }, { "cell_type": "markdown", "id": "02139545", "metadata": { "papermill": { "duration": 0.001159, "end_time": "2025-01-13T14:53:14.805748", "exception": false, "start_time": "2025-01-13T14:53:14.804589", "status": "completed" }, "tags": [] }, "source": [ "To implement a language model, we simply attach a linear layer on the RNN unit to compute logits. \n", "The linear layer performs matrix multiplication on the rightmost dimension of `outs` which contains the value of the state vector at each time step. Thus, as shown in {numref}`04-char-rnn` we have $T$ predictions with increasing context size[^1] $1, 2, \\ldots, T.$\n", "\n", "[^1]: Consequently, the model gets corrected at each time step, with variable-length dependency, during backward pass. " ] }, { "cell_type": "code", "execution_count": 2, "id": "33053e9e", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.808762Z", "iopub.status.busy": "2025-01-13T14:53:14.808610Z", "iopub.status.idle": "2025-01-13T14:53:14.833112Z", "shell.execute_reply": "2025-01-13T14:53:14.832846Z" }, "papermill": { "duration": 0.027072, "end_time": "2025-01-13T14:53:14.833972", "exception": false, "start_time": "2025-01-13T14:53:14.806900", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import torch\n",
       "import torch.nn as nn\n",
       "from typing import Type\n",
       "from functools import partial\n",
       "\n",
       "\n",
       "class RNNLanguageModel(nn.Module):\n",
       "    def __init__(self, \n",
       "        cell: Type[RNNBase],\n",
       "        inputs_dim: int,\n",
       "        hidden_dim: int,\n",
       "        vocab_size: int,\n",
       "        **kwargs\n",
       "    ):\n",
       "        super().__init__()\n",
       "        self.cell = cell(inputs_dim, hidden_dim, **kwargs)\n",
       "        self.linear = nn.Linear(hidden_dim, vocab_size)\n",
       "\n",
       "    def forward(self, x, state=None, return_state=False):\n",
       "        outs, state = self.cell(x, state)\n",
       "        outs = self.linear(outs)    # (T, B, H) -> (T, B, C)\n",
       "        return outs if not return_state else (outs, state)\n",
       "\n",
       "\n",
       "LanguageModel = lambda cell: partial(RNNLanguageModel, cell)\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn} \\PY{k}{as} \\PY{n+nn}{nn}\n", "\\PY{k+kn}{from} \\PY{n+nn}{typing} \\PY{k+kn}{import} \\PY{n}{Type}\n", "\\PY{k+kn}{from} \\PY{n+nn}{functools} \\PY{k+kn}{import} \\PY{n}{partial}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{RNNLanguageModel}\\PY{p}{(}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{Module}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \n", " \\PY{n}{cell}\\PY{p}{:} \\PY{n}{Type}\\PY{p}{[}\\PY{n}{RNNBase}\\PY{p}{]}\\PY{p}{,}\n", " \\PY{n}{inputs\\PYZus{}dim}\\PY{p}{:} \\PY{n+nb}{int}\\PY{p}{,}\n", " \\PY{n}{hidden\\PYZus{}dim}\\PY{p}{:} \\PY{n+nb}{int}\\PY{p}{,}\n", " \\PY{n}{vocab\\PYZus{}size}\\PY{p}{:} \\PY{n+nb}{int}\\PY{p}{,}\n", " \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\n", " \\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{cell} \\PY{o}{=} \\PY{n}{cell}\\PY{p}{(}\\PY{n}{inputs\\PYZus{}dim}\\PY{p}{,} \\PY{n}{hidden\\PYZus{}dim}\\PY{p}{,} \\PY{o}{*}\\PY{o}{*}\\PY{n}{kwargs}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{linear} \\PY{o}{=} \\PY{n}{nn}\\PY{o}{.}\\PY{n}{Linear}\\PY{p}{(}\\PY{n}{hidden\\PYZus{}dim}\\PY{p}{,} \\PY{n}{vocab\\PYZus{}size}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{forward}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{x}\\PY{p}{,} \\PY{n}{state}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{,} \\PY{n}{return\\PYZus{}state}\\PY{o}{=}\\PY{k+kc}{False}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{outs}\\PY{p}{,} \\PY{n}{state} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{cell}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{state}\\PY{p}{)}\n", " \\PY{n}{outs} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{outs}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} (T, B, H) \\PYZhy{}\\PYZgt{} (T, B, C)}\n", " \\PY{k}{return} \\PY{n}{outs} \\PY{k}{if} \\PY{o+ow}{not} \\PY{n}{return\\PYZus{}state} \\PY{k}{else} \\PY{p}{(}\\PY{n}{outs}\\PY{p}{,} \\PY{n}{state}\\PY{p}{)}\n", "\n", "\n", "\\PY{n}{LanguageModel} \\PY{o}{=} \\PY{k}{lambda} \\PY{n}{cell}\\PY{p}{:} \\PY{n}{partial}\\PY{p}{(}\\PY{n}{RNNLanguageModel}\\PY{p}{,} \\PY{n}{cell}\\PY{p}{)}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import torch\n", "import torch.nn as nn\n", "from typing import Type\n", "from functools import partial\n", "\n", "\n", "class RNNLanguageModel(nn.Module):\n", " def __init__(self, \n", " cell: Type[RNNBase],\n", " inputs_dim: int,\n", " hidden_dim: int,\n", " vocab_size: int,\n", " **kwargs\n", " ):\n", " super().__init__()\n", " self.cell = cell(inputs_dim, hidden_dim, **kwargs)\n", " self.linear = nn.Linear(hidden_dim, vocab_size)\n", "\n", " def forward(self, x, state=None, return_state=False):\n", " outs, state = self.cell(x, state)\n", " outs = self.linear(outs) # (T, B, H) -> (T, B, C)\n", " return outs if not return_state else (outs, state)\n", "\n", "\n", "LanguageModel = lambda cell: partial(RNNLanguageModel, cell)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import torch\n", "import torch.nn as nn\n", "from typing import Type\n", "from functools import partial\n", "\n", "\n", "class RNNLanguageModel(nn.Module):\n", " def __init__(self, \n", " cell: Type[RNNBase],\n", " inputs_dim: int,\n", " hidden_dim: int,\n", " vocab_size: int,\n", " **kwargs\n", " ):\n", " super().__init__()\n", " self.cell = cell(inputs_dim, hidden_dim, **kwargs)\n", " self.linear = nn.Linear(hidden_dim, vocab_size)\n", "\n", " def forward(self, x, state=None, return_state=False):\n", " outs, state = self.cell(x, state)\n", " outs = self.linear(outs) # (T, B, H) -> (T, B, C)\n", " return outs if not return_state else (outs, state)\n", "\n", "\n", "LanguageModel = lambda cell: partial(RNNLanguageModel, cell)" ] }, { "cell_type": "markdown", "id": "ea93725f", "metadata": { "papermill": { "duration": 0.001346, "end_time": "2025-01-13T14:53:14.836845", "exception": false, "start_time": "2025-01-13T14:53:14.835499", "status": "completed" }, "tags": [] }, "source": [ "
\n", "\n", "## Character sequences dataset" ] }, { "cell_type": "markdown", "id": "e349b7b9", "metadata": { "papermill": { "duration": 0.001302, "end_time": "2025-01-13T14:53:14.839455", "exception": false, "start_time": "2025-01-13T14:53:14.838153", "status": "completed" }, "tags": [] }, "source": [ "Our dataset consists of $T$ input-output pairs of characters **shifted** one time step:" ] }, { "cell_type": "code", "execution_count": 3, "id": "d8ec7cbf", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.843426Z", "iopub.status.busy": "2025-01-13T14:53:14.843222Z", "iopub.status.idle": "2025-01-13T14:53:14.848792Z", "shell.execute_reply": "2025-01-13T14:53:14.848557Z" }, "papermill": { "duration": 0.008143, "end_time": "2025-01-13T14:53:14.849577", "exception": false, "start_time": "2025-01-13T14:53:14.841434", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import torch.nn.functional as F\n",
       "from torch.utils.data import Dataset, DataLoader\n",
       "\n",
       "class SequenceDataset(Dataset):\n",
       "    def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):\n",
       "        super().__init__()\n",
       "        self.data = data\n",
       "        self.seq_len = seq_len\n",
       "        self.vocab_size = vocab_size\n",
       "\n",
       "    def __getitem__(self, i):\n",
       "        c = self.data[i: i + self.seq_len + 1]\n",
       "        x, y = c[:-1], c[1:]\n",
       "        x = F.one_hot(x, num_classes=self.vocab_size).float()\n",
       "        return x, y\n",
       "    \n",
       "    def __len__(self):\n",
       "        return len(self.data) - self.seq_len\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional} \\PY{k}{as} \\PY{n+nn}{F}\n", "\\PY{k+kn}{from} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{utils}\\PY{n+nn}{.}\\PY{n+nn}{data} \\PY{k+kn}{import} \\PY{n}{Dataset}\\PY{p}{,} \\PY{n}{DataLoader}\n", "\n", "\\PY{k}{class} \\PY{n+nc}{SequenceDataset}\\PY{p}{(}\\PY{n}{Dataset}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{data}\\PY{p}{:} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{Tensor}\\PY{p}{,} \\PY{n}{seq\\PYZus{}len}\\PY{p}{:} \\PY{n+nb}{int}\\PY{p}{,} \\PY{n}{vocab\\PYZus{}size}\\PY{p}{:} \\PY{n+nb}{int}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb}{super}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{data} \\PY{o}{=} \\PY{n}{data}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{seq\\PYZus{}len} \\PY{o}{=} \\PY{n}{seq\\PYZus{}len}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab\\PYZus{}size} \\PY{o}{=} \\PY{n}{vocab\\PYZus{}size}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}getitem\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{i}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{c} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{data}\\PY{p}{[}\\PY{n}{i}\\PY{p}{:} \\PY{n}{i} \\PY{o}{+} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{seq\\PYZus{}len} \\PY{o}{+} \\PY{l+m+mi}{1}\\PY{p}{]}\n", " \\PY{n}{x}\\PY{p}{,} \\PY{n}{y} \\PY{o}{=} \\PY{n}{c}\\PY{p}{[}\\PY{p}{:}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{p}{,} \\PY{n}{c}\\PY{p}{[}\\PY{l+m+mi}{1}\\PY{p}{:}\\PY{p}{]}\n", " \\PY{n}{x} \\PY{o}{=} \\PY{n}{F}\\PY{o}{.}\\PY{n}{one\\PYZus{}hot}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{num\\PYZus{}classes}\\PY{o}{=}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab\\PYZus{}size}\\PY{p}{)}\\PY{o}{.}\\PY{n}{float}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{x}\\PY{p}{,} \\PY{n}{y}\n", " \n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}len\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{data}\\PY{p}{)} \\PY{o}{\\PYZhy{}} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{seq\\PYZus{}len}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "class SequenceDataset(Dataset):\n", " def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):\n", " super().__init__()\n", " self.data = data\n", " self.seq_len = seq_len\n", " self.vocab_size = vocab_size\n", "\n", " def __getitem__(self, i):\n", " c = self.data[i: i + self.seq_len + 1]\n", " x, y = c[:-1], c[1:]\n", " x = F.one_hot(x, num_classes=self.vocab_size).float()\n", " return x, y\n", " \n", " def __len__(self):\n", " return len(self.data) - self.seq_len" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "class SequenceDataset(Dataset):\n", " def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):\n", " super().__init__()\n", " self.data = data\n", " self.seq_len = seq_len\n", " self.vocab_size = vocab_size\n", "\n", " def __getitem__(self, i):\n", " c = self.data[i: i + self.seq_len + 1]\n", " x, y = c[:-1], c[1:]\n", " x = F.one_hot(x, num_classes=self.vocab_size).float()\n", " return x, y\n", " \n", " def __len__(self):\n", " return len(self.data) - self.seq_len" ] }, { "cell_type": "markdown", "id": "e6937a18", "metadata": { "papermill": { "duration": 0.001475, "end_time": "2025-01-13T14:53:14.852656", "exception": false, "start_time": "2025-01-13T14:53:14.851181", "status": "completed" }, "tags": [] }, "source": [ "Training on the *Time Machine* text:" ] }, { "cell_type": "code", "execution_count": 4, "id": "39b73f77", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.856414Z", "iopub.status.busy": "2025-01-13T14:53:14.856273Z", "iopub.status.idle": "2025-01-13T14:53:14.896689Z", "shell.execute_reply": "2025-01-13T14:53:14.896406Z" }, "papermill": { "duration": 0.04345, "end_time": "2025-01-13T14:53:14.897634", "exception": false, "start_time": "2025-01-13T14:53:14.854184", "status": "completed" }, "tags": [ "hide-output", "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import re\n",
       "import os\n",
       "import torch\n",
       "import requests\n",
       "from collections import Counter\n",
       "from typing import Union, Optional, TypeVar, List\n",
       "\n",
       "from pathlib import Path\n",
       "\n",
       "DATA_DIR = Path("./data")\n",
       "DATA_DIR.mkdir(exist_ok=True)\n",
       "\n",
       "\n",
       "T = TypeVar("T")\n",
       "ScalarOrList = Union[T, List[T]]\n",
       "\n",
       "\n",
       "class Vocab:\n",
       "    def __init__(self, \n",
       "        text: str, \n",
       "        min_freq: int = 0, \n",
       "        reserved_tokens: Optional[List[str]] = None,\n",
       "        preprocess: bool = True\n",
       "    ):\n",
       "        text = self.preprocess(text) if preprocess else text\n",
       "        tokens = list(text)\n",
       "        counter = Counter(tokens)\n",
       "        reserved_tokens = reserved_tokens or []\n",
       "        self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)\n",
       "        self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]\n",
       "        self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}\n",
       "\n",
       "    def __len__(self):\n",
       "        return len(self.itos)\n",
       "    \n",
       "    def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:\n",
       "        if isinstance(tokens, str):\n",
       "            return self.stoi.get(tokens, self.unk)\n",
       "        else:\n",
       "            return [self.__getitem__(tok) for tok in tokens]\n",
       "\n",
       "    def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:\n",
       "        if isinstance(indices, int):\n",
       "            return self.itos[indices]\n",
       "        else:\n",
       "            return [self.itos[int(index)] for index in indices]\n",
       "            \n",
       "    def preprocess(self, text: str):\n",
       "        return re.sub("[^A-Za-z]+", " ", text).lower().strip()\n",
       "\n",
       "    @property\n",
       "    def unk_token(self) -> str:\n",
       "        return "▮"\n",
       "\n",
       "    @property\n",
       "    def unk(self) -> int:\n",
       "        return self.stoi[self.unk_token]\n",
       "\n",
       "    @property\n",
       "    def tokens(self) -> List[int]:\n",
       "        return self.itos\n",
       "\n",
       "\n",
       "class Tokenizer:\n",
       "    def __init__(self, vocab: Vocab):\n",
       "        self.vocab = vocab\n",
       "\n",
       "    def tokenize(self, text: str) -> List[str]:\n",
       "        UNK = self.vocab.unk_token\n",
       "        tokens = self.vocab.stoi.keys()\n",
       "        return [c if c in tokens else UNK for c in list(text)]\n",
       "\n",
       "    def encode(self, text: str) -> torch.Tensor:\n",
       "        x = self.vocab[self.tokenize(text)]\n",
       "        return torch.tensor(x, dtype=torch.int64)\n",
       "\n",
       "    def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:\n",
       "        return "".join(self.vocab.to_tokens(indices))\n",
       "\n",
       "    @property\n",
       "    def vocab_size(self) -> int:\n",
       "        return len(self.vocab)\n",
       "\n",
       "\n",
       "class TimeMachine:\n",
       "    def __init__(self, download=False, path=None):\n",
       "        DEFAULT_PATH = str((DATA_DIR / "time_machine.txt").absolute())\n",
       "        self.filepath = path or DEFAULT_PATH\n",
       "        if download or not os.path.exists(self.filepath):\n",
       "            self._download()\n",
       "        \n",
       "    def _download(self):\n",
       "        url = "https://www.gutenberg.org/cache/epub/35/pg35.txt"\n",
       "        print(f"Downloading text from {url} ...", end=" ")\n",
       "        response = requests.get(url, stream=True)\n",
       "        response.raise_for_status()\n",
       "        print("OK!")\n",
       "        with open(self.filepath, "wb") as output:\n",
       "            output.write(response.content)\n",
       "        \n",
       "    def _load_text(self):\n",
       "        with open(self.filepath, "r") as f:\n",
       "            text = f.read()\n",
       "        s = "*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"\n",
       "        e = "*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"\n",
       "        return text[text.find(s) + len(s): text.find(e)]\n",
       "    \n",
       "    def build(self, vocab: Optional[Vocab] = None):\n",
       "        self.text = self._load_text()\n",
       "        vocab = vocab or Vocab(self.text)\n",
       "        tokenizer = Tokenizer(vocab)\n",
       "        encoded_text = tokenizer.encode(vocab.preprocess(self.text))\n",
       "        return encoded_text, tokenizer\n",
       "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{k+kn}{import} \\PY{n+nn}{re}\n", "\\PY{k+kn}{import} \\PY{n+nn}{os}\n", "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", "\\PY{k+kn}{import} \\PY{n+nn}{requests}\n", "\\PY{k+kn}{from} \\PY{n+nn}{collections} \\PY{k+kn}{import} \\PY{n}{Counter}\n", "\\PY{k+kn}{from} \\PY{n+nn}{typing} \\PY{k+kn}{import} \\PY{n}{Union}\\PY{p}{,} \\PY{n}{Optional}\\PY{p}{,} \\PY{n}{TypeVar}\\PY{p}{,} \\PY{n}{List}\n", "\n", "\\PY{k+kn}{from} \\PY{n+nn}{pathlib} \\PY{k+kn}{import} \\PY{n}{Path}\n", "\n", "\\PY{n}{DATA\\PYZus{}DIR} \\PY{o}{=} \\PY{n}{Path}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{./data}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{DATA\\PYZus{}DIR}\\PY{o}{.}\\PY{n}{mkdir}\\PY{p}{(}\\PY{n}{exist\\PYZus{}ok}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{)}\n", "\n", "\n", "\\PY{n}{T} \\PY{o}{=} \\PY{n}{TypeVar}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{T}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", "\\PY{n}{ScalarOrList} \\PY{o}{=} \\PY{n}{Union}\\PY{p}{[}\\PY{n}{T}\\PY{p}{,} \\PY{n}{List}\\PY{p}{[}\\PY{n}{T}\\PY{p}{]}\\PY{p}{]}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Vocab}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \n", " \\PY{n}{text}\\PY{p}{:} \\PY{n+nb}{str}\\PY{p}{,} \n", " \\PY{n}{min\\PYZus{}freq}\\PY{p}{:} \\PY{n+nb}{int} \\PY{o}{=} \\PY{l+m+mi}{0}\\PY{p}{,} \n", " \\PY{n}{reserved\\PYZus{}tokens}\\PY{p}{:} \\PY{n}{Optional}\\PY{p}{[}\\PY{n}{List}\\PY{p}{[}\\PY{n+nb}{str}\\PY{p}{]}\\PY{p}{]} \\PY{o}{=} \\PY{k+kc}{None}\\PY{p}{,}\n", " \\PY{n}{preprocess}\\PY{p}{:} \\PY{n+nb}{bool} \\PY{o}{=} \\PY{k+kc}{True}\n", " \\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{text} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{preprocess}\\PY{p}{(}\\PY{n}{text}\\PY{p}{)} \\PY{k}{if} \\PY{n}{preprocess} \\PY{k}{else} \\PY{n}{text}\n", " \\PY{n}{tokens} \\PY{o}{=} \\PY{n+nb}{list}\\PY{p}{(}\\PY{n}{text}\\PY{p}{)}\n", " \\PY{n}{counter} \\PY{o}{=} \\PY{n}{Counter}\\PY{p}{(}\\PY{n}{tokens}\\PY{p}{)}\n", " \\PY{n}{reserved\\PYZus{}tokens} \\PY{o}{=} \\PY{n}{reserved\\PYZus{}tokens} \\PY{o+ow}{or} \\PY{p}{[}\\PY{p}{]}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{token\\PYZus{}freqs} \\PY{o}{=} \\PY{n+nb}{sorted}\\PY{p}{(}\\PY{n}{counter}\\PY{o}{.}\\PY{n}{items}\\PY{p}{(}\\PY{p}{)}\\PY{p}{,} \\PY{n}{key}\\PY{o}{=}\\PY{k}{lambda} \\PY{n}{x}\\PY{p}{:} \\PY{n}{x}\\PY{p}{[}\\PY{l+m+mi}{1}\\PY{p}{]}\\PY{p}{,} \\PY{n}{reverse}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos} \\PY{o}{=} \\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{unk\\PYZus{}token}\\PY{p}{]} \\PY{o}{+} \\PY{n}{reserved\\PYZus{}tokens} \\PY{o}{+} \\PY{p}{[}\\PY{n}{tok} \\PY{k}{for} \\PY{n}{tok}\\PY{p}{,} \\PY{n}{f} \\PY{o+ow}{in} \\PY{n+nb}{filter}\\PY{p}{(}\\PY{k}{lambda} \\PY{n}{tokf}\\PY{p}{:} \\PY{n}{tokf}\\PY{p}{[}\\PY{l+m+mi}{1}\\PY{p}{]} \\PY{o}{\\PYZgt{}}\\PY{o}{=} \\PY{n}{min\\PYZus{}freq}\\PY{p}{,} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{token\\PYZus{}freqs}\\PY{p}{)}\\PY{p}{]}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{stoi} \\PY{o}{=} \\PY{p}{\\PYZob{}}\\PY{n}{tok}\\PY{p}{:} \\PY{n}{idx} \\PY{k}{for} \\PY{n}{idx}\\PY{p}{,} \\PY{n}{tok} \\PY{o+ow}{in} \\PY{n+nb}{enumerate}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\n", "\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}len\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos}\\PY{p}{)}\n", " \n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}getitem\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{tokens}\\PY{p}{:} \\PY{n}{ScalarOrList}\\PY{p}{[}\\PY{n+nb}{str}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n}{ScalarOrList}\\PY{p}{[}\\PY{n+nb}{int}\\PY{p}{]}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n+nb}{isinstance}\\PY{p}{(}\\PY{n}{tokens}\\PY{p}{,} \\PY{n+nb}{str}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{stoi}\\PY{o}{.}\\PY{n}{get}\\PY{p}{(}\\PY{n}{tokens}\\PY{p}{,} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{unk}\\PY{p}{)}\n", " \\PY{k}{else}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n+nf+fm}{\\PYZus{}\\PYZus{}getitem\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n}{tok}\\PY{p}{)} \\PY{k}{for} \\PY{n}{tok} \\PY{o+ow}{in} \\PY{n}{tokens}\\PY{p}{]}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{to\\PYZus{}tokens}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{indices}\\PY{p}{:} \\PY{n}{ScalarOrList}\\PY{p}{[}\\PY{n+nb}{int}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n}{ScalarOrList}\\PY{p}{[}\\PY{n+nb}{str}\\PY{p}{]}\\PY{p}{:}\n", " \\PY{k}{if} \\PY{n+nb}{isinstance}\\PY{p}{(}\\PY{n}{indices}\\PY{p}{,} \\PY{n+nb}{int}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos}\\PY{p}{[}\\PY{n}{indices}\\PY{p}{]}\n", " \\PY{k}{else}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos}\\PY{p}{[}\\PY{n+nb}{int}\\PY{p}{(}\\PY{n}{index}\\PY{p}{)}\\PY{p}{]} \\PY{k}{for} \\PY{n}{index} \\PY{o+ow}{in} \\PY{n}{indices}\\PY{p}{]}\n", " \n", " \\PY{k}{def} \\PY{n+nf}{preprocess}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{text}\\PY{p}{:} \\PY{n+nb}{str}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n}{re}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{[\\PYZca{}A\\PYZhy{}Za\\PYZhy{}z]+}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{ }\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{text}\\PY{p}{)}\\PY{o}{.}\\PY{n}{lower}\\PY{p}{(}\\PY{p}{)}\\PY{o}{.}\\PY{n}{strip}\\PY{p}{(}\\PY{p}{)}\n", "\n", " \\PY{n+nd}{@property}\n", " \\PY{k}{def} \\PY{n+nf}{unk\\PYZus{}token}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{str}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{▮}\\PY{l+s+s2}{\\PYZdq{}}\n", "\n", " \\PY{n+nd}{@property}\n", " \\PY{k}{def} \\PY{n+nf}{unk}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{int}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{stoi}\\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{unk\\PYZus{}token}\\PY{p}{]}\n", "\n", " \\PY{n+nd}{@property}\n", " \\PY{k}{def} \\PY{n+nf}{tokens}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n}{List}\\PY{p}{[}\\PY{n+nb}{int}\\PY{p}{]}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{itos}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{Tokenizer}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{vocab}\\PY{p}{:} \\PY{n}{Vocab}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab} \\PY{o}{=} \\PY{n}{vocab}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{tokenize}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{text}\\PY{p}{:} \\PY{n+nb}{str}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n}{List}\\PY{p}{[}\\PY{n+nb}{str}\\PY{p}{]}\\PY{p}{:}\n", " \\PY{n}{UNK} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab}\\PY{o}{.}\\PY{n}{unk\\PYZus{}token}\n", " \\PY{n}{tokens} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab}\\PY{o}{.}\\PY{n}{stoi}\\PY{o}{.}\\PY{n}{keys}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{p}{[}\\PY{n}{c} \\PY{k}{if} \\PY{n}{c} \\PY{o+ow}{in} \\PY{n}{tokens} \\PY{k}{else} \\PY{n}{UNK} \\PY{k}{for} \\PY{n}{c} \\PY{o+ow}{in} \\PY{n+nb}{list}\\PY{p}{(}\\PY{n}{text}\\PY{p}{)}\\PY{p}{]}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{encode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{text}\\PY{p}{:} \\PY{n+nb}{str}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{Tensor}\\PY{p}{:}\n", " \\PY{n}{x} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab}\\PY{p}{[}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{tokenize}\\PY{p}{(}\\PY{n}{text}\\PY{p}{)}\\PY{p}{]}\n", " \\PY{k}{return} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tensor}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{int64}\\PY{p}{)}\n", "\n", " \\PY{k}{def} \\PY{n+nf}{decode}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{indices}\\PY{p}{:} \\PY{n}{Union}\\PY{p}{[}\\PY{n}{ScalarOrList}\\PY{p}{[}\\PY{n+nb}{int}\\PY{p}{]}\\PY{p}{,} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{Tensor}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{str}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{o}{.}\\PY{n}{join}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab}\\PY{o}{.}\\PY{n}{to\\PYZus{}tokens}\\PY{p}{(}\\PY{n}{indices}\\PY{p}{)}\\PY{p}{)}\n", "\n", " \\PY{n+nd}{@property}\n", " \\PY{k}{def} \\PY{n+nf}{vocab\\PYZus{}size}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{n+nb}{int}\\PY{p}{:}\n", " \\PY{k}{return} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{vocab}\\PY{p}{)}\n", "\n", "\n", "\\PY{k}{class} \\PY{n+nc}{TimeMachine}\\PY{p}{:}\n", " \\PY{k}{def} \\PY{n+nf+fm}{\\PYZus{}\\PYZus{}init\\PYZus{}\\PYZus{}}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{download}\\PY{o}{=}\\PY{k+kc}{False}\\PY{p}{,} \\PY{n}{path}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{DEFAULT\\PYZus{}PATH} \\PY{o}{=} \\PY{n+nb}{str}\\PY{p}{(}\\PY{p}{(}\\PY{n}{DATA\\PYZus{}DIR} \\PY{o}{/} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{time\\PYZus{}machine.txt}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{o}{.}\\PY{n}{absolute}\\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{filepath} \\PY{o}{=} \\PY{n}{path} \\PY{o+ow}{or} \\PY{n}{DEFAULT\\PYZus{}PATH}\n", " \\PY{k}{if} \\PY{n}{download} \\PY{o+ow}{or} \\PY{o+ow}{not} \\PY{n}{os}\\PY{o}{.}\\PY{n}{path}\\PY{o}{.}\\PY{n}{exists}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{filepath}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}download}\\PY{p}{(}\\PY{p}{)}\n", " \n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}download}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{url} \\PY{o}{=} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{https://www.gutenberg.org/cache/epub/35/pg35.txt}\\PY{l+s+s2}{\\PYZdq{}}\n", " \\PY{n+nb}{print}\\PY{p}{(}\\PY{l+s+sa}{f}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{Downloading text from }\\PY{l+s+si}{\\PYZob{}}\\PY{n}{url}\\PY{l+s+si}{\\PYZcb{}}\\PY{l+s+s2}{ ...}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{n}{end}\\PY{o}{=}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{ }\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", " \\PY{n}{response} \\PY{o}{=} \\PY{n}{requests}\\PY{o}{.}\\PY{n}{get}\\PY{p}{(}\\PY{n}{url}\\PY{p}{,} \\PY{n}{stream}\\PY{o}{=}\\PY{k+kc}{True}\\PY{p}{)}\n", " \\PY{n}{response}\\PY{o}{.}\\PY{n}{raise\\PYZus{}for\\PYZus{}status}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n+nb}{print}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{OK!}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\n", " \\PY{k}{with} \\PY{n+nb}{open}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{filepath}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{wb}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{k}{as} \\PY{n}{output}\\PY{p}{:}\n", " \\PY{n}{output}\\PY{o}{.}\\PY{n}{write}\\PY{p}{(}\\PY{n}{response}\\PY{o}{.}\\PY{n}{content}\\PY{p}{)}\n", " \n", " \\PY{k}{def} \\PY{n+nf}{\\PYZus{}load\\PYZus{}text}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n+nb}{open}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{filepath}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{r}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{k}{as} \\PY{n}{f}\\PY{p}{:}\n", " \\PY{n}{text} \\PY{o}{=} \\PY{n}{f}\\PY{o}{.}\\PY{n}{read}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{s} \\PY{o}{=} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***}\\PY{l+s+s2}{\\PYZdq{}}\n", " \\PY{n}{e} \\PY{o}{=} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***}\\PY{l+s+s2}{\\PYZdq{}}\n", " \\PY{k}{return} \\PY{n}{text}\\PY{p}{[}\\PY{n}{text}\\PY{o}{.}\\PY{n}{find}\\PY{p}{(}\\PY{n}{s}\\PY{p}{)} \\PY{o}{+} \\PY{n+nb}{len}\\PY{p}{(}\\PY{n}{s}\\PY{p}{)}\\PY{p}{:} \\PY{n}{text}\\PY{o}{.}\\PY{n}{find}\\PY{p}{(}\\PY{n}{e}\\PY{p}{)}\\PY{p}{]}\n", " \n", " \\PY{k}{def} \\PY{n+nf}{build}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{p}{,} \\PY{n}{vocab}\\PY{p}{:} \\PY{n}{Optional}\\PY{p}{[}\\PY{n}{Vocab}\\PY{p}{]} \\PY{o}{=} \\PY{k+kc}{None}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{text} \\PY{o}{=} \\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{\\PYZus{}load\\PYZus{}text}\\PY{p}{(}\\PY{p}{)}\n", " \\PY{n}{vocab} \\PY{o}{=} \\PY{n}{vocab} \\PY{o+ow}{or} \\PY{n}{Vocab}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{text}\\PY{p}{)}\n", " \\PY{n}{tokenizer} \\PY{o}{=} \\PY{n}{Tokenizer}\\PY{p}{(}\\PY{n}{vocab}\\PY{p}{)}\n", " \\PY{n}{encoded\\PYZus{}text} \\PY{o}{=} \\PY{n}{tokenizer}\\PY{o}{.}\\PY{n}{encode}\\PY{p}{(}\\PY{n}{vocab}\\PY{o}{.}\\PY{n}{preprocess}\\PY{p}{(}\\PY{n+nb+bp}{self}\\PY{o}{.}\\PY{n}{text}\\PY{p}{)}\\PY{p}{)}\n", " \\PY{k}{return} \\PY{n}{encoded\\PYZus{}text}\\PY{p}{,} \\PY{n}{tokenizer}\n", "\\end{Verbatim}\n" ], "text/plain": [ "import re\n", "import os\n", "import torch\n", "import requests\n", "from collections import Counter\n", "from typing import Union, Optional, TypeVar, List\n", "\n", "from pathlib import Path\n", "\n", "DATA_DIR = Path(\"./data\")\n", "DATA_DIR.mkdir(exist_ok=True)\n", "\n", "\n", "T = TypeVar(\"T\")\n", "ScalarOrList = Union[T, List[T]]\n", "\n", "\n", "class Vocab:\n", " def __init__(self, \n", " text: str, \n", " min_freq: int = 0, \n", " reserved_tokens: Optional[List[str]] = None,\n", " preprocess: bool = True\n", " ):\n", " text = self.preprocess(text) if preprocess else text\n", " tokens = list(text)\n", " counter = Counter(tokens)\n", " reserved_tokens = reserved_tokens or []\n", " self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)\n", " self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]\n", " self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}\n", "\n", " def __len__(self):\n", " return len(self.itos)\n", " \n", " def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:\n", " if isinstance(tokens, str):\n", " return self.stoi.get(tokens, self.unk)\n", " else:\n", " return [self.__getitem__(tok) for tok in tokens]\n", "\n", " def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:\n", " if isinstance(indices, int):\n", " return self.itos[indices]\n", " else:\n", " return [self.itos[int(index)] for index in indices]\n", " \n", " def preprocess(self, text: str):\n", " return re.sub(\"[^A-Za-z]+\", \" \", text).lower().strip()\n", "\n", " @property\n", " def unk_token(self) -> str:\n", " return \"▮\"\n", "\n", " @property\n", " def unk(self) -> int:\n", " return self.stoi[self.unk_token]\n", "\n", " @property\n", " def tokens(self) -> List[int]:\n", " return self.itos\n", "\n", "\n", "class Tokenizer:\n", " def __init__(self, vocab: Vocab):\n", " self.vocab = vocab\n", "\n", " def tokenize(self, text: str) -> List[str]:\n", " UNK = self.vocab.unk_token\n", " tokens = self.vocab.stoi.keys()\n", " return [c if c in tokens else UNK for c in list(text)]\n", "\n", " def encode(self, text: str) -> torch.Tensor:\n", " x = self.vocab[self.tokenize(text)]\n", " return torch.tensor(x, dtype=torch.int64)\n", "\n", " def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:\n", " return \"\".join(self.vocab.to_tokens(indices))\n", "\n", " @property\n", " def vocab_size(self) -> int:\n", " return len(self.vocab)\n", "\n", "\n", "class TimeMachine:\n", " def __init__(self, download=False, path=None):\n", " DEFAULT_PATH = str((DATA_DIR / \"time_machine.txt\").absolute())\n", " self.filepath = path or DEFAULT_PATH\n", " if download or not os.path.exists(self.filepath):\n", " self._download()\n", " \n", " def _download(self):\n", " url = \"https://www.gutenberg.org/cache/epub/35/pg35.txt\"\n", " print(f\"Downloading text from {url} ...\", end=\" \")\n", " response = requests.get(url, stream=True)\n", " response.raise_for_status()\n", " print(\"OK!\")\n", " with open(self.filepath, \"wb\") as output:\n", " output.write(response.content)\n", " \n", " def _load_text(self):\n", " with open(self.filepath, \"r\") as f:\n", " text = f.read()\n", " s = \"*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***\"\n", " e = \"*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***\"\n", " return text[text.find(s) + len(s): text.find(e)]\n", " \n", " def build(self, vocab: Optional[Vocab] = None):\n", " self.text = self._load_text()\n", " vocab = vocab or Vocab(self.text)\n", " tokenizer = Tokenizer(vocab)\n", " encoded_text = tokenizer.encode(vocab.preprocess(self.text))\n", " return encoded_text, tokenizer" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%save\n", "import re\n", "import os\n", "import torch\n", "import requests\n", "from collections import Counter\n", "from typing import Union, Optional, TypeVar, List\n", "\n", "from pathlib import Path\n", "\n", "DATA_DIR = Path(\"./data\")\n", "DATA_DIR.mkdir(exist_ok=True)\n", "\n", "\n", "T = TypeVar(\"T\")\n", "ScalarOrList = Union[T, List[T]]\n", "\n", "\n", "class Vocab:\n", " def __init__(self, \n", " text: str, \n", " min_freq: int = 0, \n", " reserved_tokens: Optional[List[str]] = None,\n", " preprocess: bool = True\n", " ):\n", " text = self.preprocess(text) if preprocess else text\n", " tokens = list(text)\n", " counter = Counter(tokens)\n", " reserved_tokens = reserved_tokens or []\n", " self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)\n", " self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]\n", " self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}\n", "\n", " def __len__(self):\n", " return len(self.itos)\n", " \n", " def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:\n", " if isinstance(tokens, str):\n", " return self.stoi.get(tokens, self.unk)\n", " else:\n", " return [self.__getitem__(tok) for tok in tokens]\n", "\n", " def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:\n", " if isinstance(indices, int):\n", " return self.itos[indices]\n", " else:\n", " return [self.itos[int(index)] for index in indices]\n", " \n", " def preprocess(self, text: str):\n", " return re.sub(\"[^A-Za-z]+\", \" \", text).lower().strip()\n", "\n", " @property\n", " def unk_token(self) -> str:\n", " return \"▮\"\n", "\n", " @property\n", " def unk(self) -> int:\n", " return self.stoi[self.unk_token]\n", "\n", " @property\n", " def tokens(self) -> List[int]:\n", " return self.itos\n", "\n", "\n", "class Tokenizer:\n", " def __init__(self, vocab: Vocab):\n", " self.vocab = vocab\n", "\n", " def tokenize(self, text: str) -> List[str]:\n", " UNK = self.vocab.unk_token\n", " tokens = self.vocab.stoi.keys()\n", " return [c if c in tokens else UNK for c in list(text)]\n", "\n", " def encode(self, text: str) -> torch.Tensor:\n", " x = self.vocab[self.tokenize(text)]\n", " return torch.tensor(x, dtype=torch.int64)\n", "\n", " def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:\n", " return \"\".join(self.vocab.to_tokens(indices))\n", "\n", " @property\n", " def vocab_size(self) -> int:\n", " return len(self.vocab)\n", "\n", "\n", "class TimeMachine:\n", " def __init__(self, download=False, path=None):\n", " DEFAULT_PATH = str((DATA_DIR / \"time_machine.txt\").absolute())\n", " self.filepath = path or DEFAULT_PATH\n", " if download or not os.path.exists(self.filepath):\n", " self._download()\n", " \n", " def _download(self):\n", " url = \"https://www.gutenberg.org/cache/epub/35/pg35.txt\"\n", " print(f\"Downloading text from {url} ...\", end=\" \")\n", " response = requests.get(url, stream=True)\n", " response.raise_for_status()\n", " print(\"OK!\")\n", " with open(self.filepath, \"wb\") as output:\n", " output.write(response.content)\n", " \n", " def _load_text(self):\n", " with open(self.filepath, \"r\") as f:\n", " text = f.read()\n", " s = \"*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***\"\n", " e = \"*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***\"\n", " return text[text.find(s) + len(s): text.find(e)]\n", " \n", " def build(self, vocab: Optional[Vocab] = None):\n", " self.text = self._load_text()\n", " vocab = vocab or Vocab(self.text)\n", " tokenizer = Tokenizer(vocab)\n", " encoded_text = tokenizer.encode(vocab.preprocess(self.text))\n", " return encoded_text, tokenizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "c7a0c089", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.903039Z", "iopub.status.busy": "2025-01-13T14:53:14.902310Z", "iopub.status.idle": "2025-01-13T14:53:14.990142Z", "shell.execute_reply": "2025-01-13T14:53:14.989424Z" }, "papermill": { "duration": 0.092184, "end_time": "2025-01-13T14:53:14.991921", "exception": false, "start_time": "2025-01-13T14:53:14.899737", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from torch.utils.data import random_split\n", "\n", "def collate_fn(batch):\n", " \"\"\"Transforming the data to sequence-first format.\"\"\"\n", " x, y = zip(*batch)\n", " x = torch.stack(x, 1) # (T, B, vocab_size)\n", " y = torch.stack(y, 1) # (T, B)\n", " return x, y\n", "\n", "\n", "data, tokenizer = TimeMachine().build()\n", "VOCAB_SIZE = tokenizer.vocab_size\n", "dataset = SequenceDataset(data, seq_len=10, vocab_size=VOCAB_SIZE)\n", "train_dataset, valid_dataset = random_split(dataset, [0.80, 0.20])\n", "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)\n", "valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)" ] }, { "cell_type": "markdown", "id": "8c3a0756", "metadata": { "papermill": { "duration": 0.002064, "end_time": "2025-01-13T14:53:14.996357", "exception": false, "start_time": "2025-01-13T14:53:14.994293", "status": "completed" }, "tags": [] }, "source": [ "The batch index (i.e. starting point) is shuffled, but the ordering in each sequence is intact:" ] }, { "cell_type": "code", "execution_count": 6, "id": "117304f1", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:15.003321Z", "iopub.status.busy": "2025-01-13T14:53:15.003089Z", "iopub.status.idle": "2025-01-13T14:53:15.010516Z", "shell.execute_reply": "2025-01-13T14:53:15.010113Z" }, "papermill": { "duration": 0.01181, "end_time": "2025-01-13T14:53:15.011554", "exception": false, "start_time": "2025-01-13T14:53:14.999744", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y --> \n", " --> o\n", "o --> w\n", "w --> n\n", "n --> \n", " --> e\n", "e --> x\n", "x --> p\n", "p --> e\n", "e --> n\n" ] } ], "source": [ "x, y = next(iter(train_loader))\n", "\n", "a, T = 1, dataset.seq_len\n", "x_chars = tokenizer.decode(torch.argmax(x[:, a], dim=1)) # inputs are one-hot\n", "y_chars = tokenizer.decode(y[:, a])\n", "for i in range(T):\n", " print(f\"{x_chars[i]} --> {y_chars[i]}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "7f2d90b9", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:15.016373Z", "iopub.status.busy": "2025-01-13T14:53:15.016118Z", "iopub.status.idle": "2025-01-13T14:53:15.019348Z", "shell.execute_reply": "2025-01-13T14:53:15.018972Z" }, "papermill": { "duration": 0.006735, "end_time": "2025-01-13T14:53:15.020451", "exception": false, "start_time": "2025-01-13T14:53:15.013716", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 32, 28]) torch.Size([10, 32])\n", "inputs: tensor([ 4, 6, 11, 1, 17, 9, 2, 6, 1, 5])\n", "target: tensor([ 6, 11, 1, 17, 9, 2, 6, 1, 5, 1])\n" ] } ], "source": [ "print(x.shape, y.shape)\n", "print(\"inputs:\", torch.argmax(x[:, 0], dim=-1))\n", "print(\"target:\", y[:, 0])" ] }, { "cell_type": "markdown", "id": "e3c96a8f", "metadata": { "papermill": { "duration": 0.001948, "end_time": "2025-01-13T14:53:15.024427", "exception": false, "start_time": "2025-01-13T14:53:15.022479", "status": "completed" }, "tags": [] }, "source": [ "PyTorch `F.cross_entropy` expects input `(B, C, T)` and target `(B, T)`:" ] }, { "cell_type": "code", "execution_count": 8, "id": "8f888413", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:15.029569Z", "iopub.status.busy": "2025-01-13T14:53:15.029443Z", "iopub.status.idle": "2025-01-13T14:53:15.036667Z", "shell.execute_reply": "2025-01-13T14:53:15.036336Z" }, "papermill": { "duration": 0.010749, "end_time": "2025-01-13T14:53:15.037758", "exception": false, "start_time": "2025-01-13T14:53:15.027009", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor(3.3744, grad_fn=)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch.nn.functional as F\n", "\n", "x, y = next(iter(train_loader))\n", "model = LanguageModel(RNN)(VOCAB_SIZE, 5, VOCAB_SIZE)\n", "loss = F.cross_entropy(model(x).permute(1, 2, 0), y.transpose(0, 1))\n", "loss" ] }, { "cell_type": "code", "execution_count": null, "id": "889eba7e", "metadata": { "papermill": { "duration": 0.001985, "end_time": "2025-01-13T14:53:15.041970", "exception": false, "start_time": "2025-01-13T14:53:15.039985", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "papermill": { "default_parameters": {}, "duration": 2.549128, "end_time": "2025-01-13T14:53:15.784564", "environment_variables": {}, "exception": null, "input_path": "05b-rnn-lm.ipynb", "output_path": "05b-rnn-lm.ipynb", "parameters": {}, "start_time": "2025-01-13T14:53:13.235436", "version": "2.6.0" } }, "nbformat": 4, "nbformat_minor": 5 }