Jelajahi Sumber

send weights to target device instead of CPU memory

AUTOMATIC1111 2 tahun lalu
induk
melakukan
eaba3d7349
2 mengubah file dengan 31 tambahan dan 10 penghapusan
  1. 15 9
      modules/sd_disable_initialization.py
  2. 16 1
      modules/sd_models.py

+ 15 - 9
modules/sd_disable_initialization.py

@@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
     ```
     """
 
-    def __init__(self, state_dict, device):
+    def __init__(self, state_dict, device, weight_dtype_conversion=None):
         super().__init__()
         self.state_dict = state_dict
         self.device = device
+        self.weight_dtype_conversion = weight_dtype_conversion or {}
+        self.default_dtype = self.weight_dtype_conversion.get('')
+
+    def get_weight_dtype(self, key):
+        key_first_term, _ = key.split('.', 1)
+        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
 
     def __enter__(self):
         if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper):
         sd = self.state_dict
         device = self.device
 
-        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
             used_param_keys = []
 
-            for name, param in self._parameters.items():
+            for name, param in module._parameters.items():
                 if param is None:
                     continue
 
                 key = prefix + name
                 sd_param = sd.pop(key, None)
                 if sd_param is not None:
-                    state_dict[key] = sd_param
+                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
                     used_param_keys.append(key)
 
                 if param.is_meta:
                     dtype = sd_param.dtype if sd_param is not None else param.dtype
-                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
+                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
 
-            for name in self._buffers:
+            for name in module._buffers:
                 key = prefix + name
 
                 sd_param = sd.pop(key, None)
@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper):
                     state_dict[key] = sd_param
                     used_param_keys.append(key)
 
-            original(self, state_dict, prefix, *args, **kwargs)
+            original(module, state_dict, prefix, *args, **kwargs)
 
             for key in used_param_keys:
                 state_dict.pop(key, None)
 
-        def load_state_dict(original, self, state_dict, strict=True):
+        def load_state_dict(original, module, state_dict, strict=True):
             """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
             because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
             all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
             if state_dict == sd:
                 state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
 
-            original(self, state_dict, strict=strict)
+            original(module, state_dict, strict=strict)
 
         module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
         module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))

+ 16 - 1
modules/sd_models.py

@@ -518,6 +518,13 @@ def send_model_to_cpu(m):
     devices.torch_gc()
 
 
+def model_target_device():
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        return devices.cpu
+    else:
+        return devices.device
+
+
 def send_model_to_device(m):
     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
         lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 
     timer.record("create model")
 
-    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+    if shared.cmd_opts.no_half:
+        weight_dtype_conversion = None
+    else:
+        weight_dtype_conversion = {
+            'first_stage_model': None,
+            '': torch.float16,
+        }
+
+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
     timer.record("load weights from state dict")