12345678910111213141516171819202122232425262728293031 |
- import importlib
- import torch
- from modules import shared
- def check_for_npu():
- if importlib.util.find_spec("torch_npu") is None:
- return False
- import torch_npu
- try:
- # Will raise a RuntimeError if no NPU is found
- _ = torch_npu.npu.device_count()
- return torch.npu.is_available()
- except RuntimeError:
- return False
- def get_npu_device_string():
- if shared.cmd_opts.device_id is not None:
- return f"npu:{shared.cmd_opts.device_id}"
- return "npu:0"
- def torch_npu_gc():
- with torch.npu.device(get_npu_device_string()):
- torch.npu.empty_cache()
- has_npu = check_for_npu()
|