|
@@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target):
|
|
return updown
|
|
return updown
|
|
|
|
|
|
|
|
|
|
|
|
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
|
|
+ weights_backup = getattr(self, "lora_weights_backup", None)
|
|
|
|
+
|
|
|
|
+ if weights_backup is None:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ if isinstance(self, torch.nn.MultiheadAttention):
|
|
|
|
+ self.in_proj_weight.copy_(weights_backup[0])
|
|
|
|
+ self.out_proj.weight.copy_(weights_backup[1])
|
|
|
|
+ else:
|
|
|
|
+ self.weight.copy_(weights_backup)
|
|
|
|
+
|
|
|
|
+
|
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
"""
|
|
"""
|
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
|
@@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|
self.lora_weights_backup = weights_backup
|
|
self.lora_weights_backup = weights_backup
|
|
|
|
|
|
if current_names != wanted_names:
|
|
if current_names != wanted_names:
|
|
- if weights_backup is not None:
|
|
|
|
- if isinstance(self, torch.nn.MultiheadAttention):
|
|
|
|
- self.in_proj_weight.copy_(weights_backup[0])
|
|
|
|
- self.out_proj.weight.copy_(weights_backup[1])
|
|
|
|
- else:
|
|
|
|
- self.weight.copy_(weights_backup)
|
|
|
|
|
|
+ lora_restore_weights_from_backup(self)
|
|
|
|
|
|
for lora in loaded_loras:
|
|
for lora in loaded_loras:
|
|
module = lora.modules.get(lora_layer_name, None)
|
|
module = lora.modules.get(lora_layer_name, None)
|
|
@@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|
setattr(self, "lora_current_names", wanted_names)
|
|
setattr(self, "lora_current_names", wanted_names)
|
|
|
|
|
|
|
|
|
|
|
|
+def lora_forward(module, input, original_forward):
|
|
|
|
+ """
|
|
|
|
+ Old way of applying Lora by executing operations during layer's forward.
|
|
|
|
+ Stacking many loras this way results in big performance degradation.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ if len(loaded_loras) == 0:
|
|
|
|
+ return original_forward(module, input)
|
|
|
|
+
|
|
|
|
+ input = devices.cond_cast_unet(input)
|
|
|
|
+
|
|
|
|
+ lora_restore_weights_from_backup(module)
|
|
|
|
+ lora_reset_cached_weight(module)
|
|
|
|
+
|
|
|
|
+ res = original_forward(module, input)
|
|
|
|
+
|
|
|
|
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
|
|
|
|
+ for lora in loaded_loras:
|
|
|
|
+ module = lora.modules.get(lora_layer_name, None)
|
|
|
|
+ if module is None:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ module.up.to(device=devices.device)
|
|
|
|
+ module.down.to(device=devices.device)
|
|
|
|
+
|
|
|
|
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
|
|
+
|
|
|
|
+ return res
|
|
|
|
+
|
|
|
|
+
|
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|
setattr(self, "lora_current_names", ())
|
|
setattr(self, "lora_current_names", ())
|
|
setattr(self, "lora_weights_backup", None)
|
|
setattr(self, "lora_weights_backup", None)
|
|
|
|
|
|
|
|
|
|
def lora_Linear_forward(self, input):
|
|
def lora_Linear_forward(self, input):
|
|
|
|
+ if shared.opts.lora_functional:
|
|
|
|
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
|
|
+
|
|
lora_apply_weights(self)
|
|
lora_apply_weights(self)
|
|
|
|
|
|
return torch.nn.Linear_forward_before_lora(self, input)
|
|
return torch.nn.Linear_forward_before_lora(self, input)
|
|
@@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
def lora_Conv2d_forward(self, input):
|
|
def lora_Conv2d_forward(self, input):
|
|
|
|
+ if shared.opts.lora_functional:
|
|
|
|
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
|
|
|
+
|
|
lora_apply_weights(self)
|
|
lora_apply_weights(self)
|
|
|
|
|
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|