rng_philox.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. """RNG imitiating torch cuda randn on CPU. You are welcome.
  2. Usage:
  3. ```
  4. g = Generator(seed=0)
  5. print(g.randn(shape=(3, 4)))
  6. ```
  7. Expected output:
  8. ```
  9. [[-0.92466259 -0.42534415 -2.6438457 0.14518388]
  10. [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
  11. [-1.07454231 -0.36314407 -1.67105067 2.26550497]]
  12. ```
  13. """
  14. import numpy as np
  15. philox_m = [0xD2511F53, 0xCD9E8D57]
  16. philox_w = [0x9E3779B9, 0xBB67AE85]
  17. two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
  18. two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
  19. def uint32(x):
  20. """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
  21. return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
  22. def philox4_round(counter, key):
  23. """A single round of the Philox 4x32 random number generator."""
  24. v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
  25. v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
  26. counter[0] = v2[1] ^ counter[1] ^ key[0]
  27. counter[1] = v2[0]
  28. counter[2] = v1[1] ^ counter[3] ^ key[1]
  29. counter[3] = v1[0]
  30. def philox4_32(counter, key, rounds=10):
  31. """Generates 32-bit random numbers using the Philox 4x32 random number generator.
  32. Parameters:
  33. counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
  34. key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
  35. rounds (int): The number of rounds to perform.
  36. Returns:
  37. numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
  38. """
  39. for _ in range(rounds - 1):
  40. philox4_round(counter, key)
  41. key[0] = key[0] + philox_w[0]
  42. key[1] = key[1] + philox_w[1]
  43. philox4_round(counter, key)
  44. return counter
  45. def box_muller(x, y):
  46. """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
  47. u = x * two_pow32_inv + two_pow32_inv / 2
  48. v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
  49. s = np.sqrt(-2.0 * np.log(u))
  50. r1 = s * np.sin(v)
  51. return r1.astype(np.float32)
  52. class Generator:
  53. """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
  54. def __init__(self, seed):
  55. self.seed = seed
  56. self.offset = 0
  57. def randn(self, shape):
  58. """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
  59. n = 1
  60. for x in shape:
  61. n *= x
  62. counter = np.zeros((4, n), dtype=np.uint32)
  63. counter[0] = self.offset
  64. counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
  65. self.offset += 1
  66. key = np.empty(n, dtype=np.uint64)
  67. key.fill(self.seed)
  68. key = uint32(key)
  69. g = philox4_32(counter, key)
  70. return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]