sd_samplers_cfg_denoiser.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import torch
  2. from modules import prompt_parser, devices, sd_samplers_common
  3. from modules.shared import opts, state
  4. import modules.shared as shared
  5. from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
  6. from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
  7. from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
  8. def catenate_conds(conds):
  9. if not isinstance(conds[0], dict):
  10. return torch.cat(conds)
  11. return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
  12. def subscript_cond(cond, a, b):
  13. if not isinstance(cond, dict):
  14. return cond[a:b]
  15. return {key: vec[a:b] for key, vec in cond.items()}
  16. def pad_cond(tensor, repeats, empty):
  17. if not isinstance(tensor, dict):
  18. return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
  19. tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
  20. return tensor
  21. class CFGDenoiser(torch.nn.Module):
  22. """
  23. Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
  24. that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
  25. instead of one. Originally, the second prompt is just an empty string, but we use non-empty
  26. negative prompt.
  27. """
  28. def __init__(self, sampler):
  29. super().__init__()
  30. self.model_wrap = None
  31. self.mask = None
  32. self.nmask = None
  33. self.init_latent = None
  34. self.steps = None
  35. """number of steps as specified by user in UI"""
  36. self.total_steps = None
  37. """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
  38. self.step = 0
  39. self.image_cfg_scale = None
  40. self.padded_cond_uncond = False
  41. self.padded_cond_uncond_v0 = False
  42. self.sampler = sampler
  43. self.model_wrap = None
  44. self.p = None
  45. # NOTE: masking before denoising can cause the original latents to be oversmoothed
  46. # as the original latents do not have noise
  47. self.mask_before_denoising = False
  48. @property
  49. def inner_model(self):
  50. raise NotImplementedError()
  51. def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
  52. denoised_uncond = x_out[-uncond.shape[0]:]
  53. denoised = torch.clone(denoised_uncond)
  54. for i, conds in enumerate(conds_list):
  55. for cond_index, weight in conds:
  56. denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
  57. return denoised
  58. def combine_denoised_for_edit_model(self, x_out, cond_scale):
  59. out_cond, out_img_cond, out_uncond = x_out.chunk(3)
  60. denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
  61. return denoised
  62. def get_pred_x0(self, x_in, x_out, sigma):
  63. return x_out
  64. def update_inner_model(self):
  65. self.model_wrap = None
  66. c, uc = self.p.get_conds()
  67. self.sampler.sampler_extra_args['cond'] = c
  68. self.sampler.sampler_extra_args['uncond'] = uc
  69. def pad_cond_uncond(self, cond, uncond):
  70. empty = shared.sd_model.cond_stage_model_empty_prompt
  71. num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
  72. if num_repeats < 0:
  73. cond = pad_cond(cond, -num_repeats, empty)
  74. self.padded_cond_uncond = True
  75. elif num_repeats > 0:
  76. uncond = pad_cond(uncond, num_repeats, empty)
  77. self.padded_cond_uncond = True
  78. return cond, uncond
  79. def pad_cond_uncond_v0(self, cond, uncond):
  80. """
  81. Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
  82. If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
  83. If 'uncond' is a tensor, it is padded directly.
  84. If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
  85. is repeated to match the number of columns in 'cond'.
  86. If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
  87. to match the number of columns in 'cond'.
  88. Args:
  89. cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
  90. uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
  91. Returns:
  92. tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
  93. Note:
  94. This is the padding that was always used in DDIM before version 1.6.0
  95. """
  96. is_dict_cond = isinstance(uncond, dict)
  97. uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
  98. if uncond_vec.shape[1] < cond.shape[1]:
  99. last_vector = uncond_vec[:, -1:]
  100. last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
  101. uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
  102. self.padded_cond_uncond_v0 = True
  103. elif uncond_vec.shape[1] > cond.shape[1]:
  104. uncond_vec = uncond_vec[:, :cond.shape[1]]
  105. self.padded_cond_uncond_v0 = True
  106. if is_dict_cond:
  107. uncond['crossattn'] = uncond_vec
  108. else:
  109. uncond = uncond_vec
  110. return cond, uncond
  111. def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
  112. if state.interrupted or state.skipped:
  113. raise sd_samplers_common.InterruptedException
  114. if sd_samplers_common.apply_refiner(self, sigma):
  115. cond = self.sampler.sampler_extra_args['cond']
  116. uncond = self.sampler.sampler_extra_args['uncond']
  117. # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
  118. # so is_edit_model is set to False to support AND composition.
  119. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
  120. conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
  121. uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
  122. assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
  123. # If we use masks, blending between the denoised and original latent images occurs here.
  124. def apply_blend(current_latent):
  125. blended_latent = current_latent * self.nmask + self.init_latent * self.mask
  126. if self.p.scripts is not None:
  127. from modules import scripts
  128. mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
  129. self.p.scripts.on_mask_blend(self.p, mba)
  130. blended_latent = mba.blended_latent
  131. return blended_latent
  132. # Blend in the original latents (before)
  133. if self.mask_before_denoising and self.mask is not None:
  134. x = apply_blend(x)
  135. batch_size = len(conds_list)
  136. repeats = [len(conds_list[i]) for i in range(batch_size)]
  137. if shared.sd_model.model.conditioning_key == "crossattn-adm":
  138. image_uncond = torch.zeros_like(image_cond)
  139. make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
  140. else:
  141. image_uncond = image_cond
  142. if isinstance(uncond, dict):
  143. make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
  144. else:
  145. make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
  146. if not is_edit_model:
  147. x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
  148. sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
  149. image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
  150. else:
  151. x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
  152. sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
  153. image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
  154. denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
  155. cfg_denoiser_callback(denoiser_params)
  156. x_in = denoiser_params.x
  157. image_cond_in = denoiser_params.image_cond
  158. sigma_in = denoiser_params.sigma
  159. tensor = denoiser_params.text_cond
  160. uncond = denoiser_params.text_uncond
  161. skip_uncond = False
  162. # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
  163. if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
  164. skip_uncond = True
  165. x_in = x_in[:-batch_size]
  166. sigma_in = sigma_in[:-batch_size]
  167. self.padded_cond_uncond = False
  168. self.padded_cond_uncond_v0 = False
  169. if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
  170. tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
  171. elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
  172. tensor, uncond = self.pad_cond_uncond(tensor, uncond)
  173. if tensor.shape[1] == uncond.shape[1] or skip_uncond:
  174. if is_edit_model:
  175. cond_in = catenate_conds([tensor, uncond, uncond])
  176. elif skip_uncond:
  177. cond_in = tensor
  178. else:
  179. cond_in = catenate_conds([tensor, uncond])
  180. if shared.opts.batch_cond_uncond:
  181. x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
  182. else:
  183. x_out = torch.zeros_like(x_in)
  184. for batch_offset in range(0, x_out.shape[0], batch_size):
  185. a = batch_offset
  186. b = a + batch_size
  187. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
  188. else:
  189. x_out = torch.zeros_like(x_in)
  190. batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
  191. for batch_offset in range(0, tensor.shape[0], batch_size):
  192. a = batch_offset
  193. b = min(a + batch_size, tensor.shape[0])
  194. if not is_edit_model:
  195. c_crossattn = subscript_cond(tensor, a, b)
  196. else:
  197. c_crossattn = torch.cat([tensor[a:b]], uncond)
  198. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
  199. if not skip_uncond:
  200. x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
  201. denoised_image_indexes = [x[0][0] for x in conds_list]
  202. if skip_uncond:
  203. fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
  204. x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
  205. denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
  206. cfg_denoised_callback(denoised_params)
  207. devices.test_for_nans(x_out, "unet")
  208. if is_edit_model:
  209. denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
  210. elif skip_uncond:
  211. denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
  212. else:
  213. denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
  214. # Blend in the original latents (after)
  215. if not self.mask_before_denoising and self.mask is not None:
  216. denoised = apply_blend(denoised)
  217. self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
  218. if opts.live_preview_content == "Prompt":
  219. preview = self.sampler.last_latent
  220. elif opts.live_preview_content == "Negative prompt":
  221. preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
  222. else:
  223. preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
  224. sd_samplers_common.store_latent(preview)
  225. after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
  226. cfg_after_cfg_callback(after_cfg_callback_params)
  227. denoised = after_cfg_callback_params.x
  228. self.step += 1
  229. return denoised