Browse Source

stuff related to torch version change

AUTOMATIC 2 năm trước cách đây
mục cha
commit
ee71eee181
4 tập tin đã thay đổi với 10 bổ sung11 xóa
  1. 4 4
      launch.py
  2. 1 4
      modules/safe.py
  3. 1 1
      requirements_versions.txt
  4. 4 2
      webui.py

+ 4 - 4
launch.py

@@ -121,12 +121,12 @@ def run_python(code, desc=None, errdesc=None):
     return run(f'"{python}" -c "{code}"', desc, errdesc)
 
 
-def run_pip(args, desc=None):
+def run_pip(args, desc=None, live=False):
     if skip_install:
         return
 
     index_url_line = f' --index-url {index_url}' if index_url != '' else ''
-    return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
+    return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
 
 
 def check_run_python(code):
@@ -225,7 +225,7 @@ def run_extensions_installers(settings_file):
 def prepare_environment():
     global skip_install
 
-    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118")
+    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118")
     requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
 
     xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
@@ -271,7 +271,7 @@ def prepare_environment():
     if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
         if platform.system() == "Windows":
             if platform.python_version().startswith("3.10"):
-                run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+                run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
             else:
                 print("Installation of xformers is not supported in this version of Python.")
                 print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")

+ 1 - 4
modules/safe.py

@@ -1,6 +1,5 @@
 # this code is adapted from the script contributed by anon from /h/
 
-import io
 import pickle
 import collections
 import sys
@@ -12,11 +11,9 @@ import _codecs
 import zipfile
 import re
 
-
 # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
 TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
 
-
 def encode(*args):
     out = _codecs.encode(*args)
     return out
@@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
 
     def persistent_load(self, saved_id):
         assert saved_id[0] == 'storage'
-        return TypedStorage()
+        return TypedStorage(_internal=True)
 
     def find_class(self, module, name):
         if self.extra_handler is not None:

+ 1 - 1
requirements_versions.txt

@@ -25,6 +25,6 @@ lark==1.1.2
 inflection==0.5.1
 GitPython==3.1.30
 torchsde==0.2.5
-safetensors==0.3.0
+safetensors==0.3.1
 httpcore<=0.15
 fastapi==0.94.0

+ 4 - 2
webui.py

@@ -21,6 +21,8 @@ import torch
 import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
 warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
 warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+warnings.filterwarnings(action='ignore', category=UserWarning, message='TypedStorage is deprecated')
+
 
 startup_timer.record("import torch")
 
@@ -113,7 +115,7 @@ def check_versions():
     if shared.cmd_opts.skip_version_check:
         return
 
-    expected_torch_version = "1.13.1"
+    expected_torch_version = "2.0.0"
 
     if version.parse(torch.__version__) < version.parse(expected_torch_version):
         errors.print_error_explanation(f"""
@@ -126,7 +128,7 @@ there are reports of issues with training tab on the latest version.
 Use --skip-version-check commandline argument to disable this check.
         """.strip())
 
-    expected_xformers_version = "0.0.16rc425"
+    expected_xformers_version = "0.0.17"
     if shared.xformers_available:
         import xformers