瀏覽代碼

clear GenerationParametersList before batch

clears any generation parameters that are with the attribute to_be_clear_before_batch = True
prevent buildup of some parameters
w-e-w 9 月之前
父節點
當前提交
ac8c05398b
共有 2 個文件被更改,包括 18 次插入3 次删除
  1. 10 3
      modules/processing.py
  2. 8 0
      modules/util.py

+ 10 - 3
modules/processing.py

@@ -457,7 +457,7 @@ class StableDiffusionProcessing:
             opts.emphasis,
             opts.emphasis,
         )
         )
 
 
-    def apply_generation_params_states(self, generation_params_states):
+    def apply_generation_params_list(self, generation_params_states):
         """add and apply generation_params_states to self.extra_generation_params"""
         """add and apply generation_params_states to self.extra_generation_params"""
         for key, value in generation_params_states.items():
         for key, value in generation_params_states.items():
             if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
             if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
@@ -465,6 +465,12 @@ class StableDiffusionProcessing:
             else:
             else:
                 self.extra_generation_params[key] = value
                 self.extra_generation_params[key] = value
 
 
+    def clear_marked_generation_params(self):
+        """clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
+        for key, value in list(self.extra_generation_params.items()):
+            if getattr(value, 'to_be_clear_before_batch', False):
+                self.extra_generation_params.pop(key)
+
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
         """
         """
         Returns the result of calling function(shared.sd_model, required_prompts, steps)
         Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -491,7 +497,7 @@ class StableDiffusionProcessing:
                 if len(cache) == 3:
                 if len(cache) == 3:
                     generation_params_states, cached_cached_params = cache[2]
                     generation_params_states, cached_cached_params = cache[2]
                     if cached_params == cached_cached_params:
                     if cached_params == cached_cached_params:
-                        self.apply_generation_params_states(generation_params_states)
+                        self.apply_generation_params_list(generation_params_states)
                 return cache[1]
                 return cache[1]
 
 
         cache = caches[0]
         cache = caches[0]
@@ -500,7 +506,7 @@ class StableDiffusionProcessing:
             cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
             cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
 
 
         generation_params_states = model_hijack.extract_generation_params_states()
         generation_params_states = model_hijack.extract_generation_params_states()
-        self.apply_generation_params_states(generation_params_states)
+        self.apply_generation_params_list(generation_params_states)
         if len(cache) == 2:
         if len(cache) == 2:
             cache.append((generation_params_states, cached_params))
             cache.append((generation_params_states, cached_params))
         else:
         else:
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if state.interrupted or state.stopping_generation:
             if state.interrupted or state.stopping_generation:
                 break
                 break
 
 
+            p.clear_marked_generation_params()  # clean up some generation params are tagged to be cleared before batch
             sd_models.reload_model_weights()  # model can be changed for example by refiner
             sd_models.reload_model_weights()  # model can be changed for example by refiner
 
 
             p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
             p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

+ 8 - 0
modules/util.py

@@ -308,9 +308,17 @@ class GenerationParametersList(list):
     if return str, the value will be written to infotext, if return None will be ignored.
     if return str, the value will be written to infotext, if return None will be ignored.
     """
     """
 
 
+    def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._to_be_clear_before_batch = to_be_clear_before_batch
+
     def __call__(self, *args, **kwargs):
     def __call__(self, *args, **kwargs):
         return ', '.join(sorted(set(self), key=natural_sort_key))
         return ', '.join(sorted(set(self), key=natural_sort_key))
 
 
+    @property
+    def to_be_clear_before_batch(self):
+        return self._to_be_clear_before_batch
+
     def __add__(self, other):
     def __add__(self, other):
         if isinstance(other, GenerationParametersList):
         if isinstance(other, GenerationParametersList):
             return self.__class__([*self, *other])
             return self.__class__([*self, *other])