lyco_helpers.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. def make_weight_cp(t, wa, wb):
  3. temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
  4. return torch.einsum('i j k l, i r -> r j k l', temp, wa)
  5. def rebuild_conventional(up, down, shape, dyn_dim=None):
  6. up = up.reshape(up.size(0), -1)
  7. down = down.reshape(down.size(0), -1)
  8. if dyn_dim is not None:
  9. up = up[:, :dyn_dim]
  10. down = down[:dyn_dim, :]
  11. return (up @ down).reshape(shape)
  12. def rebuild_cp_decomposition(up, down, mid):
  13. up = up.reshape(up.size(0), -1)
  14. down = down.reshape(down.size(0), -1)
  15. return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
  16. # copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
  17. def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
  18. '''
  19. return a tuple of two value of input dimension decomposed by the number closest to factor
  20. second value is higher or equal than first value.
  21. In LoRA with Kroneckor Product, first value is a value for weight scale.
  22. secon value is a value for weight.
  23. Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
  24. examples)
  25. factor
  26. -1 2 4 8 16 ...
  27. 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
  28. 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
  29. 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
  30. 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
  31. 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
  32. 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
  33. '''
  34. if factor > 0 and (dimension % factor) == 0:
  35. m = factor
  36. n = dimension // factor
  37. if m > n:
  38. n, m = m, n
  39. return m, n
  40. if factor < 0:
  41. factor = dimension
  42. m, n = 1, dimension
  43. length = m + n
  44. while m<n:
  45. new_m = m + 1
  46. while dimension%new_m != 0:
  47. new_m += 1
  48. new_n = dimension // new_m
  49. if new_m + new_n > length or new_m>factor:
  50. break
  51. else:
  52. m, n = new_m, new_n
  53. if m > n:
  54. n, m = m, n
  55. return m, n