소스 검색

extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>

AUTOMATIC 2 년 전
부모
커밋
40ff6db532

BIN
html/card-no-preview.png


+ 11 - 0
html/extra-networks-card.html

@@ -0,0 +1,11 @@
+<div class='card' {preview_html} onclick='return cardClicked({prompt}, {allow_negative_prompt})'>
+	<div class='actions'>
+		<div class='additional'>
+			<ul>
+				<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
+			</ul>
+		</div>
+		<span class='name'>{name}</span>
+	</div>
+</div>
+

+ 8 - 0
html/extra-networks-no-cards.html

@@ -0,0 +1,8 @@
+<div class='nocards'>
+<h1>Nothing here. Add some content to the following directories:</h1>
+
+<ul>
+{dirs}
+</ul>
+</div>
+

+ 60 - 0
javascript/extraNetworks.js

@@ -0,0 +1,60 @@
+
+function setupExtraNetworksForTab(tabname){
+    gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
+
+    gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
+    gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
+}
+
+var activePromptTextarea = null;
+var activePositivePromptTextarea = null;
+
+function setupExtraNetworks(){
+    setupExtraNetworksForTab('txt2img')
+    setupExtraNetworksForTab('img2img')
+
+    function registerPrompt(id, isNegative){
+        var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+        if (activePromptTextarea == null){
+            activePromptTextarea = textarea
+        }
+        if (activePositivePromptTextarea == null && ! isNegative){
+            activePositivePromptTextarea = textarea
+        }
+
+		textarea.addEventListener("focus", function(){
+            activePromptTextarea = textarea;
+            if(! isNegative)  activePositivePromptTextarea = textarea;
+		});
+    }
+
+    registerPrompt('txt2img_prompt')
+    registerPrompt('txt2img_neg_prompt', true)
+    registerPrompt('img2img_prompt')
+    registerPrompt('img2img_neg_prompt', true)
+}
+
+onUiLoaded(setupExtraNetworks)
+
+function cardClicked(textToAdd, allowNegativePrompt){
+    textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea
+
+    textarea.value = textarea.value + " " + textToAdd
+    updateInput(textarea)
+
+    return false
+}
+
+function saveCardPreview(event, tabname, filename){
+    textarea = gradioApp().querySelector("#" + tabname + '_preview_filename  > label > textarea')
+    button = gradioApp().getElementById(tabname + '_save_preview')
+
+    textarea.value = filename
+    updateInput(textarea)
+
+    button.click()
+
+    event.stopPropagation()
+    event.preventDefault()
+}

+ 2 - 0
javascript/hints.js

@@ -21,6 +21,8 @@ titles = {
     "\U0001F5D1": "Clear prompt",
     "\u{1f4cb}": "Apply selected styles to current prompt",
     "\u{1f4d2}": "Paste available values into the field",
+    "\u{1f3b4}": "Show extra networks",
+
 
     "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
     "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

+ 4 - 5
javascript/ui.js

@@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
     return [prompt, negative_prompt]
 }
 
-
-
 opts = {}
 onUiUpdate(function(){
 	if(Object.keys(opts).length != 0) return;
@@ -239,11 +237,14 @@ onUiUpdate(function(){
             return
         }
 
+
         prompt.parentElement.insertBefore(counter, prompt)
         counter.classList.add("token-counter")
         prompt.parentElement.style.position = "relative"
 
-		textarea.addEventListener("input", () => update_token_counter(id_button));
+		textarea.addEventListener("input", function(){
+		    update_token_counter(id_button);
+		});
     }
 
     registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@@ -261,10 +262,8 @@ onUiUpdate(function(){
             })
         }
     }
-
 })
 
-
 onOptionsChanged(function(){
     elem = gradioApp().getElementById('sd_checkpoint_hash')
     sd_checkpoint_hash = opts.sd_checkpoint_hash || ""

+ 3 - 4
modules/api/api.py

@@ -480,7 +480,7 @@ class Api:
     def train_hypernetwork(self, args: dict):
         try:
             shared.state.begin()
-            initial_hypernetwork = shared.loaded_hypernetwork
+            shared.loaded_hypernetworks = []
             apply_optimizations = shared.opts.training_xattention_optimizations
             error = None
             filename = ''
@@ -491,16 +491,15 @@ class Api:
             except Exception as e:
                 error = e
             finally:
-                shared.loaded_hypernetwork = initial_hypernetwork
                 shared.sd_model.cond_stage_model.to(devices.device)
                 shared.sd_model.first_stage_model.to(devices.device)
                 if not apply_optimizations:
                     sd_hijack.apply_optimizations()
                 shared.state.end()
-            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
         except AssertionError as msg:
             shared.state.end()
-            return TrainResponse(info = "train embedding error: {error}".format(error = error))
+            return TrainResponse(info="train embedding error: {error}".format(error=error))
 
     def get_memory(self):
         try:

+ 147 - 0
modules/extra_networks.py

@@ -0,0 +1,147 @@
+import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+    extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+    extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+    def __init__(self, items=None):
+        self.items = items or []
+
+
+class ExtraNetwork:
+    def __init__(self, name):
+        self.name = name
+
+    def activate(self, p, params_list):
+        """
+        Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+        Passes arguments related to this extra network in params_list.
+        User passes arguments by specifying this in his prompt:
+
+        <name:arg1:arg2:arg3>
+
+        Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+        separated by colon.
+
+        Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+        in this case, all effects of this extra networks should be disabled.
+
+        Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+        For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+        > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
+
+        params_list will be:
+
+        [
+            ExtraNetworkParams(items=["agm", "1.1"]),
+            ExtraNetworkParams(items=["ray"])
+        ]
+
+        """
+        raise NotImplementedError
+
+    def deactivate(self, p):
+        """
+        Called at the end of processing for housekeeping. No need to do anything here.
+        """
+
+        raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+    """call activate for extra networks in extra_network_data in specified order, then call
+    activate for all remaining registered networks with an empty argument list"""
+
+    for extra_network_name, extra_network_args in extra_network_data.items():
+        extra_network = extra_network_registry.get(extra_network_name, None)
+        if extra_network is None:
+            print(f"Skipping unknown extra network: {extra_network_name}")
+            continue
+
+        try:
+            extra_network.activate(p, extra_network_args)
+        except Exception as e:
+            errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+    for extra_network_name, extra_network in extra_network_registry.items():
+        args = extra_network_data.get(extra_network_name, None)
+        if args is not None:
+            continue
+
+        try:
+            extra_network.activate(p, [])
+        except Exception as e:
+            errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+    """call deactivate for extra networks in extra_network_data in specified order, then call
+    deactivate for all remaining registered networks"""
+
+    for extra_network_name, extra_network_args in extra_network_data.items():
+        extra_network = extra_network_registry.get(extra_network_name, None)
+        if extra_network is None:
+            continue
+
+        try:
+            extra_network.deactivate(p)
+        except Exception as e:
+            errors.display(e, f"deactivating extra network {extra_network_name}")
+
+    for extra_network_name, extra_network in extra_network_registry.items():
+        args = extra_network_data.get(extra_network_name, None)
+        if args is not None:
+            continue
+
+        try:
+            extra_network.deactivate(p)
+        except Exception as e:
+            errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+    res = defaultdict(list)
+
+    def found(m):
+        name = m.group(1)
+        args = m.group(2)
+
+        res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+        return ""
+
+    prompt = re.sub(re_extra_net, found, prompt)
+
+    return prompt, res
+
+
+def parse_prompts(prompts):
+    res = []
+    extra_data = None
+
+    for prompt in prompts:
+        updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+        if extra_data is None:
+            extra_data = parsed_extra_data
+
+        res.append(updated_prompt)
+
+    return res, extra_data
+

+ 21 - 0
modules/extra_networks_hypernet.py

@@ -0,0 +1,21 @@
+from modules import extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+    def __init__(self):
+        super().__init__('hypernet')
+
+    def activate(self, p, params_list):
+        names = []
+        multipliers = []
+        for params in params_list:
+            assert len(params.items) > 0
+
+            names.append(params.items[0])
+            multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+        hypernetwork.load_hypernetworks(names, multipliers)
+
+    def deactivate(p, self):
+        pass

+ 3 - 9
modules/generation_parameters_copypaste.py

@@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict):
     from modules import ui
 
     settings_map = {
-        'sd_hypernetwork': 'Hypernet',
-        'sd_hypernetwork_strength': 'Hypernet strength',
         'CLIP_stop_at_last_layers': 'Clip skip',
         'inpainting_mask_weight': 'Conditional mask weight',
         'sd_model_checkpoint': 'Model hash',
@@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Clip skip" not in res:
         res["Clip skip"] = "1"
 
-    if "Hypernet strength" not in res:
-        res["Hypernet strength"] = "1"
-
-    if "Hypernet" in res:
-        hypernet_name = res["Hypernet"]
-        hypernet_hash = res.get("Hypernet hash", None)
-        res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+    hypernet = res.get("Hypernet", None)
+    if hypernet is not None:
+        res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
 
     if "Hires resize-1" not in res:
         res["Hires resize-1"] = 0

+ 75 - 32
modules/hypernetworks/hypernetwork.py

@@ -25,7 +25,6 @@ from statistics import stdev, mean
 optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
 
 class HypernetworkModule(torch.nn.Module):
-    multiplier = 1.0
     activation_dict = {
         "linear": torch.nn.Identity,
         "relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
                  add_layer_norm=False, activate_output=False, dropout_structure=None):
         super().__init__()
 
+        self.multiplier = 1.0
+
         assert layer_structure is not None, "layer_structure must not be None"
         assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
         assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
             state_dict[to] = x
 
     def forward(self, x):
-        return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+        return x + self.linear(x) * (self.multiplier if not self.training else 1)
 
     def trainables(self):
         layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
         return layer_structure
 
 
-def apply_strength(value=None):
-    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
 #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
 def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
     if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork:
                 for param in layer.parameters():
                     param.requires_grad = mode
 
+    def to(self, device):
+        for k, layers in self.layers.items():
+            for layer in layers:
+                layer.to(device)
+
+        return self
+
+    def set_multiplier(self, multiplier):
+        for k, layers in self.layers.items():
+            for layer in layers:
+                layer.multiplier = multiplier
+
+        return self
+
     def eval(self):
         for k, layers in self.layers.items():
             for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork:
             self.optimizer_state_dict = None
         if self.optimizer_state_dict:
             self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
-            print("Loaded existing optimizer from checkpoint")
-            print(f"Optimizer name is {self.optimizer_name}")
+            if shared.opts.print_hypernet_extra:
+                print("Loaded existing optimizer from checkpoint")
+                print(f"Optimizer name is {self.optimizer_name}")
         else:
             self.optimizer_name = "AdamW"
-            print("No saved optimizer exists in checkpoint")
+            if shared.opts.print_hypernet_extra:
+                print("No saved optimizer exists in checkpoint")
 
         for size, sd in state_dict.items():
             if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path):
     return res
 
 
-def load_hypernetwork(filename):
-    path = shared.hypernetworks.get(filename, None)
-    # Prevent any file named "None.pt" from being loaded.
-    if path is not None and filename != "None":
-        print(f"Loading hypernetwork {filename}")
-        try:
-            shared.loaded_hypernetwork = Hypernetwork()
-            shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+    path = shared.hypernetworks.get(name, None)
 
-        except Exception:
-            print(f"Error loading hypernetwork {path}", file=sys.stderr)
-            print(traceback.format_exc(), file=sys.stderr)
-    else:
-        if shared.loaded_hypernetwork is not None:
-            print("Unloading hypernetwork")
+    if path is None:
+        return None
+
+    hypernetwork = Hypernetwork()
+
+    try:
+        hypernetwork.load(path)
+    except Exception:
+        print(f"Error loading hypernetwork {path}", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+        return None
+
+    return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+    already_loaded = {}
+
+    for hypernetwork in shared.loaded_hypernetworks:
+        if hypernetwork.name in names:
+            already_loaded[hypernetwork.name] = hypernetwork
 
-        shared.loaded_hypernetwork = None
+    shared.loaded_hypernetworks.clear()
+
+    for i, name in enumerate(names):
+        hypernetwork = already_loaded.get(name, None)
+        if hypernetwork is None:
+            hypernetwork = load_hypernetwork(name)
+
+        if hypernetwork is None:
+            continue
+
+        hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+        shared.loaded_hypernetworks.append(hypernetwork)
 
 
 def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
     return applicable[0]
 
 
-def apply_hypernetwork(hypernetwork, context, layer=None):
-    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
 
     if hypernetwork_layers is None:
-        return context, context
+        return context_k, context_v
 
     if layer is not None:
         layer.hyper_k = hypernetwork_layers[0]
         layer.hyper_v = hypernetwork_layers[1]
 
-    context_k = hypernetwork_layers[0](context)
-    context_v = hypernetwork_layers[1](context)
+    context_k = hypernetwork_layers[0](context_k)
+    context_v = hypernetwork_layers[1](context_v)
+    return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+    context_k = context
+    context_v = context
+    for hypernetwork in hypernetworks:
+        context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
     return context_k, context_v
 
 
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
     q = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+    context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
     k = self.to_k(context_k)
     v = self.to_v(context_v)
 
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
     template_file = template_file.path
 
     path = shared.hypernetworks.get(hypernetwork_name, None)
-    shared.loaded_hypernetwork = Hypernetwork()
-    shared.loaded_hypernetwork.load(path)
+    hypernetwork = Hypernetwork()
+    hypernetwork.load(path)
+    shared.loaded_hypernetworks = [hypernetwork]
 
     shared.state.job = "train-hypernetwork"
     shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
     else:
         images_dir = None
 
-    hypernetwork = shared.loaded_hypernetwork
     checkpoint = sd_models.select_checkpoint()
 
     initial_step = hypernetwork.step or 0

+ 2 - 3
modules/hypernetworks/ui.py

@@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared
 not_available = ["hardswish", "multiheadattention"]
 keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
 
+
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
 
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
 
 
 def train_hypernetwork(*args):
-
-    initial_hypernetwork = shared.loaded_hypernetwork
+    shared.loaded_hypernetworks = []
 
     assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
 
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)}
     except Exception:
         raise
     finally:
