Browse Source

Added negative prompts to extra networks lora

Learwin 1 year ago
parent
commit
bc5ae74c7d

+ 12 - 2
extensions-builtin/Lora/ui_edit_user_metadata.py

@@ -54,12 +54,14 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
         self.slider_preferred_weight = None
         self.edit_notes = None
 
-    def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
+    def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes):
         user_metadata = self.get_user_metadata(name)
         user_metadata["description"] = desc
         user_metadata["sd version"] = sd_version
         user_metadata["activation text"] = activation_text
         user_metadata["preferred weight"] = preferred_weight
+        user_metadata["negative text"] = negative_text
+        user_metadata["negative weight"] = negative_weight
         user_metadata["notes"] = notes
 
         self.write_user_metadata(name, user_metadata)
@@ -127,6 +129,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
             gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
             user_metadata.get('activation text', ''),
             float(user_metadata.get('preferred weight', 0.0)),
+            user_metadata.get('negative text', ''),
+            float(user_metadata.get('negative weight', 0.0)),
             gr.update(visible=True if tags else False),
             gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
         ]
@@ -162,7 +166,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
         self.taginfo = gr.HighlightedText(label="Training dataset tags")
         self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
         self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
-
+        self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
+        self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
         with gr.Row() as row_random_prompt:
             with gr.Column(scale=8):
                 random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +203,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
             self.taginfo,
             self.edit_activation_text,
             self.slider_preferred_weight,
+            self.edit_negative_text,
+            self.slider_negative_weight,
             row_random_prompt,
             random_prompt,
         ]
@@ -211,7 +218,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
             self.select_sd_version,
             self.edit_activation_text,
             self.slider_preferred_weight,
+            self.edit_negative_text,
+            self.slider_negative_weight,
             self.edit_notes,
         ]
 
+
         self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)

+ 9 - 0
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -45,6 +45,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
         if activation_text:
             item["prompt"] += " + " + quote_js(" " + activation_text)
 
+        negative_prompt = item["user_metadata"].get("negative text")
+        preferred_negative_weight = item["user_metadata"].get("negative weight")
+        item["negative_prompt"] = quote_js("")
+        if negative_prompt:
+            neg_prompt = negative_prompt
+            if (preferred_negative_weight > 0):
+                neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')'
+            item["negative_prompt"] = quote_js(neg_prompt)  
+            
         sd_version = item["user_metadata"].get("sd version")
         if sd_version in network.SdVersion.__members__:
             item["sd_version"] = sd_version

+ 21 - 10
javascript/extraNetworks.js

@@ -185,8 +185,10 @@ onUiLoaded(setupExtraNetworks);
 var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
 var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
 
-function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
-    var m = text.match(re_extranet);
+var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
+var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
+function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
+    var m = text.match(isNeg ? re_extranet_neg : re_extranet);
     var replaced = false;
     var newTextareaText;
     if (m) {
@@ -194,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
         var extraTextAfterNet = m[2];
         var partToSearch = m[1];
         var foundAtPosition = -1;
-        newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
-            m = found.match(re_extranet);
+        newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
+            m = found.match(isNeg ? re_extranet_neg : re_extranet);
             if (m[1] == partToSearch) {
                 replaced = true;
                 foundAtPosition = pos;
@@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
         });
 
         if (foundAtPosition >= 0) {
-            if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+            if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
                 newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
             }
             if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
@@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
     return false;
 }
 
-function cardClicked(tabname, textToAdd, allowNegativePrompt) {
-    var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+function updatePromptArea(text, textArea, isNeg) {
 
-    if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
-        textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
+    if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
+        textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
     }
 
-    updateInput(textarea);
+    updateInput(textArea);
+}
+
+function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
+    if (textToAddNegative.length > 0) {
+        updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"))
+        updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true)
+    } else {
+        var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+        updatePromptArea(textToAdd, textarea)
+    }
 }
 
 function saveCardPreview(event, tabname, filename) {

+ 4 - 1
modules/ui_extra_networks.py

@@ -223,7 +223,10 @@ class ExtraNetworksPage:
 
         onclick = item.get("onclick", None)
         if onclick is None:
-            onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+            if "negative_prompt" in item:
+                onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+            else:
+                onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
 
         height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
         width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''