Selaa lähdekoodia

switched the token counter to use hidden buttons instead of api call

Liam 2 vuotta sitten
vanhempi
commit
e5707b66d6
4 muutettua tiedostoa jossa 25 lisäystä ja 64 poistoa
  1. 0 13
      javascript/helpers.js
  2. 15 45
      javascript/ui.js
  3. 1 2
      modules/sd_hijack.py
  4. 9 4
      modules/ui.py

+ 0 - 13
javascript/helpers.js

@@ -1,13 +0,0 @@
-// helper functions
-
-function debounce(func, wait_time) {
-	let timeout;
-	return function wrapped(...args) {
-		let call_function = () => {
-			clearTimeout(timeout);
-			func(...args)
-		}
-		clearTimeout(timeout);
-		timeout = setTimeout(call_function, wait_time);
-	};
-}

+ 15 - 45
javascript/ui.js

@@ -182,51 +182,21 @@ onUiUpdate(function(){
     });
     });
 
 
     json_elem.parentElement.style.display="none"
     json_elem.parentElement.style.display="none"
-
-	let debounce_time = 800
-	if (!txt2img_textarea) {
-		txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea")
-		txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time))
-	}
-	if (!img2img_textarea) {
-		img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea")
-		img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time))
-    }
 })
 })
 
 
+let wait_time = 800
+let token_timeout;
+function txt2img_token_counter(text) {
+	return update_token_counter("txt2img_token_button", text);
+}
+
+function img2img_token_counter(text) {
+	return update_token_counter("img2img_token_button", text);
+}
 
 
-let txt2img_textarea, img2img_textarea = undefined;
-function submit_prompt_text(source, e) {
-	let prompt_text;
-	if (source == "txt2img")
-		prompt_text = txt2img_textarea.value;
-	else if (source == "img2img")
-		prompt_text = img2img_textarea.value;
-	if (!prompt_text)
-		return;
-	params = {
-		method: "POST",
-		headers: {
-			"Accept": "application/json",
-			"Content-type": "application/json"
-		},
-		body: JSON.stringify({data:[prompt_text]})
-	}
-	fetch('http://127.0.0.1:7860/api/tokenize/', params)
-	.then((response) => response.json())
-	.then((data) => {
-		if (data?.data.length) {
-			let response_json = data.data[0]
-			if (elem = gradioApp().getElementById(source+"_token_counter")) {
-				if (response_json.token_count > response_json.max_length)
-					elem.classList.add("red");
-				else
-					elem.classList.remove("red");
-				elem.innerText = response_json.token_count + "/" + response_json.max_length;
-			}
-		}
-	})
-	.catch((error) => {
-		console.error('Error:', error);
-	});
-}
+function update_token_counter(button_id, text) {
+	if (token_timeout)
+		clearTimeout(token_timeout);
+	token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+	return [];
+}

+ 1 - 2
modules/sd_hijack.py

@@ -273,8 +273,7 @@ class StableDiffusionModelHijack:
     def tokenize(self, text):
     def tokenize(self, text):
         max_length = self.clip.max_length - 2
         max_length = self.clip.max_length - 2
         _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
         _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
-        return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length}
-        
+        return remade_batch_tokens[0], token_count, max_length
 
 
 class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
 class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
     def __init__(self, wrapped, hijack):
     def __init__(self, wrapped, hijack):

+ 9 - 4
modules/ui.py

@@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts
 import modules.shared as shared
 import modules.shared as shared
 from modules.sd_samplers import samplers, samplers_for_img2img
 from modules.sd_samplers import samplers, samplers_for_img2img
 from modules.sd_hijack import model_hijack
 from modules.sd_hijack import model_hijack
+from modules.helpers import debounce
 import modules.ldsr_model
 import modules.ldsr_model
 import modules.scripts
 import modules.scripts
 import modules.gfpgan_model
 import modules.gfpgan_model
@@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
         outputs=[seed, dummy_component]
         outputs=[seed, dummy_component]
     )
     )
 
 
+def update_token_counter(text):
+    tokens, token_count, max_length = model_hijack.tokenize(text)
+    style_class = ' class="red"' if (token_count > max_length) else ""
+    return f"<span {style_class}>{token_count}/{max_length}</span>"
 
 
 def create_toprow(is_img2img):
 def create_toprow(is_img2img):
     id_part = "img2img" if is_img2img else "txt2img"
     id_part = "img2img" if is_img2img else "txt2img"
@@ -339,15 +344,15 @@ def create_toprow(is_img2img):
             with gr.Row():
             with gr.Row():
                 with gr.Column(scale=80):
                 with gr.Column(scale=80):
                     with gr.Row():
                     with gr.Row():
-                        prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2)
+                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
+                        prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False)
 
 
                 with gr.Column(scale=1, elem_id="roll_col"):
                 with gr.Column(scale=1, elem_id="roll_col"):
                     roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
                     roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
                     paste = gr.Button(value=paste_symbol, elem_id="paste")
                     paste = gr.Button(value=paste_symbol, elem_id="paste")
                     token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
                     token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
-                    token_output = gr.JSON(visible=False)
-                    if is_img2img:  # only define the api function ONCE
-                        token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output])
+                    hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+                    hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
 
 
                 with gr.Column(scale=10, elem_id="style_pos_col"):
                 with gr.Column(scale=10, elem_id="style_pos_col"):
                     prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
                     prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)