1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import torch
- def make_weight_cp(t, wa, wb):
- temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
- return torch.einsum('i j k l, i r -> r j k l', temp, wa)
- def rebuild_conventional(up, down, shape, dyn_dim=None):
- up = up.reshape(up.size(0), -1)
- down = down.reshape(down.size(0), -1)
- if dyn_dim is not None:
- up = up[:, :dyn_dim]
- down = down[:dyn_dim, :]
- return (up @ down).reshape(shape)
- def rebuild_cp_decomposition(up, down, mid):
- up = up.reshape(up.size(0), -1)
- down = down.reshape(down.size(0), -1)
- return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
- # copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
- def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
- '''
- return a tuple of two value of input dimension decomposed by the number closest to factor
- second value is higher or equal than first value.
- In LoRA with Kroneckor Product, first value is a value for weight scale.
- secon value is a value for weight.
- Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
- examples)
- factor
- -1 2 4 8 16 ...
- 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
- 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
- 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
- 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
- 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
- 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
- '''
- if factor > 0 and (dimension % factor) == 0:
- m = factor
- n = dimension // factor
- if m > n:
- n, m = m, n
- return m, n
- if factor < 0:
- factor = dimension
- m, n = 1, dimension
- length = m + n
- while m<n:
- new_m = m + 1
- while dimension%new_m != 0:
- new_m += 1
- new_n = dimension // new_m
- if new_m + new_n > length or new_m>factor:
- break
- else:
- m, n = new_m, new_n
- if m > n:
- n, m = m, n
- return m, n
|