瀏覽代碼

read infotext params from the other extension for Lora if it's not active

AUTOMATIC 2 年之前
父節點
當前提交
2473bafa67
共有 2 個文件被更改,包括 36 次插入1 次删除
  1. 35 1
      extensions-builtin/Lora/lora.py
  2. 1 0
      extensions-builtin/Lora/scripts/lora_script.py

+ 35 - 1
extensions-builtin/Lora/lora.py

@@ -4,7 +4,7 @@ import re
 import torch
 from typing import Union
 
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, sd_models, errors, scripts
 
 metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
 
@@ -366,6 +366,40 @@ def list_available_loras():
         available_lora_aliases[entry.alias] = entry
 
 
+re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+    if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+        return  # if the other extension is active, it will handle those fields, no need to do anything
+
+    added = []
+
+    for k, v in params.items():
+        if not k.startswith("AddNet Model "):
+            continue
+
+        num = k[13:]
+
+        if params.get("AddNet Module " + num) != "LoRA":
+            continue
+
+        name = params.get("AddNet Model " + num)
+        if name is None:
+            continue
+
+        m = re_lora_name.match(name)
+        if m:
+            name = m.group(1)
+
+        multiplier = params.get("AddNet Weight A " + num, "1.0")
+
+        added.append(f"<lora:{name}:{multiplier}>")
+
+    if added:
+        params["Prompt"] += "\n" + "".join(added)
+
+
 available_loras = {}
 available_lora_aliases = {}
 loaded_loras = []

+ 1 - 0
extensions-builtin/Lora/scripts/lora_script.py

@@ -49,6 +49,7 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
 script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
 script_callbacks.on_script_unloaded(unload)
 script_callbacks.on_before_ui(before_ui)
+script_callbacks.on_infotext_pasted(lora.infotext_pasted)
 
 
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {