Эх сурвалжийг харах

run basic torch calculation at startup in parallel to reduce the performance impact of first generation

AUTOMATIC 2 жил өмнө
parent
commit
8faac8b963
2 өөрчлөгдсөн 21 нэмэгдсэн , 1 устгасан
  1. 18 0
      modules/devices.py
  2. 3 1
      webui.py

+ 18 - 0
modules/devices.py

@@ -1,5 +1,7 @@
 import sys
 import contextlib
+from functools import lru_cache
+
 import torch
 from modules import errors
 
@@ -154,3 +156,19 @@ def test_for_nans(x, where):
     message += " Use --disable-nan-check commandline argument to disable this check."
 
     raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+    """
+    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+    spends about 2.7 seconds doing that, at least wih NVidia.
+    """
+
+    x = torch.zeros((1, 1)).to(device, dtype)
+    linear = torch.nn.Linear(1, 1).to(device, dtype)
+    linear(x)
+
+    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+    conv2d(x)

+ 3 - 1
webui.py

@@ -20,7 +20,7 @@ import logging
 
 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 
-from modules import paths, timer, import_hook, errors  # noqa: F401
+from modules import paths, timer, import_hook, errors, devices  # noqa: F401
 
 startup_timer = timer.Timer()
 
@@ -295,6 +295,8 @@ def initialize_rest(*, reload_script_modules=False):
     # (when reloading, this does nothing)
     Thread(target=lambda: shared.sd_model).start()
 
+    Thread(target=devices.first_time_calculation).start()
+
     shared.reload_hypernetworks()
     startup_timer.record("reload hypernetworks")