safe.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # this code is adapted from the script contributed by anon from /h/
  2. import pickle
  3. import collections
  4. import sys
  5. import traceback
  6. import torch
  7. import numpy
  8. import _codecs
  9. import zipfile
  10. import re
  11. # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
  12. TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
  13. def encode(*args):
  14. out = _codecs.encode(*args)
  15. return out
  16. class RestrictedUnpickler(pickle.Unpickler):
  17. extra_handler = None
  18. def persistent_load(self, saved_id):
  19. assert saved_id[0] == 'storage'
  20. try:
  21. return TypedStorage(_internal=True)
  22. except TypeError:
  23. return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
  24. def find_class(self, module, name):
  25. if self.extra_handler is not None:
  26. res = self.extra_handler(module, name)
  27. if res is not None:
  28. return res
  29. if module == 'collections' and name == 'OrderedDict':
  30. return getattr(collections, name)
  31. if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
  32. return getattr(torch._utils, name)
  33. if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
  34. return getattr(torch, name)
  35. if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
  36. return getattr(torch.nn.modules.container, name)
  37. if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
  38. return getattr(numpy.core.multiarray, name)
  39. if module == 'numpy' and name in ['dtype', 'ndarray']:
  40. return getattr(numpy, name)
  41. if module == '_codecs' and name == 'encode':
  42. return encode
  43. if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
  44. import pytorch_lightning.callbacks
  45. return pytorch_lightning.callbacks.model_checkpoint
  46. if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
  47. import pytorch_lightning.callbacks.model_checkpoint
  48. return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
  49. if module == "__builtin__" and name == 'set':
  50. return set
  51. # Forbid everything else.
  52. raise Exception(f"global '{module}/{name}' is forbidden")
  53. # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
  54. allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
  55. data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
  56. def check_zip_filenames(filename, names):
  57. for name in names:
  58. if allowed_zip_names_re.match(name):
  59. continue
  60. raise Exception(f"bad file inside {filename}: {name}")
  61. def check_pt(filename, extra_handler):
  62. try:
  63. # new pytorch format is a zip file
  64. with zipfile.ZipFile(filename) as z:
  65. check_zip_filenames(filename, z.namelist())
  66. # find filename of data.pkl in zip file: '<directory name>/data.pkl'
  67. data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
  68. if len(data_pkl_filenames) == 0:
  69. raise Exception(f"data.pkl not found in {filename}")
  70. if len(data_pkl_filenames) > 1:
  71. raise Exception(f"Multiple data.pkl found in {filename}")
  72. with z.open(data_pkl_filenames[0]) as file:
  73. unpickler = RestrictedUnpickler(file)
  74. unpickler.extra_handler = extra_handler
  75. unpickler.load()
  76. except zipfile.BadZipfile:
  77. # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
  78. with open(filename, "rb") as file:
  79. unpickler = RestrictedUnpickler(file)
  80. unpickler.extra_handler = extra_handler
  81. for _ in range(5):
  82. unpickler.load()
  83. def load(filename, *args, **kwargs):
  84. return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
  85. def load_with_extra(filename, extra_handler=None, *args, **kwargs):
  86. """
  87. this function is intended to be used by extensions that want to load models with
  88. some extra classes in them that the usual unpickler would find suspicious.
  89. Use the extra_handler argument to specify a function that takes module and field name as text,
  90. and returns that field's value:
  91. ```python
  92. def extra(module, name):
  93. if module == 'collections' and name == 'OrderedDict':
  94. return collections.OrderedDict
  95. return None
  96. safe.load_with_extra('model.pt', extra_handler=extra)
  97. ```
  98. The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
  99. definitely unsafe.
  100. """
  101. from modules import shared
  102. try:
  103. if not shared.cmd_opts.disable_safe_unpickle:
  104. check_pt(filename, extra_handler)
  105. except pickle.UnpicklingError:
  106. print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
  107. print(traceback.format_exc(), file=sys.stderr)
  108. print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
  109. print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
  110. return None
  111. except Exception:
  112. print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
  113. print(traceback.format_exc(), file=sys.stderr)
  114. print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
  115. print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
  116. return None
  117. return unsafe_torch_load(filename, *args, **kwargs)
  118. class Extra:
  119. """
  120. A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
  121. (because it's not your code making the torch.load call). The intended use is like this:
  122. ```
  123. import torch
  124. from modules import safe
  125. def handler(module, name):
  126. if module == 'torch' and name in ['float64', 'float16']:
  127. return getattr(torch, name)
  128. return None
  129. with safe.Extra(handler):
  130. x = torch.load('model.pt')
  131. ```
  132. """
  133. def __init__(self, handler):
  134. self.handler = handler
  135. def __enter__(self):
  136. global global_extra_handler
  137. assert global_extra_handler is None, 'already inside an Extra() block'
  138. global_extra_handler = self.handler
  139. def __exit__(self, exc_type, exc_val, exc_tb):
  140. global global_extra_handler
  141. global_extra_handler = None
  142. unsafe_torch_load = torch.load
  143. torch.load = load
  144. global_extra_handler = None