-        shared.loaded_hypernetwork = initial_hypernetwork
         shared.sd_model.cond_stage_model.to(devices.device)
         shared.sd_model.first_stage_model.to(devices.device)
         sd_hijack.apply_optimizations()

+ 13 - 11
modules/processing.py

@@ -13,7 +13,7 @@ from skimage import exposure
 from typing import Any, Dict, List, Optional
 
 import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
 from modules.sd_hijack import model_hijack
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
@@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Size": f"{p.width}x{p.height}",
         "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
         "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
-        "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
-        "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
-        "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
         "Batch size": (None if p.batch_size < 2 else p.batch_size),
         "Batch pos": (None if p.batch_size < 2 else position_in_batch),
         "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
     try:
         for k, v in p.override_settings.items():
             setattr(opts, k, v)
-            if k == 'sd_hypernetwork':
-                shared.reload_hypernetworks()  # make onchange call for changing hypernet
 
             if k == 'sd_model_checkpoint':
-                sd_models.reload_model_weights()  # make onchange call for changing SD model
+                sd_models.reload_model_weights()
 
             if k == 'sd_vae':
-                sd_vae.reload_vae_weights()  # make onchange call for changing VAE
+                sd_vae.reload_vae_weights()
 
         res = process_images_inner(p)
 
