wangshuai09 1 éve
szülő
commit
cc3f604310

+ 7 - 0
modules/devices.py

@@ -88,9 +88,16 @@ def torch_gc():
         xpu_specific.torch_xpu_gc()
 
     if npu_specific.has_npu:
+        torch_npu_set_device()
         npu_specific.torch_npu_gc()
 
 
+def torch_npu_set_device():
+    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+    if npu_specific.has_npu:
+        torch.npu.set_device(0)
+
+
 def enable_tf32():
     if torch.cuda.is_available():
 

+ 1 - 4
modules/initialize.py

@@ -143,10 +143,7 @@ def initialize_rest(*, reload_script_modules=False):
         by that time, so we apply optimization again.
         """
         from modules import devices
-        # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
-        if devices.npu_specific.has_npu:
-            import torch
-            torch.npu.set_device(0)
+        devices.torch_npu_set_device()
 
         shared.sd_model  # noqa: B018
 

+ 8 - 0
modules/launch_utils.py

@@ -338,6 +338,7 @@ def prepare_environment():
             torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
             torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
     requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
+    requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
 
     xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
     clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
@@ -421,6 +422,13 @@ def prepare_environment():
         run_pip(f"install -r \"{requirements_file}\"", "requirements")
         startup_timer.record("install requirements")
 
+    if not os.path.isfile(requirements_file_for_npu):
+        requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
+
+    if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
+        run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
+        startup_timer.record("install requirements_for_npu")
+
     if not args.skip_install:
         run_extensions_installers(settings_file=args.ui_settings_file)
 

+ 1 - 4
modules/npu_specific.py

@@ -8,11 +8,10 @@ def check_for_npu():
     if importlib.util.find_spec("torch_npu") is None:
         return False
     import torch_npu
-    torch_npu.npu.set_device(0)
 
     try:
         # Will raise a RuntimeError if no NPU is found
-        _ = torch.npu.device_count()
+        _ = torch_npu.npu.device_count()
         return torch.npu.is_available()
     except RuntimeError:
         return False
@@ -25,8 +24,6 @@ def get_npu_device_string():
 
 
 def torch_npu_gc():
-    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
-    torch.npu.set_device(0)
     with torch.npu.device(get_npu_device_string()):
         torch.npu.empty_cache()
 

+ 1 - 4
modules/textual_inversion/textual_inversion.py

@@ -150,10 +150,7 @@ class EmbeddingDatabase:
         return embedding
 
     def get_expected_shape(self):
-        # workaround
-        if devices.npu_specific.has_npu:
-            import torch
-            torch.npu.set_device(0)
+        devices.torch_npu_set_device()
         vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
         return vec.shape[1]
 

+ 0 - 4
requirements.txt

@@ -4,8 +4,6 @@ accelerate
 
 blendmodes
 clean-fid
-cloudpickle
-decorator
 einops
 facexlib
 fastapi>=0.90.1
@@ -26,10 +24,8 @@ resize-right
 
 safetensors
 scikit-image>=0.19
-synr==0.5.0
 tomesd
 torch
 torchdiffeq
 torchsde
-tornado
 transformers==4.30.2

+ 4 - 0
requirements_npu.txt

@@ -0,0 +1,4 @@
+cloudpickle
+decorator
+synr==0.5.0
+tornado

+ 0 - 4
requirements_versions.txt

@@ -3,8 +3,6 @@ Pillow==9.5.0
 accelerate==0.21.0
 blendmodes==2022
 clean-fid==0.1.35
-cloudpickle==3.0.0
-decorator==5.1.1
 einops==0.4.1
 facexlib==0.3.0
 fastapi==0.94.0
@@ -23,12 +21,10 @@ pytorch_lightning==1.9.4
 resize-right==0.0.2
 safetensors==0.4.2
 scikit-image==0.21.0
-synr==0.5.0
 spandrel==0.1.6
 tomesd==0.1.3
 torch
 torchdiffeq==0.2.3
 torchsde==0.2.6
-tornado==6.4
 transformers==4.30.2
 httpx==0.24.1