ソースを参照

Fix some Lora's not working

Leo Mozoloa 2 年 前
コミット
c3eced22fc
1 ファイル変更5 行追加1 行削除
  1. 5 1
      extensions-builtin/Lora/lora.py

+ 5 - 1
extensions-builtin/Lora/lora.py

@@ -165,8 +165,10 @@ def load_lora(name, filename):
             module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
         elif type(sd_module) == torch.nn.MultiheadAttention:
             module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
-        elif type(sd_module) == torch.nn.Conv2d:
+        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
             module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
         else:
             print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
             continue
@@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target):
 
         if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
             updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+        elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+            updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
         else:
             updown = up @ down