In this post I briefly investigate an interesting architecture: the Pointer Sentinel Mixture Model, as described in the paper by Stephen Merity et al.

Introduction

Let’s say we’re given a piece of text – a sequence of words – and we want to predict the word that will appear next. So given a part of a sentence A cat sat on a, we want to predict mat.

This is usually done using a classifier with a softmax on top of it: the model consumes a sequence, and spits out probability estimatees for each word in the known vocabulary.

This approach can work remarkably well, and in most cases it makes intuitive sense: the model has some idea about the “meaning” of the sentence, and therefore it can assign higher probabilities to plausible words. In the sentence above the mat will therefore be assigned a higher probability than, for example Obama.

But sometimes there is a much easier way of predicting the next word – instead of searching through the whole vocabulary, maybe it makes sense to search through the words we’ve just seen?

Consider the sentence Bill Gates is an ex-CEO of Microsoft. Mr [Gates], where you want to correctly predict Gates. It is now really hard for the model to correctly pick Gates from the known vocabulary, especially if it did not see a lot of Bill Gates-related sentences.

But humans wouldn’t try to fit every name they know into the blank spot to see if it fits. They would simply point the name that was already mentioned few words back!

The Pointer Sentinel Mixture Model does exactly that – it tries to point to previously observed words, whenever possible. But since we can’t always find the answers in the previous words, we also need to have a “fallback” behavior – i.e. a standard softmax.

The model therefore learns three components:

  • A softmax classifier, assigning probability to each word in vocabulary

  • A pointer, assigning probabilities to previous words

  • A sentinel, which weighs decisions of softmax and pointer, deciding which should influence the prediction more.


Implementation

Here I briefly describe few parts of the implementation that are crucial parts of the whole system.

The attention (aka. pointer-gate thingy)

Let’s focus on the pointer-sentinel part of the architecture.

For the pointer part we need set of scores assigned to each of the previously seen words. Building on the example above, we would want a high score assigned to Gates, and low scores assigned to everything else (Bill, an, ex-CEO, ...).

For the sentinel part we only need one thing – a score that balances how much we follow what pointer says, and how much we rely on the good-ol’ softmax.

The authors propose to compute both those things (pointer scores and sentinel score) jointly – as one probability distribution.

This is done via an attention mechanism – current hidden state is compared via dot-product with previous hidden states to check its “compatibility”. Because we want to also compute the sentinel score, we just stick the vector representing sentinel as an additional “hidden-state” after the hidden-states for all previous words (see code below).

One important detail here is that before computing the attention scores, we transform the current hidden state with a linear transformation W_query @ h + b_query.

In code, it means we do something like this:

def _pointer(self, H, last):
    """ The `pointer` part of the network

    Parameters
    ----------
    H : Tensor
        Hidden representations for each timestep extracted
          from the last layer of LSTM
        size : (batch-size, sequence-len, hidden-size)
    last : Tensor
        Representations of the whole sequences (Last vectors from H,
          extracted earlier for efficiency)
        size : (batch-size, hidden-size)
    """
    batch_size, _, hid_size = H.size()

    sentinel = self.sentinel.expand(batch_size, 1, hid_size)
    latents  = th.cat((H, sentinel), 1)  # ::(b, s+1, h)

    query = F.tanh(self.query(last))
    query = query.unsqueeze(2)  # ::(b, h, 1)

    # bmm == batched-matrix-multiply
    logits  = th.bmm(latents, query).squeeze(2)  # ::(b, s+1)
    weigths = F.softmax(logits, dim=1)  # ::(b, s+1)

    probas = weigths[:, :-1]
    gates  = weigths[:, -1].unsqueeze(1)

    return probas, gates

And as we stated earlier, self.query is defined as nn.Linear(in_features=lstm_size, out_features=lstm_size, bias=True)

The mixing and loss

Now, when we have the scores assigned to each of the previous scores, and we also have the sentinel value, we can compute the final predictions of our model. We compute them as log-probabilities to simplify the computation of the loss we’re trying to minimize.

Note that during training we take a shortcut, and compute only the probabilities for the target words. Note in the snippet below how we sum up the scores for the same words – so if Gates appeared twice in previously seen text, the score for Gates would be a sum of scores from both these occurrences.

In code we do something like:

def mixture_train(self, ptr_probas, rnn_probas, gates,
                  x, y):
    """ Compute the log-probabilities assigned by the model to
    the target words specified by `y`

    Parameters
    ----------
    ptr_probas : FloatTensor
        see the return values of `forward`
    rnn_probas : FloatTensor
        see the return values of `forward`
    gates : FloatTensor
        see the return values of `forward`
    y : LongTensor
        indices of the target words for each sequence in the batch.
        size : 2-D, (batch-size, 1)
    x : LongTensor
        indices of the input words for each sequence in the batch.
        size : 2-D, (batch-size, seq-length)
    """

    ptr_mask   = (x == y.unsqueeze(1).expand_as(x)).type_as(ptr_probas)
    ptr_scores = (ptr_probas * ptr_mask).sum(1)
    rnn_scores = rnn_probas.gather(dim=1, index=y.unsqueeze(1)).squeeze()
    gates      = gates.squeeze()

    p     = gates * rnn_scores + ptr_scores
    log_p = th.log(p + 1e-7)

    return log_p


