devices.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import sys, os, shlex
  2. import contextlib
  3. import torch
  4. from modules import errors
  5. from packaging import version
  6. # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
  7. # check `getattr` and try it for compatibility
  8. def has_mps() -> bool:
  9. if not getattr(torch, 'has_mps', False):
  10. return False
  11. try:
  12. torch.zeros(1).to(torch.device("mps"))
  13. return True
  14. except Exception:
  15. return False
  16. def extract_device_id(args, name):
  17. for x in range(len(args)):
  18. if name in args[x]:
  19. return args[x + 1]
  20. return None
  21. def get_cuda_device_string():
  22. from modules import shared
  23. if shared.cmd_opts.device_id is not None:
  24. return f"cuda:{shared.cmd_opts.device_id}"
  25. return "cuda"
  26. def get_optimal_device():
  27. if torch.cuda.is_available():
  28. return torch.device(get_cuda_device_string())
  29. if has_mps():
  30. return torch.device("mps")
  31. return cpu
  32. def get_device_for(task):
  33. from modules import shared
  34. if task in shared.cmd_opts.use_cpu:
  35. return cpu
  36. return get_optimal_device()
  37. def torch_gc():
  38. if torch.cuda.is_available():
  39. with torch.cuda.device(get_cuda_device_string()):
  40. torch.cuda.empty_cache()
  41. torch.cuda.ipc_collect()
  42. def enable_tf32():
  43. if torch.cuda.is_available():
  44. # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
  45. # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
  46. if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
  47. torch.backends.cudnn.benchmark = True
  48. torch.backends.cuda.matmul.allow_tf32 = True
  49. torch.backends.cudnn.allow_tf32 = True
  50. errors.run(enable_tf32, "Enabling TF32")
  51. cpu = torch.device("cpu")
  52. device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
  53. dtype = torch.float16
  54. dtype_vae = torch.float16
  55. def randn(seed, shape):
  56. torch.manual_seed(seed)
  57. if device.type == 'mps':
  58. return torch.randn(shape, device=cpu).to(device)
  59. return torch.randn(shape, device=device)
  60. def randn_without_seed(shape):
  61. if device.type == 'mps':
  62. return torch.randn(shape, device=cpu).to(device)
  63. return torch.randn(shape, device=device)
  64. def autocast(disable=False):
  65. from modules import shared
  66. if disable:
  67. return contextlib.nullcontext()
  68. if dtype == torch.float32 or shared.cmd_opts.precision == "full":
  69. return contextlib.nullcontext()
  70. return torch.autocast("cuda")
  71. # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
  72. orig_tensor_to = torch.Tensor.to
  73. def tensor_to_fix(self, *args, **kwargs):
  74. if self.device.type != 'mps' and \
  75. ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
  76. (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
  77. self = self.contiguous()
  78. return orig_tensor_to(self, *args, **kwargs)
  79. # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
  80. orig_layer_norm = torch.nn.functional.layer_norm
  81. def layer_norm_fix(*args, **kwargs):
  82. if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
  83. args = list(args)
  84. args[0] = args[0].contiguous()
  85. return orig_layer_norm(*args, **kwargs)
  86. # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
  87. if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
  88. torch.Tensor.to = tensor_to_fix
  89. torch.nn.functional.layer_norm = layer_norm_fix