Writing a BitGenerator¶
The standard method to write a bit generator involves writing a Cython pyx
file that wraps some C source code containing a high-performance implementation of a Pseudo RNG. This leads to best-case performance without creating any external dependencies.
UserBitGenerator
provides a simple wrapper class that allows users to write bit generators in pure Python or, if performance is an issue, using Cython or by accessing functions in a compiled library (e.g., a DLL).
Here we examine the steps needed to write a pure Python bit generator and higher-performance generator using Cython.
Using Python¶
The example here begins by writing a class that implements the PCG64 bit generator using the XSL-RR output transformation. While this is not a complete implementation (it does not support advance
or seed
), it is simple. The key to understanding PCG is that the underlying state is updated using a Linear Congruential Generator (LCG) that uses a 128-bit state, multiplier and increment. The state evolves according to
where \(s\) is the state, \(m\) is the multiplier and \(i\) is the increment.
The PCG generator then transforms \(s_{n+1}\) to produce the the final output. The XLS-RR output function XORs the upper 64 bits of the state with the lower 64 bits before applying a random rotation of this value.
The code below implements this generator using built-in Python operations and a little NumPy.
[1]:
import numpy as np
# The rotation
def rotr_64(value, rot):
value = np.uint64(value)
rot = np.uint64(rot)
return int((value >> rot) | (value << ((-rot) & np.uint(63))))
class PythonPCG64:
# A 128 bit multiplier
PCG_DEFAULT_MULTIPLIER = (2549297995355413924 << 64) + 4865540595714422341
MODULUS = 2**128
def __init__(self, state, inc):
"""Directly set the state and increment, no seed support"""
self.state = state
self.inc = inc
self._has_uint32 = False
self._uinteger = 0
self._next_32 = self._next_64 = None
def random_raw(self):
"""Generate the next "raw" value, which is 64 bits"""
state = self.state * self.PCG_DEFAULT_MULTIPLIER + self.inc
state = state % self.MODULUS
self.state = state
return rotr_64((state >> 64) ^ (state & 0xFFFFFFFFFFFFFFFF), state >> 122)
@property
def next_64(self):
"""
Return a callable that accepts a single input. The input is usually
a void pointer that is cast to a struct that contains the PRNGs
state. When wiring a bit generator in Python, it is simpler to use
a closure than to wrap the state in an array, pass it's address as a
ctypes void pointer, and then to get the pointer in the function.
"""
def _next_64(void_p):
return self.random_raw()
self._next_64 = _next_64
return _next_64
@property
def next_32(self):
"""
Return a callable that accepts a single input. This is identical to
``next_64`` except that it return a 32-bit unsigned int. Here we save
half of the raw 64 bit output for subsequent calls.
"""
def _next_32(void_p):
if self._has_uint32:
self._has_uint32 = False
return self._uinteger
next_value = self.random_raw()
self._has_uint32 = True
self._uinteger = next_value >> 32
return next_value & 0xFFFFFFFF
self._next_32 = _next_32
return _next_32
@property
def state_getter(self):
def f():
return {"state": self.state, "inc": self.inc}
return f
@property
def state_setter(self):
def f(value):
self.state = value["state"]
self.inc = value["inc"]
return f
Next, we use UserBitGenerator
to expose the Python functions to C. The Python functions are wrapped in Ctypes callbacks under the hood.
[2]:
from randomgen import PCG64, UserBitGenerator
pcg = PCG64(0, mode="sequence", variant="xsl-rr")
state, inc = pcg.state["state"]["state"], pcg.state["state"]["inc"]
print("Get the state from a seeded PCG64")
print(pcg.state["state"])
prng = PythonPCG64(state, inc)
print("State and increment are identical")
print(prng.state_getter())
python_pcg = UserBitGenerator(prng.next_64, 64, next_32=prng.next_32)
print("First 5 values from PythonPCG64")
print(python_pcg.random_raw(5))
print("Match official C version")
print(pcg.random_raw(5))
Get the state from a seeded PCG64
{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}
State and increment are identical
{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}
First 5 values from PythonPCG64
[11749869230777074271 4976686463289251617 755828109848996024
304881062738325533 15002187965291974971]
Match official C version
[11749869230777074271 4976686463289251617 755828109848996024
304881062738325533 15002187965291974971]
python_pcg
is a bit generator, and so can be used with a NumPy Generator
. Here we see the state changes after producing a single standard normal.
[3]:
from numpy.random import Generator
gen = Generator(python_pcg)
print(f"Before: {prng.state}")
print(f"Std. Normal : {gen.standard_normal()}")
print(f"After: {prng.state}")
Before: 133411349017971402732463711865589153492
Std. Normal : 0.36159505490948474
After: 9405893610231781608176235507540826829
Accessing python_pcg.state
would raise NotImplementedError
. It is possible to wire up this function by setting state_setter
and state_getter
in UserBitGenerator
. These both take callable functions.
This time the state_getter
and state_setter
are used so that the state can be read and set through the bit generator.
[4]:
python_pcg = UserBitGenerator(
prng.next_64,
64,
next_32=prng.next_32,
state_getter=prng.state_getter,
state_setter=prng.state_setter,
)
python_pcg.state
[4]:
{'state': 9405893610231781608176235507540826829,
'inc': 87136372517582989555478159403783844777}
Performance¶
We can time random_raw
to see how fast (slow) the pure python version is. It is about 3 orders-of-magnitude (1000x) slower than the C implementation.
[5]:
%timeit python_pcg.random_raw(1000)
%timeit pcg.random_raw(1000)
3.08 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.55 µs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Using numba¶
A bit generator implemented in Numba can be used through the UserBitGenerator.from_cfunc
interface. The block below implements the JSF generator using numba.
The key outputs of the manager class are next_64
, next_32
, and next_double
, which are all decorated using numba’s @cfunc
.
[6]:
import ctypes
from numba import cfunc, types, carray, jit
from randomgen.wrapper import UserBitGenerator
rotate64_sig = types.uint64(types.uint64, types.int_)
@jit(signature_or_function=rotate64_sig, inline="always")
def rotate64(x, k):
return (x << k) | (x >> (64 - k))
jsf_next_sig = types.uint64(types.uint64[:])
@jit(signature_or_function=jsf_next_sig, inline="always")
def jsf_next(state):
"""
Update the state in place
This is a literal translation of the C code where the value p, q,
and r are fixed.
"""
# Default values
p = 7
q = 13
r = 37
# Update
e = state[0] - rotate64(state[1], p)
state[0] = state[1] ^ rotate64(state[2], q)
state[1] = state[2] + (rotate64(state[3], r) if r else state[3])
state[2] = state[3] + e
state[3] = e + state[0]
return state[3]
class NumbaJSF:
def __init__(self, seed):
if not isinstance(seed, (int, np.integer)) or not (0 <= state < 2**64):
raise ValueError("seed must be a valid uint64")
# state[0:4] is the JSF state
# state[4] contains both the has_uint flag in bit 0
# uinteger in bits 32...63
self._state = np.zeros(5, dtype=np.uint64)
self._next_raw = None
self._next_64 = None
self._next_32 = None
self._next_double = None
self.seed(seed)
def seed(self, value):
self._state[0] = 0xF1EA5EED
self._state[1] = value
self._state[2] = value
self._state[3] = value
for i in range(20):
jsf_next(self._state)
@property
def state_address(self):
"""Get the location in memory of the state NumPy array."""
return self._state.ctypes.data_as(ctypes.c_void_p)
@property
def next_64(self):
"""Same as raw since a 64 bit generator"""
# Ensure a reference is held
self._next_64 = self.next_raw
return self.next_raw
@property
def next_32(self):
"""A CFunc generating the next 32 bits"""
sig = types.uint32(types.CPointer(types.uint64))
@cfunc(sig)
def next_32(st):
# Get the NumPy uint64 array
bit_gen_state = carray(st, (5,), dtype=np.uint64)
# We use the first bit to indicate that 32 bits are stored in 32...63
if bit_gen_state[4] & np.uint64(0x1):
# Get the upper 32 bits
out = bit_gen_state[4] >> np.uint64(32)
# Clear the stored value
bit_gen_state[4] = 0
return out
# If no bits available, genrate a new value
z = jsf_next(bit_gen_state)
# Store the new value always with 1 in bit 0
bit_gen_state[4] = z | np.uint64(0x1)
# Return the lower 32 (0...31)
return z & 0xFFFFFFFF
# Ensure a reference is held
self._next_32 = next_32
return next_32
@property
def next_double(self):
"""A CFunc that generates the next ouble"""
sig = types.double(types.CPointer(types.uint64))
@cfunc(sig)
def next_double(st):
# Get the state
bit_gen_state = carray(st, (5,), dtype=np.uint64)
# Return the next value / 2**53
return (
np.uint64(jsf_next(bit_gen_state)) >> np.uint64(11)
) / 9007199254740992.0
# Ensure a reference is held
self._next_double = next_double
return next_double
@property
def next_raw(self):
sig = types.uint64(types.CPointer(types.uint64))
@cfunc(sig)
def next_64(st):
# Get the NumPy array containing the state
bit_gen_state = carray(st, (5,), dtype=np.uint64)
# Return the next value
return jsf_next(bit_gen_state)
# Ensure a reference is held
self._next_64 = next_64
return next_64
@property
def state_getter(self):
"""A function that returns the state. This is Python and is not decorated"""
def f() -> dict:
return {
"bit_gen": type(self).__name__,
"state": self._state[:4],
"has_uint": self._state[4] & np.uint64(0x1),
"uinteger": self._state[4] >> np.uint64(32),
}
return f
@property
def state_setter(self):
"""A function that sets the state. This is Python and is not decorated"""
def f(value: dict):
name = value.get("bit_gen", None)
if name != type(self).__name__:
raise ValueError(f"state must be from a {type(self).__name__}")
self._state[:4] = np.uint64(value["state"])
temp = np.uint64(value["uinteger"]) << np.uint64(32)
temp |= np.uint64(value["has_uint"]) & np.uint64(0x1)
self._state[4] = temp
return f
We start by instantizing the class and taking a look at the initial state.
[7]:
# From random.org
state = np.array([0x77, 0x5E, 0xB7, 0x11, 0x14, 0x3F, 0xD1, 0x0E], dtype=np.uint8).view(
np.uint64
)[0]
njsf = NumbaJSF(state)
njsf.state_getter()
[7]:
{'bit_gen': 'NumbaJSF',
'state': array([ 1167245051188668936, 13259944246262022926, 8870424784319794977,
9596734350428388680], dtype=uint64),
'has_uint': 0,
'uinteger': 0}
from_cfunc
is then used to pass the CFunc
s, state address pointer and the state getter and setter to UserBitGenerator
. We see that the state changes after calling random_raw
.
[8]:
jsf_ubg = UserBitGenerator.from_cfunc(
njsf.next_raw,
njsf.next_64,
njsf.next_32,
njsf.next_double,
njsf.state_address,
state_getter=njsf.state_getter,
state_setter=njsf.state_setter,
)
print(jsf_ubg.state)
print(jsf_ubg.random_raw(2))
print(jsf_ubg.state)
{'bit_gen': 'NumbaJSF', 'state': array([ 1167245051188668936, 13259944246262022926, 8870424784319794977,
9596734350428388680], dtype=uint64), 'has_uint': 0, 'uinteger': 0}
[ 602963287911976729 5264292724725465572]
{'bit_gen': 'NumbaJSF', 'state': array([ 530704699024515781, 2740075917084007745, 5336551313612926520,
5264292724725465572], dtype=uint64), 'has_uint': 0, 'uinteger': 0}
Some Generator
function use 32-bit integers to save bits. random
with dtype=np.float32
is one. After calling this function we see that has_uint
is now 1.
[9]:
gen = Generator(jsf_ubg)
print(f"A 32-bit float: {gen.random(dtype=np.float32)}")
print("Notice has_uint is now 1")
jsf_ubg.state
A 32-bit float: 0.16430795192718506
Notice has_uint is now 1
[9]:
{'bit_gen': 'NumbaJSF',
'state': array([13952735719045862400, 12103276313412614439, 5553417437478470678,
14241860431798867506], dtype=uint64),
'has_uint': 1,
'uinteger': 3315941531}
Performance¶
We can use random_raw
function to assess the performance and compare it to the C-implementation JSF
. It is about 6% slower which is an impressive outcome.
[10]:
%timeit jsf_ubg.random_raw(1000000)
4.4 ms ± 62.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[11]:
from randomgen import JSF
jsf = JSF()
%timeit jsf.random_raw(1000000)
4.19 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Next, we will clone the state of the native JSF
to the numba implementation.
[12]:
jsf_state = jsf.state
jsf_state
[12]:
{'bit_generator': 'JSF',
'state': {'a': 17190901158427765818,
'b': 14501513697102443756,
'c': 15715724510248929625,
'd': 12712143389959007425,
'p': 7,
'q': 13,
'r': 37},
'size': 64,
'has_uint32': 0,
'uinteger': 0,
'seed_size': 1}
While the structure of the state is different, the values are the same: [a, b, c, d]
.
[13]:
st = jsf_ubg.state
# Clone the C-implemntations's state and set it
st["state"][:4] = [jsf_state["state"][key] for key in ("a", "b", "c", "d")]
st["has_uint32"] = jsf_state["has_uint32"]
st["uinteger"] = jsf_state["uinteger"]
jsf_ubg.state = st
jsf_ubg.state
[13]:
{'bit_gen': 'NumbaJSF',
'state': array([17190901158427765818, 14501513697102443756, 15715724510248929625,
12712143389959007425], dtype=uint64),
'has_uint': 1,
'uinteger': 0}
Finally, we can take a look at the next few values to show that the implementations of the two generators are identical.
[14]:
jsf_ubg.random_raw(5)
[14]:
array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,
17987378307908897868, 18034113569054765009], dtype=uint64)
[15]:
jsf.random_raw(5)
[15]:
array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,
17987378307908897868, 18034113569054765009], dtype=uint64)