Writing a BitGenerator

The standard method to write a bit generator involves writing a Cython pyx file that wraps some C source code containing a high-performance implementation of a Pseudo RNG. This leads to best-case performance without creating any external dependencies.

UserBitGenerator provides a simple wrapper class that allows users to write bit generators in pure Python or, if performance is an issue, using Cython or by accessing functions in a compiled library (e.g., a DLL).

Here we examine the steps needed to write a pure Python bit generator and higher-performance generator using Cython.

Using Python

The example here begins by writing a class that implements the PCG64 bit generator using the XSL-RR output transformation. While this is not a complete implementation (it does not support advance or seed), it is simple. The key to understanding PCG is that the underlying state is updated using a Linear Congruential Generator (LCG) that uses a 128-bit state, multiplier and increment. The state evolves according to

\[s_{n+1} = m s_{n} + i\]

where \(s\) is the state, \(m\) is the multiplier and \(i\) is the increment.

The PCG generator then transforms \(s_{n+1}\) to produce the the final output. The XLS-RR output function XORs the upper 64 bits of the state with the lower 64 bits before applying a random rotation of this value.

The code below implements this generator using built-in Python operations and a little NumPy.

[1]:
import numpy as np


# The rotation
def rotr_64(value, rot):
    value = np.uint64(value)
    rot = np.uint64(rot)
    return int((value >> rot) | (value << ((-rot) & np.uint(63))))


class PythonPCG64:
    # A 128 bit multiplier
    PCG_DEFAULT_MULTIPLIER = (2549297995355413924 << 64) + 4865540595714422341
    MODULUS = 2**128

    def __init__(self, state, inc):
        """Directly set the state and increment, no seed support"""
        self.state = state
        self.inc = inc
        self._has_uint32 = False
        self._uinteger = 0
        self._next_32 = self._next_64 = None

    def random_raw(self):
        """Generate the next "raw" value, which is 64 bits"""
        state = self.state * self.PCG_DEFAULT_MULTIPLIER + self.inc
        state = state % self.MODULUS
        self.state = state
        return rotr_64((state >> 64) ^ (state & 0xFFFFFFFFFFFFFFFF), state >> 122)

    @property
    def next_64(self):
        """
        Return a callable that accepts a single input. The input is usually
        a void pointer that is cast to a struct that contains the PRNGs
        state. When wiring a bit generator in Python, it is simpler to use
        a closure than to wrap the state in an array, pass it's address as a
        ctypes void pointer, and then to get the pointer in the function.
        """

        def _next_64(void_p):
            return self.random_raw()

        self._next_64 = _next_64
        return _next_64

    @property
    def next_32(self):
        """
        Return a callable that accepts a single input. This is identical to
        ``next_64`` except that it return a 32-bit unsigned int. Here we save
        half of the raw 64 bit output for subsequent calls.
        """

        def _next_32(void_p):
            if self._has_uint32:
                self._has_uint32 = False
                return self._uinteger
            next_value = self.random_raw()
            self._has_uint32 = True
            self._uinteger = next_value >> 32
            return next_value & 0xFFFFFFFF

        self._next_32 = _next_32
        return _next_32

    @property
    def state_getter(self):
        def f():
            return {"state": self.state, "inc": self.inc}

        return f

    @property
    def state_setter(self):
        def f(value):
            self.state = value["state"]
            self.inc = value["inc"]

        return f

Next, we use UserBitGenerator to expose the Python functions to C. The Python functions are wrapped in Ctypes callbacks under the hood.

[2]:
from randomgen import PCG64, UserBitGenerator

pcg = PCG64(0, mode="sequence", variant="xsl-rr")
state, inc = pcg.state["state"]["state"], pcg.state["state"]["inc"]
print("Get the state from a seeded PCG64")
print(pcg.state["state"])
prng = PythonPCG64(state, inc)
print("State and increment are identical")
print(prng.state_getter())
python_pcg = UserBitGenerator(prng.next_64, 64, next_32=prng.next_32)
print("First 5 values from PythonPCG64")
print(python_pcg.random_raw(5))
print("Match official C version")
print(pcg.random_raw(5))
Get the state from a seeded PCG64
{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}
State and increment are identical
{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}
First 5 values from PythonPCG64
[11749869230777074271  4976686463289251617   755828109848996024
   304881062738325533 15002187965291974971]
Match official C version
[11749869230777074271  4976686463289251617   755828109848996024
   304881062738325533 15002187965291974971]

python_pcg is a bit generator, and so can be used with a NumPy Generator. Here we see the state changes after producing a single standard normal.

[3]:
from numpy.random import Generator

gen = Generator(python_pcg)
print(f"Before: {prng.state}")
print(f"Std. Normal : {gen.standard_normal()}")
print(f"After: {prng.state}")
Before: 133411349017971402732463711865589153492
Std. Normal : 0.36159505490948474
After: 9405893610231781608176235507540826829

