Language modeling

Estimating the joint probability \(p(\boldsymbol{\mathsf x}_1, \ldots,\boldsymbol{\mathsf x}_T)\) of a sequence of discrete tokens prove useful for various reasons. This task is called language modeling. For instance, machine translation or ASR systems generate sequences by optimizing for the most probable ones. In particular, models which predicts the next element of a sequence are referred to as a language model (LM). Recall that we can write a joint distribution as a chain of conditional distributions:

\[ p(\boldsymbol{\mathsf x}_1, \ldots,\boldsymbol{\mathsf x}_T) = p(\boldsymbol{\mathsf x}_1) \prod_{t = 2}^{T} p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t-1}). \]

Hence, the output of a model for discrete data must be a distribution \(p(\boldsymbol{\mathsf x}_{t} \mid \boldsymbol{\mathsf x}_{1}, \ldots, \boldsymbol{\mathsf x}_{t-1})\) for each token instead of expected values for regression models. In practice, this means that we need to have a finite collection of valid tokens called a vocabulary. Then, we can generate text, simply by drawing one token at a time \(\boldsymbol{\mathsf{x}}_t \sim p(\boldsymbol{\mathsf{x}}_t \mid \boldsymbol{\mathsf{x}}_1, \ldots, \boldsymbol{\mathsf{x}}_{t-1})\). For example,

\[\begin{split} \begin{aligned} &\;p(\text{deep}, \text{learning}, \text{is}, \text{fun}) \\ =& \;p(\text{deep}) \cdot p(\text{learning} \mid \text{deep}) \cdot p(\text{is} \mid \text{deep}, \text{learning}) \cdot p(\text{fun} \mid \text{deep}, \text{learning}, \text{is}). \end{aligned} \end{split}\]

The probabilities can be estimated using relative frequencies perhaps with Laplace smoothing:

\[ p(\text{deep} \mid \text{learning}) \approx \frac{\#(\text{deep},\, \text{learning}) + \kappa}{\#(\text{learning}) + \kappa|\mathcal{V}|} \]

where \(\kappa > 0\) can be thought of as pseudo-count. Observe that the smoothing parameter \(\kappa\) acts as a regularizer when \(\kappa \gg 1,\) where the distribution becomes uniform. Moreover, we usually truncate the context to a fixed number of terms as a Markov hypothesis, and because n-grams become sparse in naturally occuring text as n increases.


Perplexity

Next, we need a generic metric to measure the quality of the language model. One way is to check how surprising the text is. A good language model is able to predict, with high accuracy, the tokens that come next. Consider the following continuations of the phrase “It is sunny”, as proposed by three different language models:

1. It is sunny outside
2. It is sunny banana tree
3. It is sunny soiupt;mkj ldfosim

The first example is clearly the best, although not necessarily factual or accurate, model predicts kind of word correctly. The next is nonsensical, but at least model has learned some degree of correlation between words (‘banana’ and ‘tree’). Finally, the last example indicates poor training.

To evaluate a language model, we can use the cross-entropy on the next token which is equivalent to maximizing the likelihood of a text. We normalize this over the number of tokens predicted. For example, we evaluate the model on contexts of variable length \(\delta = 1, \ldots, T\) starting from \(\boldsymbol{\mathsf{x}}_{t}\):

\[ \mathcal{L} = -\frac{1}{n}\sum_{t}\sum_{\delta = 1}^{T} \log p(\boldsymbol{\mathsf{x}}_{t + \delta} \mid \boldsymbol{\mathsf{x}}_{t}, \ldots, \boldsymbol{\mathsf{x}}_{t + \delta - 1}) \]

where \(n\) is the number predictions. For a classifier that predicts all tokens uniformly random, then \(\mathcal{L} = \log |\mathcal{V}|\) where \(\mathcal{V}\) is the set of tokens. This is a useful baseline. A similarly simple model predicts prior probabilities based on counts of each token in the training data.

import math
import torch
import torch.nn.functional as F

# Reduction over B × T elements
B, C, T = 32, 28, 128
print(F.cross_entropy(torch.rand(B, C, T), target=torch.randint(C, size=(B, T))))
print(math.log(C))
tensor(3.3711)
3.332204510175204

Historically, researchers in NLP have also used perplexity (PP) which is simply the exponential of the cross-entropy:

\[ \text{PP} = \exp\left(-\frac{1}{n}\sum_{t}\sum_{\delta = 1}^{T} \log p(\boldsymbol{\mathsf{x}}_{t + \delta} \mid \boldsymbol{\mathsf{x}}_{t}, \ldots, \boldsymbol{\mathsf{x}}_{t + \delta - 1})\right). \]

Note that perplexity is equivalent to an inverse likelihood, and to the geometric mean of \(\frac{1}{p(\boldsymbol{\mathsf{x}}_t \mid \boldsymbol{\mathsf{x}}_{<t})}\):

\[ \text{PP} = \frac{1}{\sqrt[n]{\prod_{t}\prod_{\delta = 1}^{T} p(\boldsymbol{\mathsf{x}}_{t + \delta} \mid \boldsymbol{\mathsf{x}}_{[t:\,t + \delta-1]})}} = \sqrt[n]{\prod_{t}\prod_{\delta = 1}^{T} \frac{1} {p(\boldsymbol{\mathsf{x}}_{t + \delta} \mid \boldsymbol{\mathsf{x}}_{[t:\,t + \delta-1]})}}. \]

Hence, for a perfect model, \(\text{PP} = 1.\) On the other hand, if the model predicts \(p \approx 0\) for the correct token at one step, then[1] we get \(\text{PP} = \infty.\) As a baseline, for a uniformly random model, we have \(\text{PP} = |\mathcal{V}|.\) This provides a nontrivial upper bound that any useful model must beat. So, we have \(\text{PP}\) values \(\infty > |\mathcal{V}| \geq 1\) for the three regimes[2]. This can be interpreted as the average number of tries to get the correct prediction at each step, e.g. single try for a perfect model.

Remark. For the sake of concreteness, we evaluated cross-entropy over predictions with context of varying length \(\delta = 1, \ldots, T\) from \(t.\) But we can also use fixed-length contexts, depending on the given task. In general, we simply evaluate cross-entropy over all instances of next-token prediction regardless of the particulars of the prediction process.