{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Writing a BitGenerator\n", "\n", "The standard method to write a bit generator involves writing a Cython `pyx`\n", "file that wraps some C source code containing a high-performance\n", "implementation of a Pseudo RNG. This leads to best-case performance without\n", "creating any external dependencies.\n", "\n", "`UserBitGenerator` provides a simple wrapper class that allows users to write\n", "bit generators in pure Python or, if performance is an issue, using Cython or by\n", "accessing functions in a compiled library (e.g., a DLL).\n", "\n", "Here we examine the steps needed to write a pure Python bit\n", "generator and higher-performance generator using Cython." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Using Python\n", "\n", "The example here begins by writing a class that implements the PCG64\n", "bit generator using the XSL-RR output transformation. While this is not\n", "a complete implementation (it does not support `advance` or `seed`),\n", "it is simple. The key to understanding PCG is that the underlying state\n", "is updated using a [Linear Congruential Generator (LCG)](https://en.wikipedia.org/wiki/Linear_congruential_generator)\n", "that uses a 128-bit state, multiplier and increment. The state evolves according to\n", "\n", "$$ s_{n+1} = m s_{n} + i $$\n", "\n", "where $s$ is the state, $m$ is the multiplier and $i$ is the increment.\n", "\n", "The PCG generator then transforms $s_{n+1}$ to produce the the final output. The XLS-RR\n", "output function XORs the upper 64 bits of the state with the lower\n", "64 bits before applying a random rotation of this value.\n", "\n", "The code below implements this generator using built-in Python operations\n", "and a little NumPy." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "# The rotation\n", "def rotr_64(value, rot):\n", " value = np.uint64(value)\n", " rot = np.uint64(rot)\n", " return int((value >> rot) | (value << ((-rot) & np.uint(63))))\n", "\n", "\n", "class PythonPCG64:\n", " # A 128 bit multiplier\n", " PCG_DEFAULT_MULTIPLIER = (2549297995355413924 << 64) + 4865540595714422341\n", " MODULUS = 2**128\n", "\n", " def __init__(self, state, inc):\n", " \"\"\"Directly set the state and increment, no seed support\"\"\"\n", " self.state = state\n", " self.inc = inc\n", " self._has_uint32 = False\n", " self._uinteger = 0\n", " self._next_32 = self._next_64 = None\n", "\n", " def random_raw(self):\n", " \"\"\"Generate the next \"raw\" value, which is 64 bits\"\"\"\n", " state = self.state * self.PCG_DEFAULT_MULTIPLIER + self.inc\n", " state = state % self.MODULUS\n", " self.state = state\n", " return rotr_64((state >> 64) ^ (state & 0xFFFFFFFFFFFFFFFF), state >> 122)\n", "\n", " @property\n", " def next_64(self):\n", " \"\"\"\n", " Return a callable that accepts a single input. The input is usually\n", " a void pointer that is cast to a struct that contains the PRNGs\n", " state. When wiring a bit generator in Python, it is simpler to use\n", " a closure than to wrap the state in an array, pass it's address as a\n", " ctypes void pointer, and then to get the pointer in the function.\n", " \"\"\"\n", "\n", " def _next_64(void_p):\n", " return self.random_raw()\n", "\n", " self._next_64 = _next_64\n", " return _next_64\n", "\n", " @property\n", " def next_32(self):\n", " \"\"\"\n", " Return a callable that accepts a single input. This is identical to\n", " ``next_64`` except that it return a 32-bit unsigned int. Here we save\n", " half of the raw 64 bit output for subsequent calls.\n", " \"\"\"\n", "\n", " def _next_32(void_p):\n", " if self._has_uint32:\n", " self._has_uint32 = False\n", " return self._uinteger\n", " next_value = self.random_raw()\n", " self._has_uint32 = True\n", " self._uinteger = next_value >> 32\n", " return next_value & 0xFFFFFFFF\n", "\n", " self._next_32 = _next_32\n", " return _next_32\n", "\n", " @property\n", " def state_getter(self):\n", " def f():\n", " return {\"state\": self.state, \"inc\": self.inc}\n", "\n", " return f\n", "\n", " @property\n", " def state_setter(self):\n", " def f(value):\n", " self.state = value[\"state\"]\n", " self.inc = value[\"inc\"]\n", "\n", " return f" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Next, we use `UserBitGenerator` to expose the Python functions to C.\n", "The Python functions are wrapped in Ctypes callbacks under the hood." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Get the state from a seeded PCG64\n", "{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}\n", "State and increment are identical\n", "{'state': 35399562948360463058890781895381311971, 'inc': 87136372517582989555478159403783844777}\n", "First 5 values from PythonPCG64\n", "[11749869230777074271 4976686463289251617 755828109848996024\n", " 304881062738325533 15002187965291974971]\n", "Match official C version\n", "[11749869230777074271 4976686463289251617 755828109848996024\n", " 304881062738325533 15002187965291974971]\n" ] } ], "source": [ "from randomgen import PCG64, UserBitGenerator\n", "\n", "pcg = PCG64(0, mode=\"sequence\", variant=\"xsl-rr\")\n", "state, inc = pcg.state[\"state\"][\"state\"], pcg.state[\"state\"][\"inc\"]\n", "print(\"Get the state from a seeded PCG64\")\n", "print(pcg.state[\"state\"])\n", "prng = PythonPCG64(state, inc)\n", "print(\"State and increment are identical\")\n", "print(prng.state_getter())\n", "python_pcg = UserBitGenerator(prng.next_64, 64, next_32=prng.next_32)\n", "print(\"First 5 values from PythonPCG64\")\n", "print(python_pcg.random_raw(5))\n", "print(\"Match official C version\")\n", "print(pcg.random_raw(5))" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "`python_pcg` _is_ a bit generator, and so can be used with a NumPy `Generator`.\n", "Here we see the state changes after producing a single standard normal." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before: 133411349017971402732463711865589153492\n", "Std. Normal : 0.36159505490948474\n", "After: 9405893610231781608176235507540826829\n" ] } ], "source": [ "from numpy.random import Generator\n", "\n", "gen = Generator(python_pcg)\n", "print(f\"Before: {prng.state}\")\n", "print(f\"Std. Normal : {gen.standard_normal()}\")\n", "print(f\"After: {prng.state}\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Accessing `python_pcg.state` would raise `NotImplementedError`. It is possible to\n", "wire up this function by setting `state_setter` and `state_getter` in `UserBitGenerator`.\n", "These both take callable functions.\n", "\n", "This time the `state_getter` and `state_setter` are used so that the state can be read\n", "and set through the bit generator." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'state': 9405893610231781608176235507540826829,\n", " 'inc': 87136372517582989555478159403783844777}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "python_pcg = UserBitGenerator(\n", " prng.next_64,\n", " 64,\n", " next_32=prng.next_32,\n", " state_getter=prng.state_getter,\n", " state_setter=prng.state_setter,\n", ")\n", "python_pcg.state" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Performance\n", "We can time `random_raw` to see how fast (**slow**) the pure python version is. It is about 3 orders-of-magnitude (1000x) slower than the C implementation." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.08 ms ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", "4.55 µs ± 21.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" ] } ], "source": [ "%timeit python_pcg.random_raw(1000)\n", "%timeit pcg.random_raw(1000)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Using numba\n", "\n", "A bit generator implemented in Numba can be used through the `UserBitGenerator.from_cfunc` interface. The block below implements the JSF generator using numba. \n", "\n", "The key outputs of the manager class are `next_64`, `next_32`, and `next_double`, which are all decorated using numba's `@cfunc`. \n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import ctypes\n", "\n", "import numpy as np\n", "from numba import cfunc, types, carray, jit\n", "\n", "from randomgen.wrapper import UserBitGenerator\n", "\n", "\n", "rotate64_sig = types.uint64(types.uint64, types.int_)\n", "\n", "\n", "@jit(signature_or_function=rotate64_sig, inline=\"always\")\n", "def rotate64(x, k):\n", " return (x << k) | (x >> (64 - k))\n", "\n", "\n", "jsf_next_sig = types.uint64(types.uint64[:])\n", "\n", "\n", "@jit(signature_or_function=jsf_next_sig, inline=\"always\")\n", "def jsf_next(state):\n", " \"\"\"\n", " Update the state in place\n", "\n", " This is a literal translation of the C code where the value p, q,\n", " and r are fixed.\n", " \"\"\"\n", " # Default values\n", " p = 7\n", " q = 13\n", " r = 37\n", " # Update\n", " e = state[0] - rotate64(state[1], p)\n", " state[0] = state[1] ^ rotate64(state[2], q)\n", " state[1] = state[2] + (rotate64(state[3], r) if r else state[3])\n", " state[2] = state[3] + e\n", " state[3] = e + state[0]\n", " return state[3]\n", "\n", "\n", "class NumbaJSF:\n", " def __init__(self, seed):\n", " if not isinstance(seed, (int, np.integer)) or not (0 <= state < 2**64):\n", " raise ValueError(\"seed must be a valid uint64\")\n", " # state[0:4] is the JSF state\n", " # state[4] contains both the has_uint flag in bit 0\n", " # uinteger in bits 32...63\n", " self._state = np.zeros(5, dtype=np.uint64)\n", " self._next_raw = None\n", " self._next_64 = None\n", " self._next_32 = None\n", " self._next_double = None\n", " self.seed(seed)\n", "\n", " def seed(self, value):\n", " self._state[0] = 0xF1EA5EED\n", " self._state[1] = value\n", " self._state[2] = value\n", " self._state[3] = value\n", " for i in range(20):\n", " jsf_next(self._state)\n", "\n", " @property\n", " def state_address(self):\n", " \"\"\"Get the location in memory of the state NumPy array.\"\"\"\n", " return self._state.ctypes.data_as(ctypes.c_void_p)\n", "\n", " @property\n", " def next_64(self):\n", " \"\"\"Same as raw since a 64 bit generator\"\"\"\n", "\n", " # Ensure a reference is held\n", " self._next_64 = self.next_raw\n", "\n", " return self.next_raw\n", "\n", " @property\n", " def next_32(self):\n", " \"\"\"A CFunc generating the next 32 bits\"\"\"\n", " sig = types.uint32(types.CPointer(types.uint64))\n", "\n", " @cfunc(sig)\n", " def next_32(st):\n", " # Get the NumPy uint64 array\n", " bit_gen_state = carray(st, (5,), dtype=np.uint64)\n", " # We use the first bit to indicate that 32 bits are stored in 32...63\n", " if bit_gen_state[4] & np.uint64(0x1):\n", " # Get the upper 32 bits\n", " out = bit_gen_state[4] >> np.uint64(32)\n", " # Clear the stored value\n", " bit_gen_state[4] = 0\n", " return out\n", " # If no bits available, genrate a new value\n", " z = jsf_next(bit_gen_state)\n", " # Store the new value always with 1 in bit 0\n", " bit_gen_state[4] = z | np.uint64(0x1)\n", " # Return the lower 32 (0...31)\n", " return z & 0xFFFFFFFF\n", "\n", " # Ensure a reference is held\n", " self._next_32 = next_32\n", "\n", " return next_32\n", "\n", " @property\n", " def next_double(self):\n", " \"\"\"A CFunc that generates the next ouble\"\"\"\n", " sig = types.double(types.CPointer(types.uint64))\n", "\n", " @cfunc(sig)\n", " def next_double(st):\n", " # Get the state\n", " bit_gen_state = carray(st, (5,), dtype=np.uint64)\n", " # Return the next value / 2**53\n", " return (\n", " np.uint64(jsf_next(bit_gen_state)) >> np.uint64(11)\n", " ) / 9007199254740992.0\n", "\n", " # Ensure a reference is held\n", " self._next_double = next_double\n", "\n", " return next_double\n", "\n", " @property\n", " def next_raw(self):\n", " sig = types.uint64(types.CPointer(types.uint64))\n", "\n", " @cfunc(sig)\n", " def next_64(st):\n", " # Get the NumPy array containing the state\n", " bit_gen_state = carray(st, (5,), dtype=np.uint64)\n", " # Return the next value\n", " return jsf_next(bit_gen_state)\n", "\n", " # Ensure a reference is held\n", " self._next_64 = next_64\n", "\n", " return next_64\n", "\n", " @property\n", " def state_getter(self):\n", " \"\"\"A function that returns the state. This is Python and is not decorated\"\"\"\n", "\n", " def f() -> dict:\n", " return {\n", " \"bit_gen\": type(self).__name__,\n", " \"state\": self._state[:4],\n", " \"has_uint\": self._state[4] & np.uint64(0x1),\n", " \"uinteger\": self._state[4] >> np.uint64(32),\n", " }\n", "\n", " return f\n", "\n", " @property\n", " def state_setter(self):\n", " \"\"\"A function that sets the state. This is Python and is not decorated\"\"\"\n", "\n", " def f(value: dict):\n", " name = value.get(\"bit_gen\", None)\n", " if name != type(self).__name__:\n", " raise ValueError(f\"state must be from a {type(self).__name__}\")\n", " self._state[:4] = np.uint64(value[\"state\"])\n", " temp = np.uint64(value[\"uinteger\"]) << np.uint64(32)\n", " temp |= np.uint64(value[\"has_uint\"]) & np.uint64(0x1)\n", " self._state[4] = temp\n", "\n", " return f" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "We start by instantizing the class and taking a look at the initial state." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'bit_gen': 'NumbaJSF',\n", " 'state': array([ 1167245051188668936, 13259944246262022926, 8870424784319794977,\n", " 9596734350428388680], dtype=uint64),\n", " 'has_uint': 0,\n", " 'uinteger': 0}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# From random.org\n", "state = np.array([0x77, 0x5E, 0xB7, 0x11, 0x14, 0x3F, 0xD1, 0x0E], dtype=np.uint8).view(\n", " np.uint64\n", ")[0]\n", "njsf = NumbaJSF(state)\n", "njsf.state_getter()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "`from_cfunc` is then used to pass the `CFunc`s, state address pointer and the state getter and setter to `UserBitGenerator`. We see that the state changes after calling `random_raw`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'bit_gen': 'NumbaJSF', 'state': array([ 1167245051188668936, 13259944246262022926, 8870424784319794977,\n", " 9596734350428388680], dtype=uint64), 'has_uint': 0, 'uinteger': 0}\n", "[ 602963287911976729 5264292724725465572]\n", "{'bit_gen': 'NumbaJSF', 'state': array([ 530704699024515781, 2740075917084007745, 5336551313612926520,\n", " 5264292724725465572], dtype=uint64), 'has_uint': 0, 'uinteger': 0}\n" ] } ], "source": [ "jsf_ubg = UserBitGenerator.from_cfunc(\n", " njsf.next_raw,\n", " njsf.next_64,\n", " njsf.next_32,\n", " njsf.next_double,\n", " njsf.state_address,\n", " state_getter=njsf.state_getter,\n", " state_setter=njsf.state_setter,\n", ")\n", "print(jsf_ubg.state)\n", "print(jsf_ubg.random_raw(2))\n", "print(jsf_ubg.state)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Some `Generator` function use 32-bit integers to save bits. `random` with `dtype=np.float32` is one. After calling this function we see that `has_uint` is now 1." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A 32-bit float: 0.16430795192718506\n", "Notice has_uint is now 1\n" ] }, { "data": { "text/plain": [ "{'bit_gen': 'NumbaJSF',\n", " 'state': array([13952735719045862400, 12103276313412614439, 5553417437478470678,\n", " 14241860431798867506], dtype=uint64),\n", " 'has_uint': 1,\n", " 'uinteger': 3315941531}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gen = Generator(jsf_ubg)\n", "print(f\"A 32-bit float: {gen.random(dtype=np.float32)}\")\n", "print(\"Notice has_uint is now 1\")\n", "jsf_ubg.state" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Performance \n", "We can use `random_raw` function to assess the performance and compare it to the C-implementation ``JSF``. It is about 6% slower which is an impressive outcome." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.4 ms ± 62.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%timeit jsf_ubg.random_raw(1000000)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.19 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "from randomgen import JSF\n", "\n", "jsf = JSF()\n", "%timeit jsf.random_raw(1000000)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Next, we will clone the state of the native ``JSF`` to the numba implementation." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'bit_generator': 'JSF',\n", " 'state': {'a': 17190901158427765818,\n", " 'b': 14501513697102443756,\n", " 'c': 15715724510248929625,\n", " 'd': 12712143389959007425,\n", " 'p': 7,\n", " 'q': 13,\n", " 'r': 37},\n", " 'size': 64,\n", " 'has_uint32': 0,\n", " 'uinteger': 0,\n", " 'seed_size': 1}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsf_state = jsf.state\n", "jsf_state" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "While the structure of the state is different, the values are the same: ``[a, b, c, d]``." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "{'bit_gen': 'NumbaJSF',\n", " 'state': array([17190901158427765818, 14501513697102443756, 15715724510248929625,\n", " 12712143389959007425], dtype=uint64),\n", " 'has_uint': 1,\n", " 'uinteger': 0}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "st = jsf_ubg.state\n", "# Clone the C-implemntations's state and set it\n", "st[\"state\"][:4] = [jsf_state[\"state\"][key] for key in (\"a\", \"b\", \"c\", \"d\")]\n", "st[\"has_uint32\"] = jsf_state[\"has_uint32\"]\n", "st[\"uinteger\"] = jsf_state[\"uinteger\"]\n", "jsf_ubg.state = st\n", "jsf_ubg.state" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Finally, we can take a look at the next few values to show that the implementations of the two generators are identical." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,\n", " 17987378307908897868, 18034113569054765009], dtype=uint64)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsf_ubg.random_raw(5)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "array([ 3814417803339974021, 15780814468893899944, 17400468283504521969,\n", " 17987378307908897868, 18034113569054765009], dtype=uint64)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jsf.random_raw(5)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }