Browse Source

Add NPU Support

wangshuai09 1 year ago
parent
commit
ec124607f4

+ 7 - 2
modules/devices.py

@@ -3,7 +3,7 @@ import contextlib
 from functools import lru_cache
 
 import torch
-from modules import errors, shared
+from modules import errors, shared, npu_specific
 
 if sys.platform == "darwin":
     from modules import mac_specific
@@ -40,6 +40,9 @@ def get_optimal_device_name():
     if has_xpu():
         return xpu_specific.get_xpu_device_string()
 
+    if npu_specific.has_npu:
+        return npu_specific.get_npu_device_string()
+
     return "cpu"
 
 
@@ -67,6 +70,9 @@ def torch_gc():
     if has_xpu():
         xpu_specific.torch_xpu_gc()
 
+    if npu_specific.has_npu:
+        npu_specific.torch_npu_gc()
+
 
 def enable_tf32():
     if torch.cuda.is_available():
@@ -164,4 +170,3 @@ def first_time_calculation():
     x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
     conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
     conv2d(x)
-

+ 5 - 1
modules/initialize.py

@@ -143,13 +143,17 @@ def initialize_rest(*, reload_script_modules=False):
         its optimization may be None because the list of optimizaers has neet been filled
         by that time, so we apply optimization again.
         """
+        from modules import devices
+        # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+        if devices.npu_specific.has_npu:
+            import torch
+            torch.npu.set_device(0)
 
         shared.sd_model  # noqa: B018
 
         if sd_hijack.current_optimizer is None:
             sd_hijack.apply_optimizations()
 
-        from modules import devices
         devices.first_time_calculation()
     if not shared.cmd_opts.skip_load_model_at_start:
         Thread(target=load_model).start()

+ 34 - 0
modules/npu_specific.py

@@ -0,0 +1,34 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+    if importlib.util.find_spec("torch_npu") is None:
+        return False
+    import torch_npu
+    torch_npu.npu.set_device(0)
+
+    try:
+        # Will raise a RuntimeError if no NPU is found
+        _ = torch.npu.device_count()
+        return torch.npu.is_available()
+    except RuntimeError:
+        return False
+
+
+def get_npu_device_string():
+    if shared.cmd_opts.device_id is not None:
+        return f"npu:{shared.cmd_opts.device_id}"
+    return "npu:0"
+
+
+def torch_npu_gc():
+    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+    torch.npu.set_device(0)
+    with torch.npu.device(get_npu_device_string()):
+        torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()

+ 4 - 0
modules/textual_inversion/textual_inversion.py

@@ -151,6 +151,10 @@ class EmbeddingDatabase:
         return embedding
 
     def get_expected_shape(self):
+        # workaround
+        if devices.npu_specific.has_npu:
+            import torch
+            torch.npu.set_device(0)
         vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
         return vec.shape[1]
 

+ 4 - 0
requirements.txt

@@ -5,6 +5,8 @@ accelerate
 basicsr
 blendmodes
 clean-fid
+cloudpickle
+decorator
 einops
 fastapi>=0.90.1
 gfpgan
@@ -26,9 +28,11 @@ resize-right
 
 safetensors
 scikit-image>=0.19
+synr==0.5.0
 timm
 tomesd
 torch
 torchdiffeq
 torchsde
+tornado
 transformers==4.30.2

+ 4 - 0
requirements_versions.txt

@@ -4,6 +4,8 @@ accelerate==0.21.0
 basicsr==1.4.2
 blendmodes==2022
 clean-fid==0.1.35
+cloudpickle==3.0.0
+decorator==5.1.1
 einops==0.4.1
 fastapi==0.94.0
 gfpgan==1.3.8
@@ -23,10 +25,12 @@ realesrgan==0.3.0
 resize-right==0.0.2
 safetensors==0.3.1
 scikit-image==0.21.0
+synr==0.5.0
 timm==0.9.2
 tomesd==0.1.3
 torch
 torchdiffeq==0.2.3
 torchsde==0.2.6
+tornado==6.4
 transformers==4.30.2
 httpx==0.24.1

+ 4 - 0
webui.sh

@@ -159,6 +159,10 @@ then
     if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
     then
         export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+    elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
+    then
+        export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
+    
     fi
 fi