The prediction

The authors state that during testing it’s better to use CPU. This makes intuitive sense – we need to iterate over every word in the vocabulary, check if it appeared in previous text, and sum the scores it got.

I think, however, that if you’re willing to go for batch-size of 1, you can use something like the following code, and still run it efficiently on GPU.

The idea here is that ptr_probas_expanded is non-zero for every word that appeared in the previously seen words.

def mixture_sample(self, ptr_probas, rnn_probas, gates, x):
    if x.size(0) != 1:
        raise RuntimeError(f"Sampling is implemented for 'batches' "
                           "of size (1), but {x.size(0)} was found.")

    ptr_probas_expaned = th.zeros_like(rnn_probas)
    ptr_probas_expaned.index_add_(1, x.squeeze(), ptr_probas)

    probas = rnn_probas * gates + ptr_probas_expaned

    return probas


The training scheme

One thing I love about Stephen’s papers is that he seems to always care deeply about explaining every detail of the model.

In this particular instance, the paper outlines details of how the model is trained. It is a rather slow procedure – you process 100 words at a time, but after your done with this “chunk”, you don’t jump to the next 100. Instead, you move only 1 word forward!

Looking at the pseudo-code from the paper, the authors use k_1 = and k_2 = 100.

In pytorch this would mean something like this:


def _core(self, x, s0):
    """ Core part of the network, shared between
    `softmax-rnn` and `pointer` parts. Rolls LSTM for `L` steps,
    moving `k` words each time (`L` == 100, `k` == 1 in the original paper)

    Parameters
    ----------
    x : LongTensor
        See `forward` method for details.
    s0 : Tuple[FloatTensor]
        See `lstm_state` arg in the `forward`

    Returns
    -------
    H : FloatTensor
        Outputs from the last layer of the LSTM
        Size: (batch-size, seq-len, hid-size)
    sk : Tuple[FloatTensors]
        k-th state (includes cell state) of the all LSTM layers
    hT : FloatTensor
        Last hidden state of the last LSTM layer
        Size: (batch-size, hid-size)
    """
    x = self.embed(x)

    H0, sk = self.lstm(x[:, :self.k, :], s0)
    H1, _  = self.lstm(x[:, self.k:, :], sk)

    H = th.cat((H0, H1), 1)
    hT = H[:, -1, :]

    return H, sk, hT


Analysis

Let’s take a closer look at what the model have learned. We’ll quickly check how often the pointer is used, and then try to investigate in what instances it is used most often.

Finally, we’ll try to reproduce the behaviour from the paper.

Gate Distribution

First of all, the pointer is used rather rarely. Here’s a distribution of the gate values on the test set for PTB:

And here’s for the Wiki:

This is not surprising, of course, given that situations where it could really be useful are expected to be rare. Most sentences do not reference information presented previously.

I was wondering why the distributions for Wiki and PTB are so different. Finally I came to a conclusion, that when we are describing something (like I am describing the pointer a lot in this blogpost), we are more likely to refer to previously used words. This means that in text such as Wikipedia articles, pointer should be used much more often – which we see is indeed the case!

Predicting numbers

One not-so-nice feature of this model (or simply a shortcoming of the dataset) is that on PTB the pointer is very often used to predict numbers:

This could be both a positive and a negative thing, altough in case of the PTB it is most likely negative: it makes sense for a model to try and choose a number it have seen previously (e.g. 2018), instead of predicting it via softmax. But in PTB all numbers are replaced with a single token, N, making these predictions rather meaningless.

Predicting names

We can easily confirm the observations from the paper, and also show that our model can indeed solve the Gates example from the beginning of this article.

When appropriate, the model will just choose the name from the previous words:

Although sometimes it will fail with weird results:

And here’s the piece of text the authors included in their paper. I did not get the same, very sharp prediction, but this is most likely due to the fact that my model trained only for 5 epochs:

Paper:

Ours:

Conclusion

It’s amazing how easy it was to implement this paper. The model was very well described, the idea quite simple, and pytorch always makes things as easy as they can get.

I recommend everyone playing with the models from the literature, it’s always fun to see something work with your own code.

The repo is here. I used visdom (although now I no longer use it in my projects), but there is no main that you can run (there is a notebook though).

Acknowledgements

Special thanks to Stephen Merity for answering my e-mail with some questions!