|
@@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts
|
|
import modules.shared as shared
|
|
import modules.shared as shared
|
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
from modules.sd_hijack import model_hijack
|
|
from modules.sd_hijack import model_hijack
|
|
|
|
+from modules.helpers import debounce
|
|
import modules.ldsr_model
|
|
import modules.ldsr_model
|
|
import modules.scripts
|
|
import modules.scripts
|
|
import modules.gfpgan_model
|
|
import modules.gfpgan_model
|
|
@@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|
outputs=[seed, dummy_component]
|
|
outputs=[seed, dummy_component]
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+def update_token_counter(text):
|
|
|
|
+ tokens, token_count, max_length = model_hijack.tokenize(text)
|
|
|
|
+ style_class = ' class="red"' if (token_count > max_length) else ""
|
|
|
|
+ return f"<span {style_class}>{token_count}/{max_length}</span>"
|
|
|
|
|
|
def create_toprow(is_img2img):
|
|
def create_toprow(is_img2img):
|
|
id_part = "img2img" if is_img2img else "txt2img"
|
|
id_part = "img2img" if is_img2img else "txt2img"
|
|
@@ -339,15 +344,15 @@ def create_toprow(is_img2img):
|
|
with gr.Row():
|
|
with gr.Row():
|
|
with gr.Column(scale=80):
|
|
with gr.Column(scale=80):
|
|
with gr.Row():
|
|
with gr.Row():
|
|
- prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2)
|
|
|
|
|
|
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
|
|
|
+ prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False)
|
|
|
|
|
|
with gr.Column(scale=1, elem_id="roll_col"):
|
|
with gr.Column(scale=1, elem_id="roll_col"):
|
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
|
- token_output = gr.JSON(visible=False)
|
|
|
|
- if is_img2img: # only define the api function ONCE
|
|
|
|
- token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output])
|
|
|
|
|
|
+ hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
|
|
|
+ hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
|
|
|
|
|
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|