sd_samplers_cfg_denoiser.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import torch
  2. from modules import prompt_parser, sd_samplers_common
  3. from modules.shared import opts, state
  4. import modules.shared as shared
  5. from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
  6. from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
  7. from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
  8. def catenate_conds(conds):
  9. if not isinstance(conds[0], dict):
  10. return torch.cat(conds)
  11. return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
  12. def subscript_cond(cond, a, b):
  13. if not isinstance(cond, dict):
  14. return cond[a:b]
  15. return {key: vec[a:b] for key, vec in cond.items()}
  16. def pad_cond(tensor, repeats, empty):
  17. if not isinstance(tensor, dict):
  18. return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
  19. tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
  20. return tensor
  21. class CFGDenoiser(torch.nn.Module):
  22. """
  23. Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
  24. that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
  25. instead of one. Originally, the second prompt is just an empty string, but we use non-empty
  26. negative prompt.
  27. """
  28. def __init__(self, sampler):
  29. super().__init__()
  30. self.model_wrap = None
  31. self.mask = None
  32. self.nmask = None
  33. self.init_latent = None
  34. self.steps = None
  35. """number of steps as specified by user in UI"""
  36. self.total_steps = None
  37. """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
  38. self.step = 0
  39. self.image_cfg_scale = None
  40. self.padded_cond_uncond = False
  41. self.padded_cond_uncond_v0 = False
  42. self.sampler = sampler
  43. self.model_wrap = None
  44. self.p = None
  45. self.cond_scale_miltiplier = 1.0
  46. self.need_last_noise_uncond = False
  47. self.last_noise_uncond = None
  48. # NOTE: masking before denoising can cause the original latents to be oversmoothed
  49. # as the original latents do not have noise
  50. self.mask_before_denoising = False
  51. @property
  52. def inner_model(self):
  53. raise NotImplementedError()
  54. def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
  55. denoised_uncond = x_out[-uncond.shape[0]:]
  56. denoised = torch.clone(denoised_uncond)
  57. for i, conds in enumerate(conds_list):
  58. for cond_index, weight in conds:
  59. denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
  60. return denoised
  61. def combine_denoised_for_edit_model(self, x_out, cond_scale):
  62. out_cond, out_img_cond, out_uncond = x_out.chunk(3)
  63. denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
  64. return denoised
  65. def get_pred_x0(self, x_in, x_out, sigma):
  66. return x_out
  67. def update_inner_model(self):
  68. self.model_wrap = None
  69. c, uc = self.p.get_conds()
  70. self.sampler.sampler_extra_args['cond'] = c
  71. self.sampler.sampler_extra_args['uncond'] = uc
  72. def pad_cond_uncond(self, cond, uncond):
  73. empty = shared.sd_model.cond_stage_model_empty_prompt
  74. num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
  75. if num_repeats < 0:
  76. cond = pad_cond(cond, -num_repeats, empty)
  77. self.padded_cond_uncond = True
  78. elif num_repeats > 0:
  79. uncond = pad_cond(uncond, num_repeats, empty)
  80. self.padded_cond_uncond = True
  81. return cond, uncond
  82. def pad_cond_uncond_v0(self, cond, uncond):
  83. """
  84. Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
  85. If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
  86. If 'uncond' is a tensor, it is padded directly.
  87. If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
  88. is repeated to match the number of columns in 'cond'.
  89. If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
  90. to match the number of columns in 'cond'.
  91. Args:
  92. cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
  93. uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
  94. Returns:
  95. tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
  96. Note:
  97. This is the padding that was always used in DDIM before version 1.6.0
  98. """
  99. is_dict_cond = isinstance(uncond, dict)
  100. uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
  101. if uncond_vec.shape[1] < cond.shape[1]:
  102. last_vector = uncond_vec[:, -1:]
  103. last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
  104. uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
  105. self.padded_cond_uncond_v0 = True
  106. elif uncond_vec.shape[1] > cond.shape[1]:
  107. uncond_vec = uncond_vec[:, :cond.shape[1]]
  108. self.padded_cond_uncond_v0 = True
  109. if is_dict_cond:
  110. uncond['crossattn'] = uncond_vec
  111. else:
  112. uncond = uncond_vec
  113. return cond, uncond
  114. def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
  115. if state.interrupted or state.skipped:
  116. raise sd_samplers_common.InterruptedException
  117. if sd_samplers_common.apply_refiner(self, sigma):
  118. cond = self.sampler.sampler_extra_args['cond']
  119. uncond = self.sampler.sampler_extra_args['uncond']
  120. # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
  121. # so is_edit_model is set to False to support AND composition.
  122. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
  123. conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
  124. uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
  125. assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
  126. # If we use masks, blending between the denoised and original latent images occurs here.
  127. def apply_blend(current_latent):
  128. blended_latent = current_latent * self.nmask + self.init_latent * self.mask
  129. if self.p.scripts is not None:
  130. from modules import scripts
  131. mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
  132. self.p.scripts.on_mask_blend(self.p, mba)
  133. blended_latent = mba.blended_latent
  134. return blended_latent
  135. # Blend in the original latents (before)
  136. if self.mask_before_denoising and self.mask is not None:
  137. x = apply_blend(x)
  138. batch_size = len(conds_list)
  139. repeats = [len(conds_list[i]) for i in range(batch_size)]
  140. if shared.sd_model.model.conditioning_key == "crossattn-adm":
  141. image_uncond = torch.zeros_like(image_cond)
  142. make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
  143. else:
  144. image_uncond = image_cond
  145. if isinstance(uncond, dict):
  146. make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
  147. else:
  148. make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
  149. if not is_edit_model:
  150. x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
  151. sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
  152. image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
  153. else:
  154. x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
  155. sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
  156. image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
  157. denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
  158. cfg_denoiser_callback(denoiser_params)
  159. x_in = denoiser_params.x
  160. image_cond_in = denoiser_params.image_cond
  161. sigma_in = denoiser_params.sigma
  162. tensor = denoiser_params.text_cond
  163. uncond = denoiser_params.text_uncond
  164. skip_uncond = False
  165. if shared.opts.skip_early_cond != 0. and self.step / self.total_steps <= shared.opts.skip_early_cond:
  166. skip_uncond = True
  167. self.p.extra_generation_params["Skip Early CFG"] = shared.opts.skip_early_cond
  168. elif (self.step % 2 or shared.opts.s_min_uncond_all) and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
  169. skip_uncond = True
  170. self.p.extra_generation_params["NGMS"] = s_min_uncond
  171. if shared.opts.s_min_uncond_all:
  172. self.p.extra_generation_params["NGMS all steps"] = shared.opts.s_min_uncond_all
  173. if skip_uncond:
  174. x_in = x_in[:-batch_size]
  175. sigma_in = sigma_in[:-batch_size]
  176. self.padded_cond_uncond = False
  177. self.padded_cond_uncond_v0 = False
  178. if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
  179. tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
  180. elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
  181. tensor, uncond = self.pad_cond_uncond(tensor, uncond)
  182. if tensor.shape[1] == uncond.shape[1] or skip_uncond:
  183. if is_edit_model:
  184. cond_in = catenate_conds([tensor, uncond, uncond])
  185. elif skip_uncond:
  186. cond_in = tensor
  187. else:
  188. cond_in = catenate_conds([tensor, uncond])
  189. if shared.opts.batch_cond_uncond:
  190. x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
  191. else:
  192. x_out = torch.zeros_like(x_in)
  193. for batch_offset in range(0, x_out.shape[0], batch_size):
  194. a = batch_offset
  195. b = a + batch_size
  196. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
  197. else:
  198. x_out = torch.zeros_like(x_in)
  199. batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
  200. for batch_offset in range(0, tensor.shape[0], batch_size):
  201. a = batch_offset
  202. b = min(a + batch_size, tensor.shape[0])
  203. if not is_edit_model:
  204. c_crossattn = subscript_cond(tensor, a, b)
  205. else:
  206. c_crossattn = torch.cat([tensor[a:b]], uncond)
  207. x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
  208. if not skip_uncond:
  209. x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
  210. denoised_image_indexes = [x[0][0] for x in conds_list]
  211. if skip_uncond:
  212. fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
  213. x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
  214. denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
  215. cfg_denoised_callback(denoised_params)
  216. if self.need_last_noise_uncond:
  217. self.last_noise_uncond = torch.clone(x_out[-uncond.shape[0]:])
  218. if is_edit_model:
  219. denoised = self.combine_denoised_for_edit_model(x_out, cond_scale * self.cond_scale_miltiplier)
  220. elif skip_uncond:
  221. denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
  222. else:
  223. denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale * self.cond_scale_miltiplier)
  224. # Blend in the original latents (after)
  225. if not self.mask_before_denoising and self.mask is not None:
  226. denoised = apply_blend(denoised)
  227. self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
  228. if opts.live_preview_content == "Prompt":
  229. preview = self.sampler.last_latent
  230. elif opts.live_preview_content == "Negative prompt":
  231. preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
  232. else:
  233. preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
  234. sd_samplers_common.store_latent(preview)
  235. after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
  236. cfg_after_cfg_callback(after_cfg_callback_params)
  237. denoised = after_cfg_callback_params.x
  238. self.step += 1
  239. return denoised