@@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
         if p.override_settings_restore_afterwards:
             for k, v in stored_opts.items():
                 setattr(opts, k, v)
-                if k == 'sd_hypernetwork': shared.reload_hypernetworks()
-                if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
-                if k == 'sd_vae': sd_vae.reload_vae_weights()
+                if k == 'sd_model_checkpoint':
+                    sd_models.reload_model_weights()
+
+                if k == 'sd_vae':
+                    sd_vae.reload_vae_weights()
 
     return res
 
@@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
         cache[0] = (required_prompts, steps)
         return cache[1]
 
+    p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
+
     with torch.no_grad(), p.sd_model.ema_scope():
         with devices.autocast():
             p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
 
+            extra_networks.activate(p, extra_network_data)
+
         with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
             processed = Processed(p, [], p.seed, "")
             file.write(processed.infotext(p, 0))
@@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             if opts.grid_save:
                 images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
+    extra_networks.deactivate(p, extra_network_data)
     devices.torch_gc()
 
     res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)

+ 5 - 5
modules/sd_hijack_optimizations.py

@@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
     q_in = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
     k_in = self.to_k(context_k)
     v_in = self.to_v(context_v)
     del context, context_k, context_v, x
@@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
     q_in = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
     k_in = self.to_k(context_k)
     v_in = self.to_v(context_v)
 
@@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
     q = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
     k = self.to_k(context_k) * self.scale
     v = self.to_v(context_v)
     del context, context_k, context_v, x
@@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
     q = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
     k = self.to_k(context_k)
     v = self.to_v(context_v)
     del context, context_k, context_v, x
@@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
     q_in = self.to_q(x)
     context = default(context, x)
 
-    context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
     k_in = self.to_k(context_k)
     v_in = self.to_v(context_v)
 

+ 16 - 5
modules/shared.py

@@ -23,6 +23,7 @@ demo = None
 sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
 sd_model_file = os.path.join(script_path, 'model.ckpt')
 default_sd_model_file = sd_model_file
+
 parser = argparse.ArgumentParser()
 parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file
 
 os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
 hypernetworks = {}
-loaded_hypernetwork = None
+loaded_hypernetworks = []
 
 
 def reload_hypernetworks():
@@ -153,8 +154,6 @@ def reload_hypernetworks():
     global hypernetworks
 
     hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
-    hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
 
 
 class State:
@@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
     "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
-    "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
-    "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
     "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
     "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
     "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -661,3 +658,17 @@ mem_mon.start()
 def listfiles(dirname):
     filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
     return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+    return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+    path = html_path(filename)
+
+    if os.path.exists(path):
+        with open(path, encoding="utf8") as file:
+            return file.read()
+
+    return ""

+ 2 - 0
modules/textual_inversion/textual_inversion.py

@@ -50,6 +50,7 @@ class Embedding:
         self.sd_checkpoint = None
         self.sd_checkpoint_name = None
         self.optimizer_state_dict = None
+        self.filename = None
 
     def save(self, filename):
         embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase:
         embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
         embedding.vectors = vec.shape[0]
         embedding.shape = vec.shape[-1]
+        embedding.filename = path
 
         if self.expected_shape == -1 or self.expected_shape == embedding.shape:
             self.register_embedding(embedding, shared.sd_model)

+ 32 - 18
modules/ui.py

@@ -20,7 +20,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.paths import script_path
 
