Browse Source

add before_token_counter callback and use it for prompt comments

AUTOMATIC1111 1 year ago
parent
commit
b7f45e67dc
3 changed files with 43 additions and 1 deletions
  1. 11 1
      modules/processing_scripts/comments.py
  2. 26 0
      modules/script_callbacks.py
  3. 6 0
      modules/ui.py

+ 11 - 1
modules/processing_scripts/comments.py

@@ -1,4 +1,4 @@
-from modules import scripts, shared
+from modules import scripts, shared, script_callbacks
 import re
 import re
 
 
 
 
@@ -27,6 +27,16 @@ class ScriptStripComments(scripts.Script):
         p.main_negative_prompt = strip_comments(p.main_negative_prompt)
         p.main_negative_prompt = strip_comments(p.main_negative_prompt)
 
 
 
 
+def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
+    if not shared.opts.enable_prompt_comments:
+        return
+
+    params.prompt = strip_comments(params.prompt)
+
+
+script_callbacks.on_before_token_counter(before_token_counter)
+
+
 shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
 shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
     "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
     "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
 }))
 }))

+ 26 - 0
modules/script_callbacks.py

@@ -1,3 +1,4 @@
+import dataclasses
 import inspect
 import inspect
 import os
 import os
 from collections import namedtuple
 from collections import namedtuple
@@ -106,6 +107,15 @@ class ImageGridLoopParams:
         self.rows = rows
         self.rows = rows
 
 
 
 
+@dataclasses.dataclass
+class BeforeTokenCounterParams:
+    prompt: str
+    steps: int
+    styles: list
+
+    is_positive: bool = True
+
+
 ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
 ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
 callback_map = dict(
 callback_map = dict(
     callbacks_app_started=[],
     callbacks_app_started=[],
@@ -128,6 +138,7 @@ callback_map = dict(
     callbacks_on_reload=[],
     callbacks_on_reload=[],
     callbacks_list_optimizers=[],
     callbacks_list_optimizers=[],
     callbacks_list_unets=[],
     callbacks_list_unets=[],
+    callbacks_before_token_counter=[],
 )
 )
 
 
 
 
@@ -309,6 +320,14 @@ def list_unets_callback():
     return res
     return res
 
 
 
 
+def before_token_counter_callback(params: BeforeTokenCounterParams):
+    for c in callback_map['callbacks_before_token_counter']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'before_token_counter')
+
+
 def add_callback(callbacks, fun):
 def add_callback(callbacks, fun):
     stack = [x for x in inspect.stack() if x.filename != __file__]
     stack = [x for x in inspect.stack() if x.filename != __file__]
     filename = stack[0].filename if stack else 'unknown file'
     filename = stack[0].filename if stack else 'unknown file'
@@ -483,3 +502,10 @@ def on_list_unets(callback):
     The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
     The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
 
 
     add_callback(callback_map['callbacks_list_unets'], callback)
     add_callback(callback_map['callbacks_list_unets'], callback)
+
+
+def on_before_token_counter(callback):
+    """register a function to be called when UI is counting tokens for a prompt.
+    The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
+
+    add_callback(callback_map['callbacks_before_token_counter'], callback)

+ 6 - 0
modules/ui.py

@@ -152,6 +152,12 @@ def connect_clear_prompt(button):
 
 
 
 
 def update_token_counter(text, steps, styles, *, is_positive=True):
 def update_token_counter(text, steps, styles, *, is_positive=True):
+    params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
+    script_callbacks.before_token_counter_callback(params)
+    text = params.prompt
+    steps = params.steps
+    styles = params.styles
+    is_positive = params.is_positive
 
 
     if shared.opts.include_styles_into_token_counters:
     if shared.opts.include_styles_into_token_counters:
         apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
         apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt