sd_disable_initialization.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import ldm.modules.encoders.modules
  2. import open_clip
  3. import torch
  4. import transformers.utils.hub
  5. from modules import shared
  6. class ReplaceHelper:
  7. def __init__(self):
  8. self.replaced = []
  9. def replace(self, obj, field, func):
  10. original = getattr(obj, field, None)
  11. if original is None:
  12. return None
  13. self.replaced.append((obj, field, original))
  14. setattr(obj, field, func)
  15. return original
  16. def restore(self):
  17. for obj, field, original in self.replaced:
  18. setattr(obj, field, original)
  19. self.replaced.clear()
  20. class DisableInitialization(ReplaceHelper):
  21. """
  22. When an object of this class enters a `with` block, it starts:
  23. - preventing torch's layer initialization functions from working
  24. - changes CLIP and OpenCLIP to not download model weights
  25. - changes CLIP to not make requests to check if there is a new version of a file you already have
  26. When it leaves the block, it reverts everything to how it was before.
  27. Use it like this:
  28. ```
  29. with DisableInitialization():
  30. do_things()
  31. ```
  32. """
  33. def __init__(self, disable_clip=True):
  34. super().__init__()
  35. self.disable_clip = disable_clip
  36. def replace(self, obj, field, func):
  37. original = getattr(obj, field, None)
  38. if original is None:
  39. return None
  40. self.replaced.append((obj, field, original))
  41. setattr(obj, field, func)
  42. return original
  43. def __enter__(self):
  44. def do_nothing(*args, **kwargs):
  45. pass
  46. def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
  47. return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
  48. def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
  49. res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
  50. res.name_or_path = pretrained_model_name_or_path
  51. return res
  52. def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
  53. args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
  54. return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)
  55. def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):
  56. # this file is always 404, prevent making request
  57. if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':
  58. return None
  59. try:
  60. res = original(url, *args, local_files_only=True, **kwargs)
  61. if res is None:
  62. res = original(url, *args, local_files_only=False, **kwargs)
  63. return res
  64. except Exception:
  65. return original(url, *args, local_files_only=False, **kwargs)
  66. def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
  67. return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)
  68. def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):
  69. return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)
  70. def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):
  71. return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)
  72. self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
  73. self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
  74. self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
  75. if self.disable_clip:
  76. self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
  77. self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
  78. self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
  79. self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
  80. self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
  81. self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
  82. def __exit__(self, exc_type, exc_val, exc_tb):
  83. self.restore()
  84. class InitializeOnMeta(ReplaceHelper):
  85. """
  86. Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
  87. which results in those parameters having no values and taking no memory. model.to() will be broken and
  88. will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
  89. Usage:
  90. ```
  91. with sd_disable_initialization.InitializeOnMeta():
  92. sd_model = instantiate_from_config(sd_config.model)
  93. ```
  94. """
  95. def __enter__(self):
  96. if shared.cmd_opts.disable_model_loading_ram_optimization:
  97. return
  98. def set_device(x):
  99. x["device"] = "meta"
  100. return x
  101. linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
  102. conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
  103. mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
  104. self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
  105. def __exit__(self, exc_type, exc_val, exc_tb):
  106. self.restore()
  107. class LoadStateDictOnMeta(ReplaceHelper):
  108. """
  109. Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
  110. As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
  111. Meant to be used together with InitializeOnMeta above.
  112. Usage:
  113. ```
  114. with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
  115. model.load_state_dict(state_dict, strict=False)
  116. ```
  117. """
  118. def __init__(self, state_dict, device, weight_dtype_conversion=None):
  119. super().__init__()
  120. self.state_dict = state_dict
  121. self.device = device
  122. self.weight_dtype_conversion = weight_dtype_conversion or {}
  123. self.default_dtype = self.weight_dtype_conversion.get('')
  124. def get_weight_dtype(self, key):
  125. key_first_term, _ = key.split('.', 1)
  126. return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
  127. def __enter__(self):
  128. if shared.cmd_opts.disable_model_loading_ram_optimization:
  129. return
  130. sd = self.state_dict
  131. device = self.device
  132. def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
  133. used_param_keys = []
  134. for name, param in module._parameters.items():
  135. if param is None:
  136. continue
  137. key = prefix + name
  138. sd_param = sd.pop(key, None)
  139. if sd_param is not None:
  140. state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
  141. used_param_keys.append(key)
  142. if param.is_meta:
  143. dtype = sd_param.dtype if sd_param is not None else param.dtype
  144. module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
  145. for name in module._buffers:
  146. key = prefix + name
  147. sd_param = sd.pop(key, None)
  148. if sd_param is not None:
  149. state_dict[key] = sd_param
  150. used_param_keys.append(key)
  151. original(module, state_dict, prefix, *args, **kwargs)
  152. for key in used_param_keys:
  153. state_dict.pop(key, None)
  154. def load_state_dict(original, module, state_dict, strict=True):
  155. """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
  156. because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
  157. all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
  158. In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
  159. The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
  160. the function and does not call the original) the state dict will just fail to load because weights
  161. would be on the meta device.
  162. """
  163. if state_dict == sd:
  164. state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
  165. original(module, state_dict, strict=strict)
  166. module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
  167. module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
  168. linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
  169. conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
  170. mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
  171. layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
  172. group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
  173. def __exit__(self, exc_type, exc_val, exc_tb):
  174. self.restore()