network_norm.py 889 B

12345678910111213141516171819202122232425262728
  1. import network
  2. class ModuleTypeNorm(network.ModuleType):
  3. def create_module(self, net: network.Network, weights: network.NetworkWeights):
  4. if all(x in weights.w for x in ["w_norm", "b_norm"]):
  5. return NetworkModuleNorm(net, weights)
  6. return None
  7. class NetworkModuleNorm(network.NetworkModule):
  8. def __init__(self, net: network.Network, weights: network.NetworkWeights):
  9. super().__init__(net, weights)
  10. self.w_norm = weights.w.get("w_norm")
  11. self.b_norm = weights.w.get("b_norm")
  12. def calc_updown(self, orig_weight):
  13. output_shape = self.w_norm.shape
  14. updown = self.w_norm.to(orig_weight.device)
  15. if self.b_norm is not None:
  16. ex_bias = self.b_norm.to(orig_weight.device)
  17. else:
  18. ex_bias = None
  19. return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)