Sfoglia il codice sorgente

fix: AttributeError when attempting to reshape rescale by org_module weight

v0xie 1 anno fa
parent
commit
07805cbeee
1 ha cambiato i file con 7 aggiunte e 7 eliminazioni
  1. 7 7
      extensions-builtin/Lora/network_oft.py

+ 7 - 7
extensions-builtin/Lora/network_oft.py

@@ -36,13 +36,6 @@ class NetworkModuleOFT(network.NetworkModule):
             # self.alpha is unused
             # self.alpha is unused
             self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
             self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
 
 
-        # LyCORIS BOFT
-        if self.oft_blocks.dim() == 4:
-            self.is_boft = True
-        self.rescale = weights.w.get('rescale', None)
-        if self.rescale is not None:
-            self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
-
         is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
         is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
         is_conv = type(self.sd_module) in [torch.nn.Conv2d]
         is_conv = type(self.sd_module) in [torch.nn.Conv2d]
         is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
         is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
@@ -54,6 +47,13 @@ class NetworkModuleOFT(network.NetworkModule):
         elif is_other_linear:
         elif is_other_linear:
             self.out_dim = self.sd_module.embed_dim
             self.out_dim = self.sd_module.embed_dim
 
 
+        # LyCORIS BOFT
+        if self.oft_blocks.dim() == 4:
+            self.is_boft = True
+        self.rescale = weights.w.get('rescale', None)
+        if self.rescale is not None and not is_other_linear:
+            self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
+
         self.num_blocks = self.dim
         self.num_blocks = self.dim
         self.block_size = self.out_dim // self.dim
         self.block_size = self.out_dim // self.dim
         self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
         self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim