rng.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. from modules import devices, rng_philox, shared
  3. def randn(seed, shape, generator=None):
  4. """Generate a tensor with random numbers from a normal distribution using seed.
  5. Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
  6. manual_seed(seed)
  7. if shared.opts.randn_source == "NV":
  8. return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
  9. if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
  10. return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
  11. return torch.randn(shape, device=devices.device, generator=generator)
  12. def randn_local(seed, shape):
  13. """Generate a tensor with random numbers from a normal distribution using seed.
  14. Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
  15. if shared.opts.randn_source == "NV":
  16. rng = rng_philox.Generator(seed)
  17. return torch.asarray(rng.randn(shape), device=devices.device)
  18. local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
  19. local_generator = torch.Generator(local_device).manual_seed(int(seed))
  20. return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
  21. def randn_like(x):
  22. """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
  23. Use either randn() or manual_seed() to initialize the generator."""
  24. if shared.opts.randn_source == "NV":
  25. return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
  26. if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
  27. return torch.randn_like(x, device=devices.cpu).to(x.device)
  28. return torch.randn_like(x)
  29. def randn_without_seed(shape, generator=None):
  30. """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
  31. Use either randn() or manual_seed() to initialize the generator."""
  32. if shared.opts.randn_source == "NV":
  33. return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
  34. if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
  35. return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
  36. return torch.randn(shape, device=devices.device, generator=generator)
  37. def manual_seed(seed):
  38. """Set up a global random number generator using the specified seed."""
  39. if shared.opts.randn_source == "NV":
  40. global nv_rng
  41. nv_rng = rng_philox.Generator(seed)
  42. return
  43. torch.manual_seed(seed)
  44. def create_generator(seed):
  45. if shared.opts.randn_source == "NV":
  46. return rng_philox.Generator(seed)
  47. device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
  48. generator = torch.Generator(device).manual_seed(int(seed))
  49. return generator
  50. # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
  51. def slerp(val, low, high):
  52. low_norm = low/torch.norm(low, dim=1, keepdim=True)
  53. high_norm = high/torch.norm(high, dim=1, keepdim=True)
  54. dot = (low_norm*high_norm).sum(1)
  55. if dot.mean() > 0.9995:
  56. return low * val + high * (1 - val)
  57. omega = torch.acos(dot)
  58. so = torch.sin(omega)
  59. res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
  60. return res
  61. class ImageRNG:
  62. def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
  63. self.shape = tuple(map(int, shape))
  64. self.seeds = seeds
  65. self.subseeds = subseeds
  66. self.subseed_strength = subseed_strength
  67. self.seed_resize_from_h = seed_resize_from_h
  68. self.seed_resize_from_w = seed_resize_from_w
  69. self.generators = [create_generator(seed) for seed in seeds]
  70. self.is_first = True
  71. def first(self):
  72. noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
  73. xs = []
  74. for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
  75. subnoise = None
  76. if self.subseeds is not None and self.subseed_strength != 0:
  77. subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
  78. subnoise = randn(subseed, noise_shape)
  79. if noise_shape != self.shape:
  80. noise = randn(seed, noise_shape)
  81. else:
  82. noise = randn(seed, self.shape, generator=generator)
  83. if subnoise is not None:
  84. noise = slerp(self.subseed_strength, noise, subnoise)
  85. if noise_shape != self.shape:
  86. x = randn(seed, self.shape, generator=generator)
  87. dx = (self.shape[2] - noise_shape[2]) // 2
  88. dy = (self.shape[1] - noise_shape[1]) // 2
  89. w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
  90. h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
  91. tx = 0 if dx < 0 else dx
  92. ty = 0 if dy < 0 else dy
  93. dx = max(-dx, 0)
  94. dy = max(-dy, 0)
  95. x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
  96. noise = x
  97. xs.append(noise)
  98. eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
  99. if eta_noise_seed_delta:
  100. self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
  101. return torch.stack(xs).to(shared.device)
  102. def next(self):
  103. if self.is_first:
  104. self.is_first = False
  105. return self.first()
  106. xs = []
  107. for generator in self.generators:
  108. x = randn_without_seed(self.shape, generator=generator)
  109. xs.append(x)
  110. return torch.stack(xs).to(shared.device)
  111. devices.randn = randn
  112. devices.randn_local = randn_local
  113. devices.randn_like = randn_like
  114. devices.randn_without_seed = randn_without_seed
  115. devices.manual_seed = manual_seed