|
@@ -53,6 +53,7 @@ class CFGDenoiser(torch.nn.Module):
|
|
self.step = 0
|
|
self.step = 0
|
|
self.image_cfg_scale = None
|
|
self.image_cfg_scale = None
|
|
self.padded_cond_uncond = False
|
|
self.padded_cond_uncond = False
|
|
|
|
+ self.padded_cond_uncond_v0 = False
|
|
self.sampler = sampler
|
|
self.sampler = sampler
|
|
self.model_wrap = None
|
|
self.model_wrap = None
|
|
self.p = None
|
|
self.p = None
|
|
@@ -91,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
|
|
self.sampler.sampler_extra_args['cond'] = c
|
|
self.sampler.sampler_extra_args['cond'] = c
|
|
self.sampler.sampler_extra_args['uncond'] = uc
|
|
self.sampler.sampler_extra_args['uncond'] = uc
|
|
|
|
|
|
|
|
+ def pad_cond_uncond(self, cond, uncond):
|
|
|
|
+ empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
|
|
+ num_repeats = (cond.shape[1] - cond.shape[1]) // empty.shape[1]
|
|
|
|
+
|
|
|
|
+ if num_repeats < 0:
|
|
|
|
+ cond = pad_cond(cond, -num_repeats, empty)
|
|
|
|
+ self.padded_cond_uncond = True
|
|
|
|
+ elif num_repeats > 0:
|
|
|
|
+ uncond = pad_cond(uncond, num_repeats, empty)
|
|
|
|
+ self.padded_cond_uncond = True
|
|
|
|
+
|
|
|
|
+ return cond, uncond
|
|
|
|
+
|
|
|
|
+ def pad_cond_uncond_v0(self, cond, uncond):
|
|
|
|
+ """
|
|
|
|
+ Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
|
|
|
+
|
|
|
|
+ If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
|
|
|
+ If 'uncond' is a tensor, it is padded directly.
|
|
|
|
+
|
|
|
|
+ If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
|
|
|
+ is repeated to match the number of columns in 'cond'.
|
|
|
|
+
|
|
|
|
+ If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
|
|
|
+ to match the number of columns in 'cond'.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
|
|
|
+ uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
|
|
|
+
|
|
|
|
+ Note:
|
|
|
|
+ This is the padding that was always used in DDIM before version 1.6.0
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ is_dict_cond = isinstance(uncond, dict)
|
|
|
|
+ uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
|
|
|
+
|
|
|
|
+ if uncond_vec.shape[1] < cond.shape[1]:
|
|
|
|
+ last_vector = uncond_vec[:, -1:]
|
|
|
|
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
|
|
|
+ uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
|
|
|
+ self.padded_cond_uncond_v0 = True
|
|
|
|
+ elif uncond_vec.shape[1] > cond.shape[1]:
|
|
|
|
+ uncond_vec = uncond_vec[:, :cond.shape[1]]
|
|
|
|
+ self.padded_cond_uncond_v0 = True
|
|
|
|
+
|
|
|
|
+ if is_dict_cond:
|
|
|
|
+ uncond['crossattn'] = uncond_vec
|
|
|
|
+ else:
|
|
|
|
+ uncond = uncond_vec
|
|
|
|
+
|
|
|
|
+ return cond, uncond
|
|
|
|
+
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
if state.interrupted or state.skipped:
|
|
if state.interrupted or state.skipped:
|
|
raise sd_samplers_common.InterruptedException
|
|
raise sd_samplers_common.InterruptedException
|
|
@@ -162,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
|
|
sigma_in = sigma_in[:-batch_size]
|
|
sigma_in = sigma_in[:-batch_size]
|
|
|
|
|
|
self.padded_cond_uncond = False
|
|
self.padded_cond_uncond = False
|
|
|
|
+ self.padded_cond_uncond_v0 = False
|
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
|
- empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
|
|
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
|
|
|
-
|
|
|
|
- if num_repeats < 0:
|
|
|
|
- tensor = pad_cond(tensor, -num_repeats, empty)
|
|
|
|
- self.padded_cond_uncond = True
|
|
|
|
- elif num_repeats > 0:
|
|
|
|
- uncond = pad_cond(uncond, num_repeats, empty)
|
|
|
|
- self.padded_cond_uncond = True
|
|
|
|
|
|
+ tensor, uncond = self.pad_cond_uncond(tensor, uncond)
|
|
|
|
+ elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
|
|
|
|
+ tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
|
|
|
|
|
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
|
if is_edit_model:
|
|
if is_edit_model:
|