Просмотр исходного кода

Merge pull request #11958 from AUTOMATIC1111/conserve-ram

Use less RAM when creating models
AUTOMATIC1111 2 лет назад
Родитель
Сommit
ac81c1dd1f
4 измененных файлов с 114 добавлено и 13 удалено
  1. 1 0
      modules/cmd_args.py
  2. 101 5
      modules/sd_disable_initialization.py
  3. 10 6
      modules/sd_models.py
  4. 2 2
      webui.py

+ 1 - 0
modules/cmd_args.py

@@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
 parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
 parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
 parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
 parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
 parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
 parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
 parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
 parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
 parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
 parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)

+ 101 - 5
modules/sd_disable_initialization.py

@@ -3,8 +3,31 @@ import open_clip
 import torch
 import torch
 import transformers.utils.hub
 import transformers.utils.hub
 
 
+from modules import shared
 
 
-class DisableInitialization:
+
+class ReplaceHelper:
+    def __init__(self):
+        self.replaced = []
+
+    def replace(self, obj, field, func):
+        original = getattr(obj, field, None)
+        if original is None:
+            return None
+
+        self.replaced.append((obj, field, original))
+        setattr(obj, field, func)
+
+        return original
+
+    def restore(self):
+        for obj, field, original in self.replaced:
+            setattr(obj, field, original)
+
+        self.replaced.clear()
+
+
+class DisableInitialization(ReplaceHelper):
     """
     """
     When an object of this class enters a `with` block, it starts:
     When an object of this class enters a `with` block, it starts:
     - preventing torch's layer initialization functions from working
     - preventing torch's layer initialization functions from working
@@ -21,7 +44,7 @@ class DisableInitialization:
     """
     """
 
 
     def __init__(self, disable_clip=True):
     def __init__(self, disable_clip=True):
-        self.replaced = []
+        super().__init__()
         self.disable_clip = disable_clip
         self.disable_clip = disable_clip
 
 
     def replace(self, obj, field, func):
     def replace(self, obj, field, func):
@@ -86,8 +109,81 @@ class DisableInitialization:
             self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
             self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
 
 
     def __exit__(self, exc_type, exc_val, exc_tb):
     def __exit__(self, exc_type, exc_val, exc_tb):
-        for obj, field, original in self.replaced:
-            setattr(obj, field, original)
+        self.restore()
 
 
-        self.replaced.clear()
 
 
+class InitializeOnMeta(ReplaceHelper):
+    """
+    Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
+    which results in those parameters having no values and taking no memory. model.to() will be broken and
+    will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
+
+    Usage:
+    ```
+    with sd_disable_initialization.InitializeOnMeta():
+        sd_model = instantiate_from_config(sd_config.model)
+    ```
+    """
+
+    def __enter__(self):
+        if shared.cmd_opts.disable_model_loading_ram_optimization:
+            return
+
+        def set_device(x):
+            x["device"] = "meta"
+            return x
+
+        linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
+        conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
+        mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
+        self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.restore()
+
+
+class LoadStateDictOnMeta(ReplaceHelper):
+    """
+    Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
+    As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
+    Meant to be used together with InitializeOnMeta above.
+
+    Usage:
+    ```
+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
+        model.load_state_dict(state_dict, strict=False)
+    ```
+    """
+
+    def __init__(self, state_dict, device):
+        super().__init__()
+        self.state_dict = state_dict
+        self.device = device
+
+    def __enter__(self):
+        if shared.cmd_opts.disable_model_loading_ram_optimization:
+            return
+
+        sd = self.state_dict
+        device = self.device
+
+        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+            params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+
+            for name, param in params:
+                if param.is_meta:
+                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+
+            original(self, state_dict, prefix, *args, **kwargs)
+
+            for name, _ in params:
+                key = prefix + name
+                if key in sd:
+                    del sd[key]
+
+        linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
+        conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
+        mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.restore()

+ 10 - 6
modules/sd_models.py

@@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
         return sd_model.cond_stage_model([""])
         return sd_model.cond_stage_model([""])
 
 
 
 
-
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     from modules import lowvram, sd_hijack
     from modules import lowvram, sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
     checkpoint_info = checkpoint_info or select_checkpoint()
@@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     sd_model = None
     sd_model = None
     try:
     try:
         with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
         with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
-            sd_model = instantiate_from_config(sd_config.model)
-    except Exception:
-        pass
+            with sd_disable_initialization.InitializeOnMeta():
+                sd_model = instantiate_from_config(sd_config.model)
+
+    except Exception as e:
+        errors.display(e, "creating model quickly", full_traceback=True)
 
 
     if sd_model is None:
     if sd_model is None:
         print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
         print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
-        sd_model = instantiate_from_config(sd_config.model)
+
+        with sd_disable_initialization.InitializeOnMeta():
+            sd_model = instantiate_from_config(sd_config.model)
 
 
     sd_model.used_config = checkpoint_config
     sd_model.used_config = checkpoint_config
 
 
     timer.record("create model")
     timer.record("create model")
 
 
-    load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
 
 
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
         lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
         lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)

+ 2 - 2
webui.py

@@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
         if modules.sd_hijack.current_optimizer is None:
         if modules.sd_hijack.current_optimizer is None:
             modules.sd_hijack.apply_optimizations()
             modules.sd_hijack.apply_optimizations()
 
 
-    Thread(target=load_model).start()
+        devices.first_time_calculation()
 
 
-    Thread(target=devices.first_time_calculation).start()
+    Thread(target=load_model).start()
 
 
     shared.reload_hypernetworks()
     shared.reload_hypernetworks()
     startup_timer.record("reload hypernetworks")
     startup_timer.record("reload hypernetworks")