network_full.py 904 B

123456789101112131415161718192021222324252627
  1. import network
  2. class ModuleTypeFull(network.ModuleType):
  3. def create_module(self, net: network.Network, weights: network.NetworkWeights):
  4. if all(x in weights.w for x in ["diff"]):
  5. return NetworkModuleFull(net, weights)
  6. return None
  7. class NetworkModuleFull(network.NetworkModule):
  8. def __init__(self, net: network.Network, weights: network.NetworkWeights):
  9. super().__init__(net, weights)
  10. self.weight = weights.w.get("diff")
  11. self.ex_bias = weights.w.get("diff_b")
  12. def calc_updown(self, orig_weight):
  13. output_shape = self.weight.shape
  14. updown = self.weight.to(orig_weight.device)
  15. if self.ex_bias is not None:
  16. ex_bias = self.ex_bias.to(orig_weight.device)
  17. else:
  18. ex_bias = None
  19. return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)