123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- import torch
- from modules import prompt_parser, devices, sd_samplers_common
- from modules.shared import opts, state
- import modules.shared as shared
- from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
- from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
- from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
- def catenate_conds(conds):
- if not isinstance(conds[0], dict):
- return torch.cat(conds)
- return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
- def subscript_cond(cond, a, b):
- if not isinstance(cond, dict):
- return cond[a:b]
- return {key: vec[a:b] for key, vec in cond.items()}
- def pad_cond(tensor, repeats, empty):
- if not isinstance(tensor, dict):
- return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
- tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
- return tensor
- class CFGDenoiser(torch.nn.Module):
- """
- Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
- that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
- instead of one. Originally, the second prompt is just an empty string, but we use non-empty
- negative prompt.
- """
- def __init__(self, sampler):
- super().__init__()
- self.model_wrap = None
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.steps = None
- """number of steps as specified by user in UI"""
- self.total_steps = None
- """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
- self.step = 0
- self.image_cfg_scale = None
- self.padded_cond_uncond = False
- self.padded_cond_uncond_v0 = False
- self.sampler = sampler
- self.model_wrap = None
- self.p = None
- # NOTE: masking before denoising can cause the original latents to be oversmoothed
- # as the original latents do not have noise
- self.mask_before_denoising = False
- @property
- def inner_model(self):
- raise NotImplementedError()
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
- return denoised
- def combine_denoised_for_edit_model(self, x_out, cond_scale):
- out_cond, out_img_cond, out_uncond = x_out.chunk(3)
- denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
- return denoised
- def get_pred_x0(self, x_in, x_out, sigma):
- return x_out
- def update_inner_model(self):
- self.model_wrap = None
- c, uc = self.p.get_conds()
- self.sampler.sampler_extra_args['cond'] = c
- self.sampler.sampler_extra_args['uncond'] = uc
- def pad_cond_uncond(self, cond, uncond):
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
- if num_repeats < 0:
- cond = pad_cond(cond, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
- return cond, uncond
- def pad_cond_uncond_v0(self, cond, uncond):
- """
- Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
- If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
- If 'uncond' is a tensor, it is padded directly.
- If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
- is repeated to match the number of columns in 'cond'.
- If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
- to match the number of columns in 'cond'.
- Args:
- cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
- uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
- Returns:
- tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
- Note:
- This is the padding that was always used in DDIM before version 1.6.0
- """
- is_dict_cond = isinstance(uncond, dict)
- uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
- if uncond_vec.shape[1] < cond.shape[1]:
- last_vector = uncond_vec[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
- uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
- self.padded_cond_uncond_v0 = True
- elif uncond_vec.shape[1] > cond.shape[1]:
- uncond_vec = uncond_vec[:, :cond.shape[1]]
- self.padded_cond_uncond_v0 = True
- if is_dict_cond:
- uncond['crossattn'] = uncond_vec
- else:
- uncond = uncond_vec
- return cond, uncond
- def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
- if state.interrupted or state.skipped:
- raise sd_samplers_common.InterruptedException
- if sd_samplers_common.apply_refiner(self, sigma):
- cond = self.sampler.sampler_extra_args['cond']
- uncond = self.sampler.sampler_extra_args['uncond']
- # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
- # so is_edit_model is set to False to support AND composition.
- is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
- # If we use masks, blending between the denoised and original latent images occurs here.
- def apply_blend(current_latent):
- blended_latent = current_latent * self.nmask + self.init_latent * self.mask
- if self.p.scripts is not None:
- from modules import scripts
- mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
- self.p.scripts.on_mask_blend(self.p, mba)
- blended_latent = mba.blended_latent
- return blended_latent
- # Blend in the original latents (before)
- if self.mask_before_denoising and self.mask is not None:
- x = apply_blend(x)
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
- if shared.sd_model.model.conditioning_key == "crossattn-adm":
- image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
- else:
- image_uncond = image_cond
- if isinstance(uncond, dict):
- make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
- else:
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
- if not is_edit_model:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
- else:
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
- tensor = denoiser_params.text_cond
- uncond = denoiser_params.text_uncond
- skip_uncond = False
- # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
- if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
- skip_uncond = True
- x_in = x_in[:-batch_size]
- sigma_in = sigma_in[:-batch_size]
- self.padded_cond_uncond = False
- self.padded_cond_uncond_v0 = False
- if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
- tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
- elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- tensor, uncond = self.pad_cond_uncond(tensor, uncond)
- if tensor.shape[1] == uncond.shape[1] or skip_uncond:
- if is_edit_model:
- cond_in = catenate_conds([tensor, uncond, uncond])
- elif skip_uncond:
- cond_in = tensor
- else:
- cond_in = catenate_conds([tensor, uncond])
- if shared.opts.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
- if not is_edit_model:
- c_crossattn = subscript_cond(tensor, a, b)
- else:
- c_crossattn = torch.cat([tensor[a:b]], uncond)
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
- denoised_image_indexes = [x[0][0] for x in conds_list]
- if skip_uncond:
- fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
- x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
- cfg_denoised_callback(denoised_params)
- devices.test_for_nans(x_out, "unet")
- if is_edit_model:
- denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
- elif skip_uncond:
- denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
- else:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- # Blend in the original latents (after)
- if not self.mask_before_denoising and self.mask is not None:
- denoised = apply_blend(denoised)
- self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
- if opts.live_preview_content == "Prompt":
- preview = self.sampler.last_latent
- elif opts.live_preview_content == "Negative prompt":
- preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
- else:
- preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
- sd_samplers_common.store_latent(preview)
- after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
- cfg_after_cfg_callback(after_cfg_callback_params)
- denoised = after_cfg_callback_params.x
- self.step += 1
- return denoised
|