Accessing python_pcg.state would raise NotImplementedError. It is possible to wire up this function by setting state_setter and state_getter in UserBitGenerator. These both take callable functions.

This time the state_getter and state_setter are used so that the state can be read and set through the bit generator.

[4]:
python_pcg = UserBitGenerator(
    prng.next_64,
    64,
    next_32=prng.next_32,
    state_getter=prng.state_getter,
    state_setter=prng.state_setter,
)
python_pcg.state
[4]:
{'state': 9405893610231781608176235507540826829,
 'inc': 87136372517582989555478159403783844777}

Performance

We can time random_raw to see how fast (slow) the pure python version is. It is about 3 orders-of-magnitude (1000x) slower than the C implementation.

[5]:
%timeit python_pcg.random_raw(1000)
%timeit pcg.random_raw(1000)
3.08 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.55 µs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Using numba

A bit generator implemented in Numba can be used through the UserBitGenerator.from_cfunc interface. The block below implements the JSF generator using numba.

The key outputs of the manager class are next_64, next_32, and next_double, which are all decorated using numba’s @cfunc.

[6]:
import ctypes

import numpy as np
from numba import cfunc, types, carray, jit

from randomgen.wrapper import UserBitGenerator


rotate64_sig = types.uint64(types.uint64, types.int_)


@jit(signature_or_function=rotate64_sig, inline="always")
def rotate64(x, k):
    return (x << k) | (x >> (64 - k))


jsf_next_sig = types.uint64(types.uint64[:])


@jit(signature_or_function=jsf_next_sig, inline="always")
def jsf_next(state):
    """
    Update the state in place

    This is a literal translation of the C code where the value p, q,
    and r are fixed.
    """
    # Default values
    p = 7
    q = 13
    r = 37
    # Update
    e = state[0] - rotate64(state[1], p)
    state[0] = state[1] ^ rotate64(state[2], q)
    state[1] = state[2] + (rotate64(state[3], r) if r else state[3])
    state[2] = state[3] + e
    state[3] = e + state[0]
    return state[3]


class NumbaJSF:
    def __init__(self, seed):
        if not isinstance(seed, (int, np.integer)) or not (0 <= state < 2**64):
            raise ValueError("seed must be a valid uint64")
        # state[0:4] is the JSF state
        # state[4] contains both the has_uint flag in bit 0
        #   uinteger in bits 32...63
        self._state = np.zeros(5, dtype=np.uint64)
        self._next_raw = None
        self._next_64 = None
        self._next_32 = None
        self._next_double = None
        self.seed(seed)

    def seed(self, value):
        self._state[0] = 0xF1EA5EED
        self._state[1] = value
        self._state[2] = value
        self._state[3] = value
        for i in range(20):
            jsf_next(self._state)

    @property
    def state_address(self):
        """Get the location in memory of the state NumPy array."""
        return self._state.ctypes.data_as(ctypes.c_void_p)

    @property
    def next_64(self):
        """Same as raw since a 64 bit generator"""

        # Ensure a reference is held
        self._next_64 = self.next_raw

        return self.next_raw

    @property
    def next_32(self):
        """A CFunc generating the next 32 bits"""
        sig = types.uint32(types.CPointer(types.uint64))

        @cfunc(sig)
        def next_32(st):
            # Get the NumPy uint64 array
            bit_gen_state = carray(st, (5,), dtype=np.uint64)
            # We use the first bit to indicate that 32 bits are stored in 32...63
            if bit_gen_state[4] & np.uint64(0x1):
                # Get the upper 32 bits
                out = bit_gen_state[4] >> np.uint64(32)
                # Clear the stored value
                bit_gen_state[4] = 0
                return out
            # If no bits available, genrate a new value
            z = jsf_next(bit_gen_state)
            # Store the new value always with 1 in bit 0
            bit_gen_state[4] = z | np.uint64(0x1)
            # Return the lower 32 (0...31)
            return z & 0xFFFFFFFF

        # Ensure a reference is held
        self._next_32 = next_32

        return next_32

    @property
    def next_double(self):
        """A CFunc that generates the next ouble"""
        sig = types.double(types.CPointer(types.uint64))

        @cfunc(sig)
        def next_double(st):
            # Get the state
            bit_gen_state = carray(st, (5,), dtype=np.uint64)
            # Return the next value / 2**53
            return (
                np.uint64(jsf_next(bit_gen_state)) >> np.uint64(11)
            ) / 9007199254740992.0

        # Ensure a reference is held
        self._next_double = next_double

        return next_double

    @property
    def next_raw(self):
        sig = types.uint64(types.CPointer(types.uint64))

        @cfunc(sig)
        def next_64(st):
            # Get the NumPy array containing the state
            bit_gen_state = carray(st, (5,), dtype=np.uint64)
            # Return the next value
            return jsf_next(bit_gen_state)

        # Ensure a reference is held
        self._next_64 = next_64

        return next_64

    @property
    def state_getter(self):
        """A function that returns the state. This is Python and is not decorated"""

        def f() -> dict:
            return {
                "bit_gen": type(self).__name__,
                "state": self._state[:4],
                "has_uint": self._state[4] & np.uint64(0x1),
                "uinteger": self._state[4] >> np.uint64(32),
            }

        return f

    @property
    def state_setter(self):
        """A function that sets the state. This is Python and is not decorated"""

        def f(value: dict):
            name = value.get("bit_gen", None)
            if name != type(self).__name__:
                raise ValueError(f"state must be from a {type(self).__name__}")
            self._state[:4] = np.uint64(value["state"])
            temp = np.uint64(value["uinteger"]) << np.uint64(32)
            temp |= np.uint64(value["has_uint"]) & np.uint64(0x1)
            self._state[4] = temp

        return f

