Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Mersenne Twister PRNG #984

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import Batteries.Data.MLList
import Batteries.Data.Nat
import Batteries.Data.PairingHeap
import Batteries.Data.RBMap
import Batteries.Data.Random
import Batteries.Data.Range
import Batteries.Data.Rat
import Batteries.Data.Stream
import Batteries.Data.String
import Batteries.Data.Sum
import Batteries.Data.UInt
Expand Down
1 change: 1 addition & 0 deletions Batteries/Data/Random.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import Batteries.Data.Random.MersenneTwister
165 changes: 165 additions & 0 deletions Batteries/Data/Random/MersenneTwister.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: François G. Dorais
-/
import Batteries.Data.Vector

/-! # Mersenne Twister

Generic implementation for the Mersenne Twister pseudorandom number generator.

All choices of parameters from Matsumoto and Nishimura (1998) are supported, along with later
refinements. Parameters for the standard 32-bit MT19937 and 64-bit MT19937-64 algorithms are
provided. Both `RandomGen` and `Stream` interfaces are provided.

Use `mt19937.init seed` to create a MT19937 PRNG with a 32 bit seed value; use
`mt19937_64.init seed` to create a MT19937-64 PRNG with a 64 bit seed value. If omitted, default
seed choices will be used.

Sample usage:
```
import Batteries.Data.Random.MersenneTwister

open Batteries.Random.MersenneTwister

def mtgen := mt19937.init -- default seed 4357

#eval (Stream.take mtgen 5).fst -- #[874448474, 2424656266, 2174085406, 1265871120, 3155244894]
```

### References:

