Mersenne Twister algorithm

I walk through the implementation of the Mersenne Twister algorithm, one of the most widely used pseudo-random number generators.

A pseudo-random number generator (PRNG) is an algorithm that produces sequences of numbers designed to imitate independently and identically distributed (IID) random variables from a given distribution. Despite appearing random, these sequences are entirely deterministic, controlled by an initial value called the seed.

Among all probability distributions, the uniform distribution $U(0,1)$ is particularly important because it serves as a building block to sample from other distributions. One of the earliest PRNGs, the middle-square method, was proposed in 1946. Since then, there have been many improvements.

In this post, I focus on the Mersenne Twister algorithm, introduced in 1997. While it is no longer the best PRNG available today, it remains the default pseudo-random generator in Python, as well as in many other languages including R, Ruby and Julia (as well as in Microsoft Excel and MATLAB).

The original paper can be found here.

Description

The Mersenne Twister algorithm maintains an internal state array of $n$ numbers and an index $0\leq i <n$. When a number is requested, it takes the $i$th element of the state, runs it through a tempering function, and produces the result. Once all $n$ numbers have been used, the algorithm performs a twist transformation to refresh the entire state and resets the index to $i = 0$.

Each number is an integer between $0$ and $2^w-1$, stored as a $w$-bit word (that is, a vector in $\mathbb{F}_2^w$). The tempering function is an invertible linear map $\mathbb{F}_2^w\to \mathbb{F}_2^w$ that scrambles the bits of each word. To obtain a real number in $[0, 1]$, we simply divide the output by $2^{w}-1$.

The twist transformation, on the other hand, is a linear map $\mathbb{F}_2^{nw}\to \mathbb{F}_2^{nw}$ applied to the entire state array. It repeatedly concatenates the upper $w-r$ bits of a word with the lower $r$ bits of the next, then uses shifts, masks and XOR operations to update another entry, effectively ‘twisting’ the bits around.

Figure 1. Diagram of the Mersenne Twister algorithm.

To understand why it is called the Mersenne Twister, we need to talk about its period.

Since the algorithm generates integers between $0$ and $2^w-1$, at least one value has to repeat after $2^w$ draws. That is not a problem. What we want to avoid is repeating the same internal state too soon, which would create obvious cycles in the output.

Each internal state is a vector in $\mathbb{F}_2^{nw}$, and the twist transformation is a linear map on this vector space. Since the lower $r$ bits of the first word are discarded during the twist transformation, this map has a kernel of dimension $r$. This means that the longest possible period is $2^{nw-r}-1$, achieved if every state in the range of the twist transformation (except zero) is visited once before any repeats.

Figure 1. Illustration of an initial state evolving in the vector space $\mathbb{F}_2^{nw}$ under iterations of the twist transformation. As long as the initial state is not in the $r$-dimensional kernel of the twist transformation, it says outside of that subspace.

It turns out that, if $nw-r$ is prime, then the maximal period $2^{nw-r}-1$ is reached. Hence the name: Mersenne primes are the prime numbers of the form $2^p-1$ for some prime $p$. They are named after Marin Mersenne, a French friar, who studied them in the early 17th century.

Choice of parameters

Instead of committing to a particular version of the algorithm, the authors of the original paper left a few parameters unspecified. We have already encountered some of them:

Other parameters help define the tempering function and the twist transformation.

The most widely used version, MT19937, uses $n=624$, $w=32$ and $r=31$. This setup produces an (astronomical) period of $2^{19937}-1$ using just $2.4$ KB of memory ($624 \times 32$ bits).

Implementation

Sometimes, the best way to convince oneself that the code is not doing any magic is to implement it. While most programming languages use the MT19937 version of the algorithm, each implementation has its own quirks. Most of it has to do with how the internal state is initialized as well as how the final output is assembled.

My goal was hence to recreate exactly the same behavior as Python’s random module (which is implemented in C). My implementation can be found here.

import random
from mersenne_twister import MersenneTwister

random.seed(7)
print(random.random()) # 0.32383276483316237

generator = MersenneTwister()
generator.seed(7)
print(generator.random()) # 0.32383276483316237

The generator can be seeded with a user-provided seed or the current system time.

import time

W, N, M, R = (32, 624, 397, 31)
A = 0x9908B0DF
U = 11
S, B = (7, 0x9D2C5680)
T, C = (15, 0xEFC60000)
L = 18
F = 1812433253

class MersenneTwister:

    def __init__(self, seed: int | None = None) -> None:
        """Initializes the PRNG.

        Args:
            Optional seed value.
        """
        self._index = 0
        self._state = [0] * N
        self.seed(seed)

    def seed(self, seed: int | None = None) -> None:
        """Initializes the internal state from a seed.
        If no seed is provided, time is used.

        Args:
            Optional seed value.
        """
        if seed is None:
            seed = time.time_ns()
        seed = self._uint32(seed)
        self._initialize_state(seed)
        self._twist()

The internal state is initialized using the method from the original paper, plus some additional bit scrambling.

    def _initialize_state(self, seed: int) -> None:
        """Initializes the internal state from a seed.

        Args:
            Seed value.
        """
        state = self._state
        state[0] = 19650218

        for i in range(1, N):
            prev = state[i - 1]
            x_i = F * (prev ^ (prev >> (W - 2))) + i
            state[i] = self._uint32(x_i)

        # Bit scrambling to improve pseudo-randomness
        i = 1
        for _ in range(N):
            prev = state[i - 1]
            x_i = 1664525 * (prev ^ (prev >> (W - 2)))
            x_i ^= state[i]
            x_i += seed
            state[i] = self._uint32(x_i)
            i += 1
            if i == N:
                state[0] = state[N - 1]
                i = 1

        for _ in range(N - 1):
            prev = state[i - 1]
            x_i = 1566083941 * (prev ^ (prev >> (W - 2)))
            x_i ^= state[i]
            x_i -= i
            state[i] = self._uint32(x_i)
            i += 1
            if i == N:
                state[0] = state[N - 1]
                i = 1

        # Ensure the initial state is not in the kernel
        # of the twist transformation
        state[0] = 1 << R

    @staticmethod
    def _uint32(number: int) -> int:
        """Keeps the lower 32 bits of a number,
        since Python integers can be arbitrarily large.

        Args:
            Input number.

        Returns:
            Number constrained to 32 bits.
        """
        return number & 0xFFFFFFFF

The matrix multiplications for the tempering and twisting are purposefully chosen so that they can be expressed as bit shifts. This makes them much faster to compute.

    def _twist(self) -> None:
        """Twists the internal state."""
        state = self._state
        # Upper W - R bits
        upper_mask = ((1 << (N - R)) - 1) << R
        # Lower R bits
        lower_mask = (1 << R) - 1
        for i in range(N):
            xu = state[i] & upper_mask
            xl = state[(i + 1) % N] & lower_mask
            x = xu + xl
            xa = x >> 1
            if x & 1:
                xa ^= A
            state[i] = state[(i + M) % N] ^ xa
        self._index = 0

    @staticmethod
    def _temper(x: int) -> int:
        """Tempers the output.

        Args:
            Unsigned 32-bit integer.

        Returns:
            Tempered unsigned 32-bit integer.
        """
        x ^= (x >> U)
        x ^= ((x << S) & B)
        x ^= ((x << T) & C)
        x ^= (x >> L)
        return x

Python’s implementation does not divide the output by $2^{32}-1$ to obtain a real number in $[0, 1]$. Instead, it uses two $32$-bit outputs to produce a $53$-bit output before dividing by $2^{53}$. This has two desirable effects: it gives floating-point numbers with higher precision, and it makes it harder to reconstruct the internal state from the outputs since parts of them are discarded.

    def _generate_uint32(self) -> int:
        """Generates an unsigned 32-bit integer.

        Returns:
            Pseudo-random unsigned 32-bit integer.
        """
        if self._index == N:
            self._twist()

        x = self._state[self._index]
        self._index += 1
        return self._temper(x)

    def random(self) -> float:
        """Generates the next floating-point number
        in the range [0,1) with 53-bit precision.

        Returns:
            Number in the range [0,1).
        """
        # Keep upper 27 bits
        a = self._generate_uint32() >> 5
        # Keep upper 26 bits
        b = self._generate_uint32() >> 6
        # 9007199254740992 = 2^53
        return ((a << 26) + b) / 9007199254740992