We start by instantizing the class and taking a look at the initial state.

[7]:
# From random.org
state = np.array([0x77, 0x5E, 0xB7, 0x11, 0x14, 0x3F, 0xD1, 0x0E], dtype=np.uint8).view(
    np.uint64
)[0]
njsf = NumbaJSF(state)
njsf.state_getter()
[7]:
{'bit_gen': 'NumbaJSF',
 'state': array([ 1167245051188668936, 13259944246262022926,  8870424784319794977,
         9596734350428388680], dtype=uint64),
 'has_uint': 0,
 'uinteger': 0}

from_cfunc is then used to pass the CFuncs, state address pointer and the state getter and setter to UserBitGenerator. We see that the state changes after calling random_raw.

[8]:
jsf_ubg = UserBitGenerator.from_cfunc(
    njsf.next_raw,
    njsf.next_64,
    njsf.next_32,
    njsf.next_double,
    njsf.state_address,
    state_getter=njsf.state_getter,
    state_setter=njsf.state_setter,
)
print(jsf_ubg.state)
print(jsf_ubg.random_raw(2))
print(jsf_ubg.state)
{'bit_gen': 'NumbaJSF', 'state': array([ 1167245051188668936, 13259944246262022926,  8870424784319794977,
        9596734350428388680], dtype=uint64), 'has_uint': 0, 'uinteger': 0}
[ 602963287911976729 5264292724725465572]
{'bit_gen': 'NumbaJSF', 'state': array([ 530704699024515781, 2740075917084007745, 5336551313612926520,
       5264292724725465572], dtype=uint64), 'has_uint': 0, 'uinteger': 0}

Some Generator function use 32-bit integers to save bits. random with dtype=np.float32 is one. After calling this function we see that has_uint is now 1.

[9]:
gen = Generator(jsf_ubg)
print(f"A 32-bit float: {gen.random(dtype=np.float32)}")
print("Notice has_uint is now 1")
jsf_ubg.state
A 32-bit float: 0.16430795192718506
Notice has_uint is now 1
[9]:
{'bit_gen': 'NumbaJSF',
 'state': array([13952735719045862400, 12103276313412614439,  5553417437478470678,
        14241860431798867506], dtype=uint64),
 'has_uint': 1,
 'uinteger': 3315941531}

Performance

We can use random_raw function to assess the performance and compare it to the C-implementation JSF. It is about 6% slower which is an impressive outcome.

[10]:
%timeit jsf_ubg.random_raw(1000000)
4.4 ms ± 62.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[11]:
from randomgen import JSF

jsf = JSF()
%timeit jsf.random_raw(1000000)
4.19 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Next, we will clone the state of the native JSF to the numba implementation.

[12]:
jsf_state = jsf.state
jsf_state
[12]:
{'bit_generator': 'JSF',
 'state': {'a': 17190901158427765818,
  'b': 14501513697102443756,
  'c': 15715724510248929625,
  'd': 12712143389959007425,
  'p': 7,
  'q': 13,
  'r': 37},
 'size': 64,
 'has_uint32': 0,
 'uinteger': 0,
 'seed_size': 1}

While the structure of the state is different, the values are the same: [a, b, c, d].

[13]:
st = jsf_ubg.state
# Clone the C-implemntations's state and set it
st["state"][:4] = [jsf_state["state"][key] for key in ("a", "b", "c", "d")]
st["has_uint32"] = jsf_state["has_uint32"]
st["uinteger"] = jsf_state["uinteger"]
jsf_ubg.state = st
jsf_ubg.state
[13]:
{'bit_gen': 'NumbaJSF',
 'state': array([17190901158427765818, 14501513697102443756, 15715724510248929625,
        12712143389959007425], dtype=uint64),
 'has_uint': 1,
 'uinteger': 0}

Finally, we can take a look at the next few values to show that the implementations of the two generators are identical.

[14]:
jsf_ubg.random_raw(5)
[14]:
array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,
       17987378307908897868, 18034113569054765009], dtype=uint64)
[15]:
jsf.random_raw(5)
[15]:
array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,
       17987378307908897868, 18034113569054765009], dtype=uint64)