- Matsumoto, Makoto and Nishimura, Takuji (1998),
[**Mersenne twister: A 623-dimensionally equidistributed uniform pseudo-random number generator**](https://doi.org/10.1145/272991.272995),
ACM Trans. Model. Comput. Simul. 8, No. 1, 3-30.
[ZBL0917.65005](https://zbmath.org/?q=an:0917.65005).

- Nishimura, Takuji (2000),
[**Tables of 64-bit Mersenne twisters**](https://doi.org/10.1145/369534.369540),
ACM Trans. Model. Comput. Simul. 10, No. 4, 348-357.
[ZBL1390.65014](https://zbmath.org/?q=an:1390.65014).
-/

namespace Batteries.Random.MersenneTwister

/--
Mersenne Twister configuration.

Letters in parentheses correspond to variable names used by Matsumoto and Nishimura (1998) and
Nishimura (2000).
-/
structure Config where
/-- Word size (`w`). -/
wordSize : Nat
/-- Degree of recurrence (`n`). -/
stateSize : Nat
/-- Middle word (`m`). -/
shiftSize : Fin stateSize
/-- Twist value (`r`). -/
maskBits : Fin wordSize
/-- Coefficients of the twist matrix (`a`). -/
xorMask : BitVec wordSize
/-- Tempering shift parameters (`u`, `s`, `t`, `l`). -/
temperingShifts : Nat × Nat × Nat × Nat
/-- Tempering mask parameters (`d`, `b`, `c`). -/
temperingMasks : BitVec wordSize × BitVec wordSize × BitVec wordSize
/-- Initialization multiplier (`f`). -/
initMult : BitVec wordSize
/-- Default initialization seed value. -/
initSeed : BitVec wordSize

private abbrev Config.uMask (cfg : Config) : BitVec cfg.wordSize :=
BitVec.allOnes cfg.wordSize <<< cfg.maskBits.val

private abbrev Config.lMask (cfg : Config) : BitVec cfg.wordSize :=
BitVec.allOnes cfg.wordSize >>> (cfg.wordSize - cfg.maskBits.val)

@[simp] theorem Config.zero_lt_wordSize (cfg : Config) : 0 < cfg.wordSize :=
Nat.zero_lt_of_lt cfg.maskBits.is_lt

@[simp] theorem Config.zero_lt_stateSize (cfg : Config) : 0 < cfg.stateSize :=
Nat.zero_lt_of_lt cfg.shiftSize.is_lt

/-- Mersenne Twister State. -/
structure State (cfg : Config) where
/-- Data for current state. -/
data : Vector (BitVec cfg.wordSize) cfg.stateSize
/-- Current data index. -/
index : Fin cfg.stateSize

/-- Mersenne Twister initialization given an optional seed. -/
@[specialize cfg] protected def Config.init (cfg : MersenneTwister.Config)
(seed : BitVec cfg.wordSize := cfg.initSeed) : State cfg :=
⟨loop seed (.mkEmpty cfg.stateSize) (Nat.zero_le _), 0, cfg.zero_lt_stateSize⟩
where
/-- Inner loop for Mersenne Twister initalization. -/
loop (w : BitVec cfg.wordSize) (v : Array (BitVec cfg.wordSize)) (h : v.size ≤ cfg.stateSize) :=
if heq : v.size = cfg.stateSize then ⟨v, heq⟩ else
let v := v.push w
let w := cfg.initMult * (w ^^^ (w >>> cfg.wordSize - 2)) + v.size
loop w v (by simp only [v, Array.size_push]; omega)

/-- Apply the twisting transformation to the given state. -/
@[specialize cfg] protected def State.twist (state : State cfg) : State cfg :=
let i := state.index
let i' : Fin cfg.stateSize :=
if h : i.val+1 < cfg.stateSize then ⟨i.val+1, h⟩ else ⟨0, cfg.zero_lt_stateSize⟩
let y := state.data[i] &&& cfg.uMask ||| state.data[i'] &&& cfg.lMask
let x := state.data[i+cfg.shiftSize] ^^^ bif y[0] then y >>> 1 ^^^ cfg.xorMask else y >>> 1
⟨state.data.set i x, i'⟩

/-- Update the state by a number of generation steps (default 1). -/
-- TODO: optimize to `O(log(steps))` using the minimal polynomial
protected def State.update (state : State cfg) : (steps : Nat := 1) → State cfg
| 0 => state
| steps+1 => state.twist.update steps

/-- Mersenne Twister iteration. -/
@[specialize cfg] protected def State.next (state : State cfg) : BitVec cfg.wordSize × State cfg :=
let i := state.index
let s := state.twist
(temper s.data[i], s)
where
/-- Tempering step for Mersenne Twister. -/
@[inline] temper (x : BitVec cfg.wordSize) :=
match cfg.temperingShifts, cfg.temperingMasks with
| (u, s, t, l), (d, b, c) =>
let x := x ^^^ x >>> u &&& d
let x := x ^^^ x <<< s &&& b
let x := x ^^^ x <<< t &&& c
x ^^^ x >>> l

instance (cfg) : RandomGen (State cfg) where
range _ := (0, 2 ^ cfg.wordSize - 1)
next s := match s.next with | (r, s) => (r.toNat, s)
split s :=
-- TODO: use `(s, s.update (2 ^ 128))` once `update` is optimized.
let (a, s) := s.next; (s, cfg.init a)

instance (cfg) : Stream (State cfg) (BitVec cfg.wordSize) where
next? s := s.next

/-- 32 bit Mersenne Twister (MT19937) configuration. -/
def mt19937 : Config where
wordSize := 32
stateSize := 624
shiftSize := 397
maskBits := 31
xorMask := 0x9908b0df
temperingShifts := (11, 7, 15, 18)
temperingMasks := (0xffffffff, 0x9d2c5680, 0xefc60000)
initMult := 1812433253
initSeed := 4357

/-- 64 bit Mersenne Twister (MT19937-64) configuration. -/
def mt19937_64 : Config where
wordSize := 64
stateSize := 312
shiftSize := 156
maskBits := 31
xorMask := 0xb5026f5aa96619e9
temperingShifts := (29, 17, 37, 43)
temperingMasks := (0x5555555555555555, 0x71d67fffeda60000, 0xfff7eee000000000)
initMult := 6364136223846793005
initSeed := 19650218
23 changes: 23 additions & 0 deletions Batteries/Data/Stream.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2. license as described in the file LICENSE.
Authors: François G. Dorais
-/

/-- Drop up to `n` values from the stream `s`. -/
def Stream.drop [Stream σ α] (s : σ) : Nat → σ
| 0 => s
| n+1 => match Stream.next? s with
| none => s
| some (_, s) => drop s n

/-- Read up to `n` values from the stream `s`. -/
def Stream.take [Stream σ α] (s : σ) (n : Nat) : Array α × σ :=
loop s (.mkEmpty n) n
where
/-- Inner loop for `Stream.take`. -/
loop (s : σ) (acc : Array α)
| 0 => (acc, s)
| n+1 => match Stream.next? s with
| none => (acc, s)
| some (v, s) => loop s (acc.push v) n
85 changes: 85 additions & 0 deletions test/mersenne_twister.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import Batteries.Data.Random.MersenneTwister
import Batteries.Data.Stream

open Batteries.Random.MersenneTwister

#guard (Stream.take mt19937.init 5).1 == #[874448474, 2424656266, 2174085406, 1265871120, 3155244894]

/- Sample output was generated using `numpy`'s implementation of MT19937:
```python
from numpy import array, uint32
from numpy.random import MT19937

mt = MT19937()
mt.state = {
'bit_generator' : 'MT19937',
'state' : {
'pos' : 624,
'key' : array([
4357, 1673174024, 1301878288, 1129097449, 2180885271, 2495295730, 3729202114, 3451529139, 2624228201, 696045212,
2296245684, 4097888573, 2110311931, 1672374534, 381896678, 2887874951, 3859861197, 420983856, 1691952728, 4233606289,
1707944415, 3515687962, 4265198858, 1433261659, 1131854641, 228846788, 3811811324, 873525989, 588291779, 2854617646,
948269870, 3798261295, 3422826645, 340138072, 3671734944, 3961007161, 2839350439, 3264455490, 310719058, 2570596611,
3750039289, 648992492, 3816674884, 2210726029, 371217291, 196912982, 3046892150, 470118103, 1302935133, 362465408,
1360220904, 2946174945, 1630294895, 3570642538, 1798333338, 1196832683, 226789057, 2740096276, 1062441100, 1875507765,
2599873619, 1037523070, 4029519294, 3231722367, 2232344613, 3458909352, 2906353456, 3064815497, 3166305847,
3658630546, 3632421090, 885320275, 1621369481, 1258557244, 2827734740, 3209486301, 131295515, 2191201702, 44141830,
1183978535, 4202966509, 801836240, 2303299448, 333191985, 4114943231, 1490315450, 453120554, 759253243, 1381163601,
3455606116, 1027445020, 1144697221, 3040135651, 4176273102, 798935118, 49817807, 2492997557, 3171983608, 2742334400,
1282687705, 1047297991, 3697219554, 1400278898, 3276297123, 843040281, 354711436, 4156544868, 2873126701, 3990490795,
3966874614, 1376536470, 4189022583, 2283386237, 3645931808, 1312021512, 679663233, 3054458511, 1152865034, 1927729338,
538380875, 374984161, 2453495220, 514433452, 1271601365, 3737270131, 630101278, 1292962526, 2908018207, 1209528133,
413117768, 3762161744, 2194986537, 1414304087, 379722290, 2862208514, 3551161587, 3402627497, 2411204572, 3033657332,
4161252989, 2267825211, 963150406, 2081690150, 4014304967, 1977732365, 2412979568, 613038232, 418857425, 3682807839,
3416550746, 3692470090, 2764012443, 3255912817, 2160692740, 3914318396, 3437441061, 2828481795, 3655629678, 582770030,
2946380655, 3506851541, 612362648, 3394202848, 1530337657, 3360830183, 570641538, 153365650, 1624454723, 80526649,
1365694508, 2272925828, 34250189, 3066169803, 631734422, 3706776758, 3443270679, 659846301, 3707435456, 3573851432,
1017208097, 1100519855, 1824765866, 3284762074, 2887949547, 569464065, 3057970772, 1726477004, 3119183733, 3349922451,
4162228670, 249085950, 3854319807, 1155219045, 811161064, 207675760, 50531529, 141911159, 3819613906, 2655884066,
3517624211, 514724041, 2094583932, 3681571092, 3518053661, 2207473499, 961982182, 1423628102, 628853095, 3823741997,
1450180112, 1817911736, 384378993, 1749521215, 4080873978, 2604100714, 2468900411, 1718743185, 3679944356, 623522652,
2974445253, 351789091, 776787982, 4087231118, 395771407, 2634989045, 2547249720, 2502583808, 3550523417, 648947207,
2361409826, 2639137202, 4179155171, 3136025689, 3233151180, 3765213604, 459508845, 412632299, 3365801270, 1208603094,
1978375863, 3608769469, 2648322656, 994422344, 1463198657, 1938300111, 1983437898, 3617090298, 582545291, 604707873,
615071476, 1976468460, 4251555349, 2373160371, 4138683998, 927249694, 4178996063, 3071856005, 3264724616, 2539911824,
1383596905, 3639900055, 2590770034, 1029541954, 369472051, 3757991913, 1470517532, 2317808180, 1065978813, 3301489275,
4087716742, 2662718566, 678716423, 274451277, 1625396912, 3598469848, 3639725841, 726808159, 1490990746, 4062476682,
2411471067, 1395972017, 1390554948, 1854727292, 2494590309, 1377225539, 2540041390, 3288614830, 706906287, 1416719637,
609008344, 2311429920, 821102265, 2034260263, 3587569090, 3115591378, 3545840515, 4166871929, 139581804, 2421643972,
1250638605, 4212965387, 2794805718, 3306616566, 2466109783, 2200482525, 1496197888, 381089640, 2743249505, 4221427695,
1247199466, 1746114586, 2065302059, 1348936513, 2997505940, 3911013644, 428274869, 2816055507, 580438782, 135588414,
916674047, 445684901, 1016784680, 654791600, 1282652681, 92916407, 1411782674, 1367985506, 1207661779, 3531669257,
627085756, 1857409876, 4107311709, 1384928667, 2576697382, 2875531654, 4151312039, 116927085, 1281879888, 414036984,
3931190705, 4100135295, 1170799418, 3130902186, 4055536507, 3692691153, 480878564, 2201474460, 3663014917, 4155766371,
1987039566, 4121861326, 2525025103, 2465094709, 2536129400, 1843468352, 2926058841, 533253191, 1988389474, 1209435122,
4141112867, 2699109017, 2373614092, 1694129124, 2730600877, 2249161515, 1355638390, 3319290902, 2209534967,
1463955965, 204923808, 1025015944, 214266113, 3382305551, 2455594378, 1861944634, 1820710091, 449145441, 4119339060,
2660525612, 3515028309, 3466454003, 1024657310, 50945886, 2913140895, 721595333, 3416444872, 2701847760, 2352361641,
234184151, 3927502002, 3834792578, 3469473651, 4193637929, 2873594460, 1994191988, 1690724605, 1956524219, 476427462,
212379302, 1370380615, 327076237, 1984104432, 682581272, 2521259089, 3543809183, 3275489242, 241390538, 3496199707,
2497799665, 770560132, 1626015420, 2776148645, 3717161347, 3970592238, 710750702, 3421625839, 876972885, 2108460056,
1195168096, 1195766777, 3121053543, 2819333890, 1916084498, 717897923, 3627489721, 1970264748, 1813355780, 4148615245,
556824139, 411448086, 4228776246, 1732939415, 3206934813, 1949588544, 3291105704, 1044314017, 222045743, 3079457322,
638497370, 1849452395, 921039233, 1115861204, 3019093836, 2828923381, 4185943827, 3344827454, 3923907710, 760572735,
3828284133, 1559197800, 724485616, 1828677449, 2985767159, 4119101778, 1077348258, 3518446099, 2585587017, 1855673084,
3495712148, 3265984413, 2998815707, 760668518, 2487249862, 3060757479, 3249514669, 4222804112, 1010910776, 3893641969,
395812799, 2591540346, 1194664170, 49789115, 1363873041, 1005502756, 1164343260, 3646613829, 459869347, 3679832718,
1137706766, 4189431951, 1412889205, 622040248, 1536739968, 3066727065, 666661511, 1672188834, 2714762802, 4135248739,
35606745, 2775710540, 4083752484, 3680159469, 1950331243, 251641782, 1501029974, 486869303, 1720971325, 241603808,
28070600, 2737782337, 910469455, 3810848458, 118398842, 3078470155, 2559096993, 2933522804, 2264615020, 3793195157,
1614887475, 45727966, 3193899422, 1157273055, 2178255365, 2646663432, 724754192, 168779241, 4048503831, 3483948530,
3996648642, 939343027, 917914729, 3030111132, 3908302516, 29247037, 3568084731, 1034472966, 1408004326, 1693666951,
3712665549, 3120003376, 3374542680, 2868373905, 1362838239, 1421625626, 4275252746, 548825947, 622261297, 3152835012,
2926192892, 423356389, 151058371, 3820087086, 1673993262, 252457775, 1317185941, 2594135384, 817169312, 2016796985,
2292688295, 1654933570, 2158435154, 2703640067, 3260663801, 3267419116, 2293555012, 2721936781, 1727868043, 91884630,
265685878, 1143096279, 961294173, 403541376, 2338233320, 1725318369, 4101205103, 4268086122, 3418016922, 1065995435,
1936572353, 265163284, 3043694988, 2167402293, 2057323859, 4033232254, 3258990270, 1137868927, 2142656805, 4216785320,
1188509744, 1051071625, 196974391, 2445666962, 3092595170, 2833121107, 2474761097, 2190021692, 1852037076, 3577763037,
3794354715, 2124118694, 2641147398, 1551493415, 1913661165, 1313919440, 2232801400, 1781682225, 1340417535, 994676154,
251493162, 2162155003, 1678056273, 3810976356, 1505106460, 3361449605, 1041703651, 1727972302, 3959583054, 3140845007,
3202914485, 2878334456, 2354150592, 3334993881, 1015617735, 506838242, 4168775794, 839674019, 4238769945, 849116300,
4189642852, 1596908589, 556328875, 2369067254, 2431152278, 1004682871], dtype=uint32)}}

print(mt.random_raw(5))
```
-/
Loading