@@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504'  # 🔄
 save_style_symbol = '\U0001f4be'  # 💾
 apply_style_symbol = '\U0001f4cb'  # 📋
 clear_prompt_symbol = '\U0001F5D1'  # 🗑️
+extra_networks_symbol = '\U0001F3B4'  # 🎴
 
 
 def plaintext_to_html(text):
@@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
 
 def update_token_counter(text, steps):
     try:
+        text, _ = extra_networks.parse_prompt(text)
+
         _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
         prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
 
@@ -354,10 +357,10 @@ def create_toprow(is_img2img):
                         negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
 
         with gr.Column(scale=1, elem_id="roll_col"):
-            paste = gr.Button(value=paste_symbol, elem_id="paste")
-            save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
-            prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
-            clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+            paste = ToolButton(value=paste_symbol, elem_id="paste")
+            clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+            extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+
             token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
             token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
             negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
@@ -395,11 +398,14 @@ def create_toprow(is_img2img):
                     outputs=[],
                 )
 
-            with gr.Row():
+            with gr.Row(elem_id=f"{id_part}_styles_row"):
                 prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
                 create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
 
-    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button
+                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply")
+                save_style = ToolButton(value=save_style_symbol, elem_id="style_create")
+
+    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
 
 
 def setup_progressbar(*args, **kwargs):
@@ -616,11 +622,15 @@ def create_ui():
     modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
 
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
-        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
+        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
 
         dummy_component = gr.Label(visible=False)
         txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
 
+        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+            from modules import ui_extra_networks
+            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
                 for category in ordered_ui_categories():
@@ -794,14 +804,20 @@ def create_ui():
             token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
             negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
 
+            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+
     modules.scripts.scripts_current = modules.scripts.scripts_img2img
     modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
 
     with gr.Blocks(analytics_enabled=False) as img2img_interface:
-        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
+        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
 
         img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
 
+        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+            from modules import ui_extra_networks
+            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+
         with FormRow().style(equal_height=False):
             with gr.Column(variant='compact', elem_id="img2img_settings"):
                 copy_image_buttons = []
@@ -1064,6 +1080,8 @@ def create_ui():
             token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
             negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
 
