{ "cells": [ { "cell_type": "markdown", "id": "2c4f3e6d", "metadata": { "papermill": { "duration": 0.011205, "end_time": "2025-01-13T14:53:14.094763", "exception": false, "start_time": "2025-01-13T14:53:14.083558", "status": "completed" }, "tags": [] }, "source": [ "# RNN language model" ] }, { "cell_type": "code", "execution_count": 1, "id": "120fc610", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.110871Z", "iopub.status.busy": "2025-01-13T14:53:14.110393Z", "iopub.status.idle": "2025-01-13T14:53:14.796397Z", "shell.execute_reply": "2025-01-13T14:53:14.796084Z" }, "papermill": { "duration": 0.695097, "end_time": "2025-01-13T14:53:14.797931", "exception": false, "start_time": "2025-01-13T14:53:14.102834", "status": "completed" }, "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "from chapter import *" ] }, { "cell_type": "markdown", "id": "1dd5f01c", "metadata": { "papermill": { "duration": 0.00127, "end_time": "2025-01-13T14:53:14.801008", "exception": false, "start_time": "2025-01-13T14:53:14.799738", "status": "completed" }, "tags": [] }, "source": [ "Our goal in this section is to train a character-level RNN language model to predict the next token at *each* step with varying-length context. Hence, during training, our model predicts on each time-step ({numref}`04-char-rnn`). The language model below is simply an RNN cell with an attached **logits layer** applied at each step." ] }, { "cell_type": "markdown", "id": "a05a2576", "metadata": { "papermill": { "duration": 0.001177, "end_time": "2025-01-13T14:53:14.803410", "exception": false, "start_time": "2025-01-13T14:53:14.802233", "status": "completed" }, "tags": [] }, "source": [ "\n", "```{figure} ../../../img/nn/04-char-rnn.svg\n", "---\n", "width: 550px\n", "name: 04-char-rnn\n", "align: center\n", "---\n", "Character-level RNN language model for predicting the next character at each step. [Source](https://www.d2l.ai/chapter_recurrent-neural-networks/rnn.html)\n", "```" ] }, { "cell_type": "markdown", "id": "02139545", "metadata": { "papermill": { "duration": 0.001159, "end_time": "2025-01-13T14:53:14.805748", "exception": false, "start_time": "2025-01-13T14:53:14.804589", "status": "completed" }, "tags": [] }, "source": [ "To implement a language model, we simply attach a linear layer on the RNN unit to compute logits. \n", "The linear layer performs matrix multiplication on the rightmost dimension of `outs` which contains the value of the state vector at each time step. Thus, as shown in {numref}`04-char-rnn` we have $T$ predictions with increasing context size[^1] $1, 2, \\ldots, T.$\n", "\n", "[^1]: Consequently, the model gets corrected at each time step, with variable-length dependency, during backward pass. " ] }, { "cell_type": "code", "execution_count": 2, "id": "33053e9e", "metadata": { "execution": { "iopub.execute_input": "2025-01-13T14:53:14.808762Z", "iopub.status.busy": "2025-01-13T14:53:14.808610Z", "iopub.status.idle": "2025-01-13T14:53:14.833112Z", "shell.execute_reply": "2025-01-13T14:53:14.832846Z" }, "papermill": { "duration": 0.027072, "end_time": "2025-01-13T14:53:14.833972", "exception": false, "start_time": "2025-01-13T14:53:14.806900", "status": "completed" }, "tags": [ "remove-input" ] }, "outputs": [ { "data": { "text/html": [ "
import torch\n",
"import torch.nn as nn\n",
"from typing import Type\n",
"from functools import partial\n",
"\n",
"\n",
"class RNNLanguageModel(nn.Module):\n",
" def __init__(self, \n",
" cell: Type[RNNBase],\n",
" inputs_dim: int,\n",
" hidden_dim: int,\n",
" vocab_size: int,\n",
" **kwargs\n",
" ):\n",
" super().__init__()\n",
" self.cell = cell(inputs_dim, hidden_dim, **kwargs)\n",
" self.linear = nn.Linear(hidden_dim, vocab_size)\n",
"\n",
" def forward(self, x, state=None, return_state=False):\n",
" outs, state = self.cell(x, state)\n",
" outs = self.linear(outs) # (T, B, H) -> (T, B, C)\n",
" return outs if not return_state else (outs, state)\n",
"\n",
"\n",
"LanguageModel = lambda cell: partial(RNNLanguageModel, cell)\n",
"
import torch.nn.functional as F\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class SequenceDataset(Dataset):\n",
" def __init__(self, data: torch.Tensor, seq_len: int, vocab_size: int):\n",
" super().__init__()\n",
" self.data = data\n",
" self.seq_len = seq_len\n",
" self.vocab_size = vocab_size\n",
"\n",
" def __getitem__(self, i):\n",
" c = self.data[i: i + self.seq_len + 1]\n",
" x, y = c[:-1], c[1:]\n",
" x = F.one_hot(x, num_classes=self.vocab_size).float()\n",
" return x, y\n",
" \n",
" def __len__(self):\n",
" return len(self.data) - self.seq_len\n",
"
import re\n",
"import os\n",
"import torch\n",
"import requests\n",
"from collections import Counter\n",
"from typing import Union, Optional, TypeVar, List\n",
"\n",
"from pathlib import Path\n",
"\n",
"DATA_DIR = Path("./data")\n",
"DATA_DIR.mkdir(exist_ok=True)\n",
"\n",
"\n",
"T = TypeVar("T")\n",
"ScalarOrList = Union[T, List[T]]\n",
"\n",
"\n",
"class Vocab:\n",
" def __init__(self, \n",
" text: str, \n",
" min_freq: int = 0, \n",
" reserved_tokens: Optional[List[str]] = None,\n",
" preprocess: bool = True\n",
" ):\n",
" text = self.preprocess(text) if preprocess else text\n",
" tokens = list(text)\n",
" counter = Counter(tokens)\n",
" reserved_tokens = reserved_tokens or []\n",
" self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)\n",
" self.itos = [self.unk_token] + reserved_tokens + [tok for tok, f in filter(lambda tokf: tokf[1] >= min_freq, self.token_freqs)]\n",
" self.stoi = {tok: idx for idx, tok in enumerate(self.itos)}\n",
"\n",
" def __len__(self):\n",
" return len(self.itos)\n",
" \n",
" def __getitem__(self, tokens: ScalarOrList[str]) -> ScalarOrList[int]:\n",
" if isinstance(tokens, str):\n",
" return self.stoi.get(tokens, self.unk)\n",
" else:\n",
" return [self.__getitem__(tok) for tok in tokens]\n",
"\n",
" def to_tokens(self, indices: ScalarOrList[int]) -> ScalarOrList[str]:\n",
" if isinstance(indices, int):\n",
" return self.itos[indices]\n",
" else:\n",
" return [self.itos[int(index)] for index in indices]\n",
" \n",
" def preprocess(self, text: str):\n",
" return re.sub("[^A-Za-z]+", " ", text).lower().strip()\n",
"\n",
" @property\n",
" def unk_token(self) -> str:\n",
" return "▮"\n",
"\n",
" @property\n",
" def unk(self) -> int:\n",
" return self.stoi[self.unk_token]\n",
"\n",
" @property\n",
" def tokens(self) -> List[int]:\n",
" return self.itos\n",
"\n",
"\n",
"class Tokenizer:\n",
" def __init__(self, vocab: Vocab):\n",
" self.vocab = vocab\n",
"\n",
" def tokenize(self, text: str) -> List[str]:\n",
" UNK = self.vocab.unk_token\n",
" tokens = self.vocab.stoi.keys()\n",
" return [c if c in tokens else UNK for c in list(text)]\n",
"\n",
" def encode(self, text: str) -> torch.Tensor:\n",
" x = self.vocab[self.tokenize(text)]\n",
" return torch.tensor(x, dtype=torch.int64)\n",
"\n",
" def decode(self, indices: Union[ScalarOrList[int], torch.Tensor]) -> str:\n",
" return "".join(self.vocab.to_tokens(indices))\n",
"\n",
" @property\n",
" def vocab_size(self) -> int:\n",
" return len(self.vocab)\n",
"\n",
"\n",
"class TimeMachine:\n",
" def __init__(self, download=False, path=None):\n",
" DEFAULT_PATH = str((DATA_DIR / "time_machine.txt").absolute())\n",
" self.filepath = path or DEFAULT_PATH\n",
" if download or not os.path.exists(self.filepath):\n",
" self._download()\n",
" \n",
" def _download(self):\n",
" url = "https://www.gutenberg.org/cache/epub/35/pg35.txt"\n",
" print(f"Downloading text from {url} ...", end=" ")\n",
" response = requests.get(url, stream=True)\n",
" response.raise_for_status()\n",
" print("OK!")\n",
" with open(self.filepath, "wb") as output:\n",
" output.write(response.content)\n",
" \n",
" def _load_text(self):\n",
" with open(self.filepath, "r") as f:\n",
" text = f.read()\n",
" s = "*** START OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"\n",
" e = "*** END OF THE PROJECT GUTENBERG EBOOK THE TIME MACHINE ***"\n",
" return text[text.find(s) + len(s): text.find(e)]\n",
" \n",
" def build(self, vocab: Optional[Vocab] = None):\n",
" self.text = self._load_text()\n",
" vocab = vocab or Vocab(self.text)\n",
" tokenizer = Tokenizer(vocab)\n",
" encoded_text = tokenizer.encode(vocab.preprocess(self.text))\n",
" return encoded_text, tokenizer\n",
"