npu_specific.py 649 B

12345678910111213141516171819202122232425262728293031
  1. import importlib
  2. import torch
  3. from modules import shared
  4. def check_for_npu():
  5. if importlib.util.find_spec("torch_npu") is None:
  6. return False
  7. import torch_npu
  8. try:
  9. # Will raise a RuntimeError if no NPU is found
  10. _ = torch_npu.npu.device_count()
  11. return torch.npu.is_available()
  12. except RuntimeError:
  13. return False
  14. def get_npu_device_string():
  15. if shared.cmd_opts.device_id is not None:
  16. return f"npu:{shared.cmd_opts.device_id}"
  17. return "npu:0"
  18. def torch_npu_gc():
  19. with torch.npu.device(get_npu_device_string()):
  20. torch.npu.empty_cache()
  21. has_npu = check_for_npu()