Sfoglia il codice sorgente

send weights to target device instead of CPU memory

AUTOMATIC1111 2 anni fa
parent
commit
eaba3d7349
2 ha cambiato i file con 31 aggiunte e 10 eliminazioni
  1. 15 9
      modules/sd_disable_initialization.py
  2. 16 1
      modules/sd_models.py

+ 15 - 9
modules/sd_disable_initialization.py

@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
     ```
     ```
     """
     """
 
 
-    def __init__(self, state_dict, device):
+    def __init__(self, state_dict, device, weight_dtype_conversion=None):
         super().__init__()
         super().__init__()
         self.state_dict = state_dict
         self.state_dict = state_dict
         self.device = device
         self.device = device
+        self.weight_dtype_conversion = weight_dtype_conversion or {}
+        self.default_dtype = self.weight_dtype_conversion.get('')
+
+    def get_weight_dtype(self, key):
+        key_first_term, _ = key.split('.', 1)
+        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
 
 
     def __enter__(self):
     def __enter__(self):
         if shared.cmd_opts.disable_model_loading_ram_optimization:
         if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper):
         sd = self.state_dict
         sd = self.state_dict
         device = self.device
         device = self.device
 
 
-        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
             used_param_keys = []
             used_param_keys = []
 
 
-            for name, param in self._parameters.items():
+            for name, param in module._parameters.items():
                 if param is None:
                 if param is None:
                     continue
                     continue
 
 
                 key = prefix + name
                 key = prefix + name
                 sd_param = sd.pop(key, None)
                 sd_param = sd.pop(key, None)
                 if sd_param is not None:
                 if sd_param is not None:
-                    state_dict[key] = sd_param
+                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
                     used_param_keys.append(key)
                     used_param_keys.append(key)
 
 
                 if param.is_meta:
                 if param.is_meta:
                     dtype = sd_param.dtype if sd_param is not None else param.dtype
                     dtype = sd_param.dtype if sd_param is not None else param.dtype
-                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
+                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
 
 
-            for name in self._buffers:
+            for name in module._buffers:
                 key = prefix + name
                 key = prefix + name
 
 
                 sd_param = sd.pop(key, None)
                 sd_param = sd.pop(key, None)
@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper):
                     state_dict[key] = sd_param
                     state_dict[key] = sd_param
                     used_param_keys.append(key)
                     used_param_keys.append(key)
 
 
-            original(self, state_dict, prefix, *args, **kwargs)
+            original(module, state_dict, prefix, *args, **kwargs)
 
 
             for key in used_param_keys:
             for key in used_param_keys:
                 state_dict.pop(key, None)
                 state_dict.pop(key, None)
 
 
-        def load_state_dict(original, self, state_dict, strict=True):
+        def load_state_dict(original, module, state_dict, strict=True):
             """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
             """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
             because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
             because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
             all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
             all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
             if state_dict == sd:
             if state_dict == sd:
                 state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
                 state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
 
 
-            original(self, state_dict, strict=strict)
+            original(module, state_dict, strict=strict)
 
 
         module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
         module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
         module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
         module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))

+ 16 - 1
modules/sd_models.py

@@ -518,6 +518,13 @@ def send_model_to_cpu(m):
     devices.torch_gc()
     devices.torch_gc()
 
 
 
 
+def model_target_device():
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        return devices.cpu
+    else:
+        return devices.device
+
+
 def send_model_to_device(m):
 def send_model_to_device(m):
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
         lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
         lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 
 
     timer.record("create model")
     timer.record("create model")
 
 
-    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+    if shared.cmd_opts.no_half:
+        weight_dtype_conversion = None
+    else:
+        weight_dtype_conversion = {
+            'first_stage_model': None,
+            '': torch.float16,
+        }
+
+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
     timer.record("load weights from state dict")
     timer.record("load weights from state dict")