瀏覽代碼

Lora: add an option to use old method of applying loras

AUTOMATIC 2 年之前
父節點
當前提交
ec0da07236
共有 2 個文件被更改,包括 55 次插入6 次删除
  1. 50 6
      extensions-builtin/Lora/lora.py
  2. 5 0
      extensions-builtin/Lora/scripts/lora_script.py

+ 50 - 6
extensions-builtin/Lora/lora.py

@@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target):
         return updown
         return updown
 
 
 
 
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+    weights_backup = getattr(self, "lora_weights_backup", None)
+
+    if weights_backup is None:
+        return
+
+    if isinstance(self, torch.nn.MultiheadAttention):
+        self.in_proj_weight.copy_(weights_backup[0])
+        self.out_proj.weight.copy_(weights_backup[1])
+    else:
+        self.weight.copy_(weights_backup)
+
+
 def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
 def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
     """
     """
     Applies the currently selected set of Loras to the weights of torch layer self.
     Applies the currently selected set of Loras to the weights of torch layer self.
@@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
         self.lora_weights_backup = weights_backup
         self.lora_weights_backup = weights_backup
 
 
     if current_names != wanted_names:
     if current_names != wanted_names:
-        if weights_backup is not None:
-            if isinstance(self, torch.nn.MultiheadAttention):
-                self.in_proj_weight.copy_(weights_backup[0])
-                self.out_proj.weight.copy_(weights_backup[1])
-            else:
-                self.weight.copy_(weights_backup)
+        lora_restore_weights_from_backup(self)
 
 
         for lora in loaded_loras:
         for lora in loaded_loras:
             module = lora.modules.get(lora_layer_name, None)
             module = lora.modules.get(lora_layer_name, None)
@@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
         setattr(self, "lora_current_names", wanted_names)
         setattr(self, "lora_current_names", wanted_names)
 
 
 
 
+def lora_forward(module, input, original_forward):
+    """
+    Old way of applying Lora by executing operations during layer's forward.
+    Stacking many loras this way results in big performance degradation.
+    """
+
+    if len(loaded_loras) == 0:
+        return original_forward(module, input)
+
+    input = devices.cond_cast_unet(input)
+
+    lora_restore_weights_from_backup(module)
+    lora_reset_cached_weight(module)
+
+    res = original_forward(module, input)
+
+    lora_layer_name = getattr(module, 'lora_layer_name', None)
+    for lora in loaded_loras:
+        module = lora.modules.get(lora_layer_name, None)
+        if module is None:
+            continue
+
+        module.up.to(device=devices.device)
+        module.down.to(device=devices.device)
+
+        res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+    return res
+
+
 def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
 def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
     setattr(self, "lora_current_names", ())
     setattr(self, "lora_current_names", ())
     setattr(self, "lora_weights_backup", None)
     setattr(self, "lora_weights_backup", None)
 
 
 
 
 def lora_Linear_forward(self, input):
 def lora_Linear_forward(self, input):
+    if shared.opts.lora_functional:
+        return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
     lora_apply_weights(self)
     lora_apply_weights(self)
 
 
     return torch.nn.Linear_forward_before_lora(self, input)
     return torch.nn.Linear_forward_before_lora(self, input)
@@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
 
 
 
 
 def lora_Conv2d_forward(self, input):
 def lora_Conv2d_forward(self, input):
+    if shared.opts.lora_functional:
+        return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
     lora_apply_weights(self)
     lora_apply_weights(self)
 
 
     return torch.nn.Conv2d_forward_before_lora(self, input)
     return torch.nn.Conv2d_forward_before_lora(self, input)

+ 5 - 0
extensions-builtin/Lora/scripts/lora_script.py

@@ -55,3 +55,8 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
     "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
     "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
 }))
 }))
+
+
+shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
+    "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
+}))