Ver Fonte

Add NPU Support

wangshuai09 há 1 ano atrás
pai
commit
ec124607f4

+ 7 - 2
modules/devices.py

@@ -3,7 +3,7 @@ import contextlib
 from functools import lru_cache
 from functools import lru_cache
 
 
 import torch
 import torch
-from modules import errors, shared
+from modules import errors, shared, npu_specific
 
 
 if sys.platform == "darwin":
 if sys.platform == "darwin":
     from modules import mac_specific
     from modules import mac_specific
@@ -40,6 +40,9 @@ def get_optimal_device_name():
     if has_xpu():
     if has_xpu():
         return xpu_specific.get_xpu_device_string()
         return xpu_specific.get_xpu_device_string()
 
 
+    if npu_specific.has_npu:
+        return npu_specific.get_npu_device_string()
+
     return "cpu"
     return "cpu"
 
 
 
 
@@ -67,6 +70,9 @@ def torch_gc():
     if has_xpu():
     if has_xpu():
         xpu_specific.torch_xpu_gc()
         xpu_specific.torch_xpu_gc()
 
 
+    if npu_specific.has_npu:
+        npu_specific.torch_npu_gc()
+
 
 
 def enable_tf32():
 def enable_tf32():
     if torch.cuda.is_available():
     if torch.cuda.is_available():
@@ -164,4 +170,3 @@ def first_time_calculation():
     x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
     x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
     conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
     conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
     conv2d(x)
     conv2d(x)
-

+ 5 - 1
modules/initialize.py

@@ -143,13 +143,17 @@ def initialize_rest(*, reload_script_modules=False):
         its optimization may be None because the list of optimizaers has neet been filled
         its optimization may be None because the list of optimizaers has neet been filled
         by that time, so we apply optimization again.
         by that time, so we apply optimization again.
         """
         """
+        from modules import devices
+        # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+        if devices.npu_specific.has_npu:
+            import torch
+            torch.npu.set_device(0)
 
 
         shared.sd_model  # noqa: B018
         shared.sd_model  # noqa: B018
 
 
         if sd_hijack.current_optimizer is None:
         if sd_hijack.current_optimizer is None:
             sd_hijack.apply_optimizations()
             sd_hijack.apply_optimizations()
 
 
-        from modules import devices
         devices.first_time_calculation()
         devices.first_time_calculation()
     if not shared.cmd_opts.skip_load_model_at_start:
     if not shared.cmd_opts.skip_load_model_at_start:
         Thread(target=load_model).start()
         Thread(target=load_model).start()

+ 34 - 0
modules/npu_specific.py

@@ -0,0 +1,34 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+    if importlib.util.find_spec("torch_npu") is None:
+        return False
+    import torch_npu
+    torch_npu.npu.set_device(0)
+
+    try:
+        # Will raise a RuntimeError if no NPU is found
+        _ = torch.npu.device_count()
+        return torch.npu.is_available()
+    except RuntimeError:
+        return False
+
+
+def get_npu_device_string():
+    if shared.cmd_opts.device_id is not None:
+        return f"npu:{shared.cmd_opts.device_id}"
+    return "npu:0"
+
+
+def torch_npu_gc():
+    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+    torch.npu.set_device(0)
+    with torch.npu.device(get_npu_device_string()):
+        torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()

+ 4 - 0
modules/textual_inversion/textual_inversion.py

@@ -151,6 +151,10 @@ class EmbeddingDatabase:
         return embedding
         return embedding
 
 
     def get_expected_shape(self):
     def get_expected_shape(self):
+        # workaround
+        if devices.npu_specific.has_npu:
+            import torch
+            torch.npu.set_device(0)
         vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
         vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
         return vec.shape[1]
         return vec.shape[1]
 
 

+ 4 - 0
requirements.txt

@@ -5,6 +5,8 @@ accelerate
 basicsr
 basicsr
 blendmodes
 blendmodes
 clean-fid
 clean-fid
+cloudpickle
+decorator
 einops
 einops
 fastapi>=0.90.1
 fastapi>=0.90.1
 gfpgan
 gfpgan
@@ -26,9 +28,11 @@ resize-right
 
 
 safetensors
 safetensors
 scikit-image>=0.19
 scikit-image>=0.19
+synr==0.5.0
 timm
 timm
 tomesd
 tomesd
 torch
 torch
 torchdiffeq
 torchdiffeq
 torchsde
 torchsde
+tornado
 transformers==4.30.2
 transformers==4.30.2

+ 4 - 0
requirements_versions.txt

@@ -4,6 +4,8 @@ accelerate==0.21.0
 basicsr==1.4.2
 basicsr==1.4.2
 blendmodes==2022
 blendmodes==2022
 clean-fid==0.1.35
 clean-fid==0.1.35
+cloudpickle==3.0.0
+decorator==5.1.1
 einops==0.4.1
 einops==0.4.1
 fastapi==0.94.0
 fastapi==0.94.0
 gfpgan==1.3.8
 gfpgan==1.3.8
@@ -23,10 +25,12 @@ realesrgan==0.3.0
 resize-right==0.0.2
 resize-right==0.0.2
 safetensors==0.3.1
 safetensors==0.3.1
 scikit-image==0.21.0
 scikit-image==0.21.0
+synr==0.5.0
 timm==0.9.2
 timm==0.9.2
 tomesd==0.1.3
 tomesd==0.1.3
 torch
 torch
 torchdiffeq==0.2.3
 torchdiffeq==0.2.3
 torchsde==0.2.6
 torchsde==0.2.6
+tornado==6.4
 transformers==4.30.2
 transformers==4.30.2
 httpx==0.24.1
 httpx==0.24.1

+ 4 - 0
webui.sh

@@ -159,6 +159,10 @@ then
     if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
     if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
     then
     then
         export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
         export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+    elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
+    then
+        export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
+    
     fi
     fi
 fi
 fi