+            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
             img2img_paste_fields = [
                 (img2img_prompt, "Prompt"),
                 (img2img_negative_prompt, "Negative prompt"),
@@ -1666,10 +1684,8 @@ def create_ui():
                 download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                 reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
 
-            if os.path.exists("html/licenses.html"):
-                with open("html/licenses.html", encoding="utf8") as file:
-                    with gr.TabItem("Licenses"):
-                        gr.HTML(file.read(), elem_id="licenses")
+            with gr.TabItem("Licenses"):
+                gr.HTML(shared.html("licenses.html"), elem_id="licenses")
 
             gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
 
@@ -1756,11 +1772,9 @@ def create_ui():
         if os.path.exists(os.path.join(script_path, "notification.mp3")):
             audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
 
-        if os.path.exists("html/footer.html"):
-            with open("html/footer.html", encoding="utf8") as file:
-                footer = file.read()
-                footer = footer.format(versions=versions_html())
-                gr.HTML(footer, elem_id="footer")
+        footer = shared.html("footer.html")
+        footer = footer.format(versions=versions_html())
+        gr.HTML(footer, elem_id="footer")
 
         text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
         settings_submit.click(

+ 10 - 0
modules/ui_components.py

@@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent):
         return "button"
 
 
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+    """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+    def __init__(self, **kwargs):
+        super().__init__(variant="tool-top", **kwargs)
+
+    def get_block_name(self):
+        return "button"
+
+
 class FormRow(gr.Row, gr.components.FormComponent):
     """Same as gr.Row but fits inside gradio forms"""
 

+ 149 - 0
modules/ui_extra_networks.py

@@ -0,0 +1,149 @@
+import os.path
+
+from modules import shared
+import gradio as gr
+import json
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+
+
+def register_page(page):
+    """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
+
+    extra_pages.append(page)
+
+
+class ExtraNetworksPage:
+    def __init__(self, title):
+        self.title = title
+        self.card_page = shared.html("extra-networks-card.html")
+        self.allow_negative_prompt = False
+
+    def refresh(self):
+        pass
+
+    def create_html(self, tabname):
+        items_html = ''
+
+        for item in self.list_items():
+            items_html += self.create_html_for_item(item, tabname)
+
+        if items_html == '':
+            dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
+            items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+        res = "<div class='extra-network-cards'>" + items_html + "</div>"
+
+        return res
+
+    def list_items(self):
+        raise NotImplementedError()
+
+    def allowed_directories_for_previews(self):
+        return []
+
+    def create_html_for_item(self, item, tabname):
+        preview = item.get("preview", None)
+
+        args = {
+            "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+            "prompt": json.dumps(item["prompt"]),
+            "tabname": json.dumps(tabname),
+            "local_preview": json.dumps(item["local_preview"]),
+            "name": item["name"],
+            "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+        }
+
+        return self.card_page.format(**args)
+
+
+def intialize():
+    extra_pages.clear()
+
+
+class ExtraNetworksUi:
+    def __init__(self):
+        self.pages = None
+        self.stored_extra_pages = None
+
+        self.button_save_preview = None
+        self.preview_target_filename = None
+
+        self.tabname = None
+
+
+def create_ui(container, button, tabname):
+    ui = ExtraNetworksUi()
+    ui.pages = []
+    ui.stored_extra_pages = extra_pages.copy()
+    ui.tabname = tabname
+
+    with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+        button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+        button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+        for page in ui.stored_extra_pages:
+            with gr.Tab(page.title):
+                page_elem = gr.HTML(page.create_html(ui.tabname))
+                ui.pages.append(page_elem)
+
+    ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+    ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+    button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
+    button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+
+    def refresh():
+        res = []
+
+        for pg in ui.stored_extra_pages:
+            pg.refresh()
+            res.append(pg.create_html(ui.tabname))
+
+        return res
+
+    button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+    return ui
+
+
+def path_is_parent(parent_path, child_path):
+    parent_path = os.path.abspath(parent_path)
+    child_path = os.path.abspath(child_path)
+
+    return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+
+
+def setup_ui(ui, gallery):
+    def save_preview(index, images, filename):
+        if len(images) == 0:
+            print("There is no image in gallery to save as a preview.")
+            return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+        index = int(index)
+        index = 0 if index < 0 else index
+        index = len(images) - 1 if index >= len(images) else index
+
+        img_info = images[index if index >= 0 else 0]
+        image = image_from_url_text(img_info)
+
+        is_allowed = False
+        for extra_page in ui.stored_extra_pages:
+            if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+                is_allowed = True
+                break
+
+        assert is_allowed, f'writing to {filename} is not allowed'
+
+        image.save(filename)
+
+        return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+    ui.button_save_preview.click(
+        fn=save_preview,
+        _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+        inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+        outputs=[*ui.pages]
+    )

+ 34 - 0
modules/ui_extra_networks_hypernets.py

@@ -0,0 +1,34 @@
+import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+    def __init__(self):
+        super().__init__('Hypernetworks')
+
+    def refresh(self):
+        shared.reload_hypernetworks()
+
+    def list_items(self):
+        for name, path in shared.hypernetworks.items():
+            path, ext = os.path.splitext(path)
+            previews = [path + ".png", path + ".preview.png"]
+
+            preview = None
+            for file in previews:
+                if os.path.isfile(file):
+                    preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+                    break
+
+            yield {
+                "name": name,
+                "filename": path,
+                "preview": preview,
+                "prompt": f"<hypernet:{name}:1.0>",
+                "local_preview": path + ".png",
+            }
+
+    def allowed_directories_for_previews(self):
+        return [shared.cmd_opts.hypernetwork_dir]
+

+ 32 - 0
modules/ui_extra_networks_textual_inversion.py

@@ -0,0 +1,32 @@
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+    def __init__(self):
+        super().__init__('Textual Inversion')
+        self.allow_negative_prompt = True
+
+    def refresh(self):
+        sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+    def list_items(self):
+        for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+            path, ext = os.path.splitext(embedding.filename)
+            preview_file = path + ".preview.png"
+
+            preview = None
+            if os.path.isfile(preview_file):
+                preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+
+            yield {
+                "name": embedding.name,
+                "filename": embedding.filename,
+                "preview": preview,
+                "prompt": embedding.name,
+                "local_preview": path + ".preview.png",
+            }
+
+    def allowed_directories_for_previews(self):
+        return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)

+ 12 - 1
script.js

@@ -13,6 +13,7 @@ function get_uiCurrentTabContent() {
 }
 
 uiUpdateCallbacks = []
+uiLoadedCallbacks = []
 uiTabChangeCallbacks = []
 optionsChangedCallbacks = []
 let uiCurrentTab = null
@@ -20,6 +21,9 @@ let uiCurrentTab = null
 function onUiUpdate(callback){
     uiUpdateCallbacks.push(callback)
 }
+function onUiLoaded(callback){
+    uiLoadedCallbacks.push(callback)
+}
 function onUiTabChange(callback){
     uiTabChangeCallbacks.push(callback)
 }
@@ -38,8 +42,15 @@ function executeCallbacks(queue, m) {
     queue.forEach(function(x){runCallback(x, m)})
 }
 
+var executedOnLoaded = false;
+
 document.addEventListener("DOMContentLoaded", function() {
     var mutationObserver = new MutationObserver(function(m){
+        if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
+            executedOnLoaded = true;
+            executeCallbacks(uiLoadedCallbacks);
+        }
+
         executeCallbacks(uiUpdateCallbacks, m);
         const newTab = get_uiCurrentTab();
         if ( newTab && ( newTab !== uiCurrentTab ) ) {
@@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() {
 /**
  * Add a ctrl+enter as a shortcut to start a generation
  */
- document.addEventListener('keydown', function(e) {
+document.addEventListener('keydown', function(e) {
     var handled = false;
     if (e.key !== undefined) {
         if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;

+ 0 - 29
scripts/xy_grid.py

@@ -11,7 +11,6 @@ import modules.scripts as scripts
 import gradio as gr
 
 from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
-from modules.hypernetworks import hypernetwork
 from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
@@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs):
             raise RuntimeError(f"Unknown checkpoint: {x}")
 
 
-def apply_hypernetwork(p, x, xs):
-    if x.lower() in ["", "none"]:
-        name = None
-    else:
-        name = hypernetwork.find_closest_hypernetwork_name(x)
-        if not name:
-            raise RuntimeError(f"Unknown hypernetwork: {x}")
-    hypernetwork.load_hypernetwork(name)
-
-
-def apply_hypernetwork_strength(p, x, xs):
-    hypernetwork.apply_strength(x)
-
-
-def confirm_hypernetworks(p, xs):
-    for x in xs:
-        if x.lower() in ["", "none"]:
-            continue
-        if not hypernetwork.find_closest_hypernetwork_name(x):
-            raise RuntimeError(f"Unknown hypernetwork: {x}")
-
-
 def apply_clip_skip(p, x, xs):
     opts.data["CLIP_stop_at_last_layers"] = x
 
@@ -208,8 +185,6 @@ axis_options = [
     AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
     AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
     AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
-    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
-    AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
     AxisOption("Sigma Churn", float, apply_field("s_churn")),
     AxisOption("Sigma min", float, apply_field("s_tmin")),
     AxisOption("Sigma max", float, apply_field("s_tmax")),
@@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
 class SharedSettingsStackHelper(object):
     def __enter__(self):
         self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
-        self.hypernetwork = opts.sd_hypernetwork
         self.vae = opts.sd_vae
   
     def __exit__(self, exc_type, exc_value, tb):
@@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object):
         modules.sd_models.reload_model_weights()
         modules.sd_vae.reload_vae_weights()
 
-        hypernetwork.load_hypernetwork(self.hypernetwork)
-        hypernetwork.apply_strength()
-
         opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
 
 

+ 102 - 88
style.css

@@ -132,13 +132,6 @@
 }
 
 #roll_col > button {
-    min-width: 2em;
-    min-height: 2em;
-    max-width: 2em;
-    max-height: 2em;
-    flex-grow: 0;
-    padding-left: 0.25em;
-    padding-right: 0.25em;
     margin: 0.1em 0;
 }
 
@@ -146,9 +139,10 @@
     min-width: 0 !important;
     max-width: 8em !important;
     margin-right: 1em;
+    gap: 0;
 }
 #interrogate, #deepbooru{
-    margin: 0em 0.25em 0.9em 0.25em;
+    margin: 0em 0.25em 0.5em 0.25em;
     min-width: 8em;
     max-width: 8em;
 }
@@ -157,8 +151,17 @@
     min-width: 8em !important;
 }
 
+#txt2img_styles_row, #img2img_styles_row{
+    gap: 0.25em;
+    margin-top: 0.5em;
+}
+
+#txt2img_styles_row > button, #img2img_styles_row > button{
+    margin: 0;
+}
+
 #txt2img_styles, #img2img_styles{
-    margin-top: 1em;
+    padding: 0;
 }
 
 #txt2img_styles ul, #img2img_styles ul{
@@ -635,17 +638,21 @@ canvas[key="mask"] {
     background-color: rgb(31 41 55 / var(--tw-bg-opacity));
 }
 
-.gr-button-tool{
+.gr-button-tool, .gr-button-tool-top{
     max-width: 2.5em;
     min-width: 2.5em !important;
     height: 2.4em;
-    margin: 1.6em 0.7em 0.55em 0;
 }
 
-#tab_modelmerger .gr-button-tool{
+.gr-button-tool{
     margin: 0.6em 0em 0.55em 0;
 }
 
+.gr-button-tool-top, #settings .gr-button-tool{
+    margin: 1.6em 0.7em 0.55em 0;
+}
+
+
 #modelmerger_results_container{
     margin-top: 1em;
     overflow: visible;
@@ -763,81 +770,88 @@ footer {
     line-height: 2.4em;
 }
 
-/* The following handles localization for right-to-left (RTL) languages like Arabic.
-The rtl media type will only be activated by the logic in javascript/localization.js.
-If you change anything above, you need to make sure it is RTL compliant by just running
-your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
-Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
-@media rtl {
-    /* this part was added manually */
-    :host {
-        direction: rtl;
-    }
-    select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
-        direction: ltr;
-    }
-    #script_list > label > select,
-    #x_type > label > select,
-    #y_type > label > select {
-        direction: rtl;
-    }
-    .gr-radio, .gr-checkbox{
-        margin-left: 0.25em;
-    }
+#txt2img_extra_networks, #img2img_extra_networks{
+    margin-top: -1em;
+}
 
