network_ia3.py 908 B

123456789101112131415161718192021222324252627282930
  1. import network
  2. class ModuleTypeIa3(network.ModuleType):
  3. def create_module(self, net: network.Network, weights: network.NetworkWeights):
  4. if all(x in weights.w for x in ["weight"]):
  5. return NetworkModuleIa3(net, weights)
  6. return None
  7. class NetworkModuleIa3(network.NetworkModule):
  8. def __init__(self, net: network.Network, weights: network.NetworkWeights):
  9. super().__init__(net, weights)
  10. self.w = weights.w["weight"]
  11. self.on_input = weights.w["on_input"].item()
  12. def calc_updown(self, orig_weight):
  13. w = self.w.to(orig_weight.device)
  14. output_shape = [w.size(0), orig_weight.size(1)]
  15. if self.on_input:
  16. output_shape.reverse()
  17. else:
  18. w = w.reshape(-1, 1)
  19. updown = orig_weight * w
  20. return self.finalize_updown(updown, orig_weight, output_shape)