Pārlūkot izejas kodu

store patches for Lora in a specialized module

AUTOMATIC1111 2 gadi atpakaļ
vecāks
revīzija
f01682ee01

+ 31 - 0
extensions-builtin/Lora/lora_patches.py

@@ -0,0 +1,31 @@
+import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+    def __init__(self):
+        self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+        self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+        self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+        self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+        self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+        self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+        self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+        self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+        self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+        self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+    def undo(self):
+        self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+        self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+        self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+        self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+        self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+        self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+        self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+        self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+        self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+        self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+

+ 18 - 14
extensions-builtin/Lora/networks.py

@@ -2,6 +2,7 @@ import logging
 import os
 import os
 import re
 import re
 
 
+import lora_patches
 import network
 import network
 import network_lora
 import network_lora
 import network_hada
 import network_hada
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
 
 
 def network_Linear_forward(self, input):
 def network_Linear_forward(self, input):
     if shared.opts.lora_functional:
     if shared.opts.lora_functional:
-        return network_forward(self, input, torch.nn.Linear_forward_before_network)
+        return network_forward(self, input, originals.Linear_forward)
 
 
     network_apply_weights(self)
     network_apply_weights(self)
 
 
-    return torch.nn.Linear_forward_before_network(self, input)
+    return originals.Linear_forward(self, input)
 
 
 
 
 def network_Linear_load_state_dict(self, *args, **kwargs):
 def network_Linear_load_state_dict(self, *args, **kwargs):
     network_reset_cached_weight(self)
     network_reset_cached_weight(self)
 
 
-    return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+    return originals.Linear_load_state_dict(self, *args, **kwargs)
 
 
 
 
 def network_Conv2d_forward(self, input):
 def network_Conv2d_forward(self, input):
     if shared.opts.lora_functional:
     if shared.opts.lora_functional:
-        return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+        return network_forward(self, input, originals.Conv2d_forward)
 
 
     network_apply_weights(self)
     network_apply_weights(self)
 
 
-    return torch.nn.Conv2d_forward_before_network(self, input)
+    return originals.Conv2d_forward(self, input)
 
 
 
 
 def network_Conv2d_load_state_dict(self, *args, **kwargs):
 def network_Conv2d_load_state_dict(self, *args, **kwargs):
     network_reset_cached_weight(self)
     network_reset_cached_weight(self)
 
 
-    return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+    return originals.Conv2d_load_state_dict(self, *args, **kwargs)
 
 
 
 
 def network_GroupNorm_forward(self, input):
 def network_GroupNorm_forward(self, input):
     if shared.opts.lora_functional:
     if shared.opts.lora_functional:
-        return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+        return network_forward(self, input, originals.GroupNorm_forward)
 
 
     network_apply_weights(self)
     network_apply_weights(self)
 
 
-    return torch.nn.GroupNorm_forward_before_network(self, input)
+    return originals.GroupNorm_forward(self, input)
 
 
 
 
 def network_GroupNorm_load_state_dict(self, *args, **kwargs):
 def network_GroupNorm_load_state_dict(self, *args, **kwargs):
     network_reset_cached_weight(self)
     network_reset_cached_weight(self)
 
 
-    return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+    return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
 
 
 
 
 def network_LayerNorm_forward(self, input):
 def network_LayerNorm_forward(self, input):
     if shared.opts.lora_functional:
     if shared.opts.lora_functional:
-        return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+        return network_forward(self, input, originals.LayerNorm_forward)
 
 
     network_apply_weights(self)
     network_apply_weights(self)
 
 
-    return torch.nn.LayerNorm_forward_before_network(self, input)
+    return originals.LayerNorm_forward(self, input)
 
 
 
 
 def network_LayerNorm_load_state_dict(self, *args, **kwargs):
 def network_LayerNorm_load_state_dict(self, *args, **kwargs):
     network_reset_cached_weight(self)
     network_reset_cached_weight(self)
 
 
-    return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+    return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
 
 
 
 
 def network_MultiheadAttention_forward(self, *args, **kwargs):
 def network_MultiheadAttention_forward(self, *args, **kwargs):
     network_apply_weights(self)
     network_apply_weights(self)
 
 
-    return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+    return originals.MultiheadAttention_forward(self, *args, **kwargs)
 
 
 
 
 def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
 def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
     network_reset_cached_weight(self)
     network_reset_cached_weight(self)
 
 
-    return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+    return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
 
 
 
 
 def list_available_networks():
 def list_available_networks():
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
     if added:
     if added:
         params["Prompt"] += "\n" + "".join(added)
         params["Prompt"] += "\n" + "".join(added)
 
 
+
+originals: lora_patches.LoraPatches = None
+
 extra_network_lora = None
 extra_network_lora = None
 
 
 available_networks = {}
 available_networks = {}

+ 5 - 47
extensions-builtin/Lora/scripts/lora_script.py

@@ -7,17 +7,14 @@ from fastapi import FastAPI
 import network
 import network
 import networks
 import networks
 import lora  # noqa:F401
 import lora  # noqa:F401
+import lora_patches
 import extra_networks_lora
 import extra_networks_lora
 import ui_extra_networks_lora
 import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
+
 
 
 def unload():
 def unload():
-    torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
-    torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
-    torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
-    torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
-    torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
-    torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+    networks.originals.undo()
 
 
 
 
 def before_ui():
 def before_ui():
@@ -28,46 +25,7 @@ def before_ui():
     extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
     extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
 
 
 
 
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
-    torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
-    torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
-    torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
-    torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
-    torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
-
-if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
-    torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
-    torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
-
-if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
-    torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
-    torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
-
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
-    torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
-torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
-torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
-torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
 
 
 script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
 script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
 script_callbacks.on_script_unloaded(unload)
 script_callbacks.on_script_unloaded(unload)

+ 64 - 0
modules/patches.py

@@ -0,0 +1,64 @@
+from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+    """Replaces a function in a module or a class.
+
+    Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+    If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+    Arguments:
+        key: identifying information for who is doing the replacement. You can use __name__.
+        obj: the module or the class
+        field: name of the function as a string
+        replacement: the new function
+
+    Returns:
+        the original function
+    """
+
+    patch_key = (obj, field)
+    if patch_key in originals[key]:
+        raise RuntimeError(f"patch for {field} is already applied")
+
+    original_func = getattr(obj, field)
+    originals[key][patch_key] = original_func
+
+    setattr(obj, field, replacement)
+
+    return original_func
+
+
+def undo(key, obj, field):
+    """Undoes the peplacement by the patch().
+
+    If the function is not replaced, raises an exception.
+
+    Arguments:
+        key: identifying information for who is doing the replacement. You can use __name__.
+        obj: the module or the class
+        field: name of the function as a string
+
+    Returns:
+        Always None
+    """
+
+    patch_key = (obj, field)
+
+    if patch_key not in originals[key]:
+        raise RuntimeError(f"there is no patch for {field} to undo")
+
+    original_func = originals[key].pop(patch_key)
+    setattr(obj, field, original_func)
+
+    return None
+
+
+def original(key, obj, field):
+    """Returns the original function for the patch created by the patch() function"""
+    patch_key = (obj, field)
+
+    return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)