-    /* automatically generated with few manual modifications */
-    .performance .time {
-        margin-right: unset;
-        margin-left: 0;
-    }
-    .justify-center.overflow-x-scroll {
-        justify-content: right;
-    }
-    .justify-center.overflow-x-scroll button:first-of-type {
-        margin-left: unset;
-        margin-right: auto;
-    }
-    .justify-center.overflow-x-scroll button:last-of-type {
-        margin-right: unset;
-        margin-left: auto;
-    }
-    #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
-        margin-right: unset;
-        margin-left: 8em;
-    }
-    #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
-        right: unset;
-        left: 0;
-    }
-    .progressDiv .progress{
-        padding: 0 0 0 8px;
-        text-align: left;
-    }
-    #lightboxModal{
-        left: unset;
-        right: 0;
-    }
-    .modalPrev, .modalNext{
-        border-radius: 3px 0 0 3px;
-    }
-    .modalNext {
-        right: unset;
-        left: 0;
-        border-radius: 0 3px 3px 0;
-    }
-    #imageARPreview{
-        left:unset;
-        right:0px;
-    }
-    #txt2img_skip, #img2img_skip{
-        right: unset;
-        left: 0px;
-    }
-    #context-menu{
-        box-shadow:-1px 1px 2px #CE6400;
-    }
-    .gr-box > div > div > input.gr-text-input{
-        right: unset;
-        left: 0.5em;
-    }
+.extra-networks > div > [id *= '_extra_']{
+    margin: 0.3em;
 }
