Pārlūkot izejas kodu

clear GenerationParametersList before batch

clears any generation parameters that are with the attribute to_be_clear_before_batch = True
prevent buildup of some parameters
w-e-w 9 mēneši atpakaļ
vecāks
revīzija
ac8c05398b
2 mainītis faili ar 18 papildinājumiem un 3 dzēšanām
  1. 10 3
      modules/processing.py
  2. 8 0
      modules/util.py

+ 10 - 3
modules/processing.py

@@ -457,7 +457,7 @@ class StableDiffusionProcessing:
             opts.emphasis,
         )
 
-    def apply_generation_params_states(self, generation_params_states):
+    def apply_generation_params_list(self, generation_params_states):
         """add and apply generation_params_states to self.extra_generation_params"""
         for key, value in generation_params_states.items():
             if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
@@ -465,6 +465,12 @@ class StableDiffusionProcessing:
             else:
                 self.extra_generation_params[key] = value
 
+    def clear_marked_generation_params(self):
+        """clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
+        for key, value in list(self.extra_generation_params.items()):
+            if getattr(value, 'to_be_clear_before_batch', False):
+                self.extra_generation_params.pop(key)
+
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
         """
         Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -491,7 +497,7 @@ class StableDiffusionProcessing:
                 if len(cache) == 3:
                     generation_params_states, cached_cached_params = cache[2]
                     if cached_params == cached_cached_params:
-                        self.apply_generation_params_states(generation_params_states)
+                        self.apply_generation_params_list(generation_params_states)
                 return cache[1]
 
         cache = caches[0]
@@ -500,7 +506,7 @@ class StableDiffusionProcessing:
             cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
 
         generation_params_states = model_hijack.extract_generation_params_states()
-        self.apply_generation_params_states(generation_params_states)
+        self.apply_generation_params_list(generation_params_states)
         if len(cache) == 2:
             cache.append((generation_params_states, cached_params))
         else:
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if state.interrupted or state.stopping_generation:
                 break
 
+            p.clear_marked_generation_params()  # clean up some generation params are tagged to be cleared before batch
             sd_models.reload_model_weights()  # model can be changed for example by refiner
 
             p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

+ 8 - 0
modules/util.py

@@ -308,9 +308,17 @@ class GenerationParametersList(list):
     if return str, the value will be written to infotext, if return None will be ignored.
     """
 
+    def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._to_be_clear_before_batch = to_be_clear_before_batch
+
     def __call__(self, *args, **kwargs):
         return ', '.join(sorted(set(self), key=natural_sort_key))
 
+    @property
+    def to_be_clear_before_batch(self):
+        return self._to_be_clear_before_batch
+
     def __add__(self, other):
         if isinstance(other, GenerationParametersList):
             return self.__class__([*self, *other])