|
@@ -457,7 +457,7 @@ class StableDiffusionProcessing:
|
|
|
opts.emphasis,
|
|
|
)
|
|
|
|
|
|
- def apply_generation_params_states(self, generation_params_states):
|
|
|
+ def apply_generation_params_list(self, generation_params_states):
|
|
|
"""add and apply generation_params_states to self.extra_generation_params"""
|
|
|
for key, value in generation_params_states.items():
|
|
|
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
|
|
@@ -465,6 +465,12 @@ class StableDiffusionProcessing:
|
|
|
else:
|
|
|
self.extra_generation_params[key] = value
|
|
|
|
|
|
+ def clear_marked_generation_params(self):
|
|
|
+ """clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
|
|
|
+ for key, value in list(self.extra_generation_params.items()):
|
|
|
+ if getattr(value, 'to_be_clear_before_batch', False):
|
|
|
+ self.extra_generation_params.pop(key)
|
|
|
+
|
|
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
|
|
"""
|
|
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
|
@@ -491,7 +497,7 @@ class StableDiffusionProcessing:
|
|
|
if len(cache) == 3:
|
|
|
generation_params_states, cached_cached_params = cache[2]
|
|
|
if cached_params == cached_cached_params:
|
|
|
- self.apply_generation_params_states(generation_params_states)
|
|
|
+ self.apply_generation_params_list(generation_params_states)
|
|
|
return cache[1]
|
|
|
|
|
|
cache = caches[0]
|
|
@@ -500,7 +506,7 @@ class StableDiffusionProcessing:
|
|
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
|
|
|
|
|
generation_params_states = model_hijack.extract_generation_params_states()
|
|
|
- self.apply_generation_params_states(generation_params_states)
|
|
|
+ self.apply_generation_params_list(generation_params_states)
|
|
|
if len(cache) == 2:
|
|
|
cache.append((generation_params_states, cached_params))
|
|
|
else:
|
|
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
|
if state.interrupted or state.stopping_generation:
|
|
|
break
|
|
|
|
|
|
+ p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
|
|
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
|
|
|
|
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|