+
+.extra-network-cards .nocards{
+    margin: 1.25em 0.5em 0.5em 0.5em;
+}
+
+.extra-network-cards .nocards h1{
+    font-size: 1.5em;
+    margin-bottom: 1em;
+}
+
+.extra-network-cards .nocards li{
+    margin-left: 0.5em;
+}
+
+.extra-network-cards .card{
+    display: inline-block;
+    margin: 0.5em;
+    width: 16em;
+    height: 24em;
+    box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
+    border-radius: 0.2em;
+    position: relative;
+
+    background-size: auto 100%;
+    background-position: center;
+    overflow: hidden;
+    cursor: pointer;
+
+    background-image: url('./file=html/card-no-preview.png')
+}
+
+.extra-network-cards .card:hover{
+    box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);
+}
+
+.extra-network-cards .card .actions .additional{
+    display: none;
+}
+
+.extra-network-cards .card .actions{
+    position: absolute;
+    bottom: 0;
+    left: 0;
+    right: 0;
+    padding: 0.5em;
+    color: white;
+    background: rgba(0,0,0,0.5);
+    box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);
+    text-shadow: 0 0 0.2em black;
+}
+
+.extra-network-cards .card .actions:hover{
+    box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
+}
+
+.extra-network-cards .card .actions .name{
+    font-size: 1.7em;
+    font-weight: bold;
+    line-break: anywhere;
+}
+
+.extra-network-cards .card .actions:hover .additional{
+    display: block;
+}
+
+.extra-network-cards .card ul{
+    margin: 0.25em 0 0.75em 0.25em;
+    cursor: unset;
+}
+
+.extra-network-cards .card ul a{
+    cursor: pointer;
+}
+
+.extra-network-cards .card ul a:hover{
+    color: red;
+}
+

+ 22 - 4
webui.py

@@ -9,16 +9,18 @@ from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 
-from modules import import_hook, errors
+from modules import import_hook, errors, extra_networks
+from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
 from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
 from modules.paths import script_path
 
 import torch
+
 # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
 if ".dev" in torch.__version__ or "+git" in torch.__version__:
     torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
 
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
 import modules.codeformer_model as codeformer
 import modules.extras
 import modules.face_restoration
@@ -84,10 +86,17 @@ def initialize():
     shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
     shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
-    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
-    shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
     shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
 
+    shared.reload_hypernetworks()
+
+    ui_extra_networks.intialize()
+    ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+    ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+    extra_networks.initialize()
+    extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
     if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
 
         try:
@@ -209,6 +218,15 @@ def webui():
 
         modules.sd_models.list_models()
 
+        shared.reload_hypernetworks()
+
+        ui_extra_networks.intialize()
+        ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
+        ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+
+        extra_networks.initialize()
+        extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
+
 
 if __name__ == "__main__":
     if cmd_opts.nowebui: